diff --git a/.env.example b/.env.example index b7d88d41..993b6cd8 100644 --- a/.env.example +++ b/.env.example @@ -2,6 +2,7 @@ FLASK_APP=lnbits FLASK_ENV=development LNBITS_SITE_TITLE=LNbits +LNBITS_ALLOWED_USERS="all" LNBITS_DEFAULT_WALLET_NAME="LNbits wallet" LNBITS_DATA_FOLDER="/your_custom_data_folder" LNBITS_DISABLED_EXTENSIONS="amilk,events" diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 4e1a9571..f712cb48 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -1,6 +1,6 @@ from flask import g, abort, redirect, request, render_template, send_from_directory, url_for from http import HTTPStatus -from os import path +from os import getenv, path from lnbits.core import core_app from lnbits.decorators import check_user_exists, validate_uuids @@ -61,6 +61,10 @@ def wallet(): user = get_user(create_account().id) else: user = get_user(user_id) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") + allowed_users = getenv("LNBITS_ALLOWED_USERS", "all") + + if allowed_users != "all" and user_id not in allowed_users.split(","): + abort(HTTPStatus.UNAUTHORIZED, f"User not authorized.") if not wallet_id: if user.wallets and not wallet_name: diff --git a/lnbits/decorators.py b/lnbits/decorators.py index b8f61e15..ef1ef66d 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -2,6 +2,7 @@ from cerberus import Validator # type: ignore from flask import g, abort, jsonify, request from functools import wraps from http import HTTPStatus +from os import getenv from typing import List, Union from uuid import UUID @@ -51,7 +52,12 @@ def check_user_exists(param: str = "usr"): def wrap(view): @wraps(view) def wrapped_view(**kwargs): - g.user = get_user(request.args.get(param, type=str)) or abort(HTTPStatus.NOT_FOUND, "User not found.") + g.user = get_user(request.args.get(param, type=str)) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") + allowed_users = getenv("LNBITS_ALLOWED_USERS", "all") + + if allowed_users != "all" and g.user.id not in allowed_users.split(","): + abort(HTTPStatus.UNAUTHORIZED, f"User not authorized.") + return view(**kwargs) return wrapped_view