refactor: clean up __init__ file following some Flask conventions

Flask extensions are loaded in a way that makes them easily reusable by blueprints.
In this commit we are also adding `environs` to manage .env and settings:
breaking changes!

- FLASK_APP=lnbits.app
- LNBITS_ALLOWED_USERS needs to be empty now to allow all users (NOT "all")
This commit is contained in:
Eneko Illarramendi 2020-09-05 08:00:44 +02:00 committed by fiatjaf
parent ffa3c3f6a6
commit 1bc5e144d3
13 changed files with 226 additions and 138 deletions

View File

@ -10,12 +10,13 @@ python_version = "3.7"
bitstring = "*" bitstring = "*"
cerberus = "*" cerberus = "*"
ecdsa = "*" ecdsa = "*"
lnurl = "*" environs = "*"
flask = "*" flask = "*"
flask-assets = "*" flask-assets = "*"
flask-compress = "*" flask-compress = "*"
flask-cors = "*" flask-cors = "*"
flask-talisman = "*" flask-talisman = "*"
lnurl = "*"
pyscss = "*" pyscss = "*"
requests = "*" requests = "*"
shortuuid = "*" shortuuid = "*"

53
Pipfile.lock generated
View File

@ -1,11 +1,11 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "d21f745fb8f799aaca868b4c97000f31e455063a8241366b1e0b0cd381489a0e" "sha256": "2270f2525e54e976b09491e458033d25ec5bbdea9e74d417e787df33031c6948"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
"python_version": "3.8" "python_version": "3.7"
}, },
"sources": [ "sources": [
{ {
@ -100,6 +100,14 @@
"index": "pypi", "index": "pypi",
"version": "==0.16.0" "version": "==0.16.0"
}, },
"environs": {
"hashes": [
"sha256:a98005aab7613b6fe7a1af7192a5163f72a52d3348d3918e6c7a2a32e4012779",
"sha256:bf3fd6bc54fcfd7f512ddcb80a7781f0ced2b0c83dd123d619e9468ecdaaf537"
],
"index": "pypi",
"version": "==8.0.0"
},
"flask": { "flask": {
"hashes": [ "hashes": [
"sha256:4efa1ae2d7c9865af48986de8aeb8504bf32c7f3d6fdc9353d34b21f4b127060", "sha256:4efa1ae2d7c9865af48986de8aeb8504bf32c7f3d6fdc9353d34b21f4b127060",
@ -209,6 +217,14 @@
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==1.1.1" "version": "==1.1.1"
}, },
"marshmallow": {
"hashes": [
"sha256:67bf4cae9d3275b3fc74bd7ff88a7c98ee8c57c94b251a67b031dc293ecc4b76",
"sha256:a2a5eefb4b75a3b43f05be1cca0b6686adf56af7465c3ca629e5ad8d1e1fe13d"
],
"markers": "python_version >= '3.5'",
"version": "==3.7.1"
},
"pydantic": { "pydantic": {
"hashes": [ "hashes": [
"sha256:1783c1d927f9e1366e0e0609ae324039b2479a1a282a98ed6a6836c9ed02002c", "sha256:1783c1d927f9e1366e0e0609ae324039b2479a1a282a98ed6a6836c9ed02002c",
@ -239,6 +255,13 @@
"index": "pypi", "index": "pypi",
"version": "==1.3.7" "version": "==1.3.7"
}, },
"python-dotenv": {
"hashes": [
"sha256:8c10c99a1b25d9a68058a1ad6f90381a62ba68230ca93966882a4dbc3bc9c33d",
"sha256:c10863aee750ad720f4f43436565e4c1698798d763b63234fb5021b6c616e423"
],
"version": "==0.14.0"
},
"requests": { "requests": {
"hashes": [ "hashes": [
"sha256:b3559a131db72c33ee969480840fff4bb6dd111de7dd27c8ee1f820f4f00231b", "sha256:b3559a131db72c33ee969480840fff4bb6dd111de7dd27c8ee1f820f4f00231b",
@ -263,6 +286,15 @@
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==1.15.0" "version": "==1.15.0"
}, },
"typing-extensions": {
"hashes": [
"sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918",
"sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c",
"sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"
],
"markers": "python_version < '3.8'",
"version": "==3.7.4.3"
},
"urllib3": { "urllib3": {
"hashes": [ "hashes": [
"sha256:91056c15fa70756691db97756772bb1eb9678fa585d9184f24534b100dc60f4a", "sha256:91056c15fa70756691db97756772bb1eb9678fa585d9184f24534b100dc60f4a",
@ -375,6 +407,14 @@
"index": "pypi", "index": "pypi",
"version": "==17.8.0" "version": "==17.8.0"
}, },
"importlib-metadata": {
"hashes": [
"sha256:90bb658cdbbf6d1735b6341ce708fc7024a3e14e99ffdc5783edea9f9b077f83",
"sha256:dc15b2969b4ce36305c51eebe62d418ac7791e9a157911d58bfb1f9ccd8e2070"
],
"markers": "python_version < '3.8'",
"version": "==1.7.0"
},
"iniconfig": { "iniconfig": {
"hashes": [ "hashes": [
"sha256:80cf40c597eb564e86346103f609d74efce0f6b4d4f30ec8ce9e2c26411ba437", "sha256:80cf40c597eb564e86346103f609d74efce0f6b4d4f30ec8ce9e2c26411ba437",
@ -568,7 +608,16 @@
"sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c", "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c",
"sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f" "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"
], ],
"markers": "python_version < '3.8'",
"version": "==3.7.4.3" "version": "==3.7.4.3"
},
"zipp": {
"hashes": [
"sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b",
"sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96"
],
"markers": "python_version >= '3.6'",
"version": "==3.1.0"
} }
} }
} }

View File

@ -1,116 +0,0 @@
import re
import importlib
import sqlite3
from flask import Flask
from flask_assets import Environment, Bundle # type: ignore
from flask_compress import Compress # type: ignore
from flask_cors import CORS # type: ignore
from flask_talisman import Talisman # type: ignore
from os import getenv
from werkzeug.middleware.proxy_fix import ProxyFix
from .core import core_app
from .helpers import ExtensionManager
from .settings import FORCE_HTTPS
from .db import open_db, open_ext_db
disabled_extensions = getenv("LNBITS_DISABLED_EXTENSIONS", "").split(",")
app = Flask(__name__)
app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) # type: ignore
valid_extensions = [ext for ext in ExtensionManager(disabled=disabled_extensions).extensions if ext.is_valid]
# optimization & security
# -----------------------
Compress(app)
CORS(app)
Talisman(
app,
force_https=FORCE_HTTPS,
content_security_policy={
"default-src": [
"'self'",
"'unsafe-eval'",
"'unsafe-inline'",
"blob:",
"api.opennode.co",
]
},
)
# blueprints / extensions
# -----------------------
app.register_blueprint(core_app)
for ext in valid_extensions:
try:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
app.register_blueprint(getattr(ext_module, f"{ext.code}_ext"), url_prefix=f"/{ext.code}")
except Exception:
raise ImportError(f"Please make sure that the extension `{ext.code}` follows conventions.")
# filters
# -------
app.jinja_env.globals["DEBUG"] = app.config["DEBUG"]
app.jinja_env.globals["EXTENSIONS"] = valid_extensions
app.jinja_env.globals["SITE_TITLE"] = getenv("LNBITS_SITE_TITLE", "LNbits")
# assets
# ------
assets = Environment(app)
assets.url = app.static_url_path
assets.register("base_css", Bundle("scss/base.scss", filters="pyscss", output="css/base.css"))
# commands
# --------
def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
from .core import migrations as core_migrations
with open_db() as core_db:
try:
rows = core_db.fetchall("SELECT * FROM dbversions")
except sqlite3.OperationalError:
# migration 3 wasn't ran
core_migrations.m000_create_migrations_table(core_db)
rows = core_db.fetchall("SELECT * FROM dbversions")
current_versions = {row["db"]: row["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
def run_migration(db, migrations_module):
db_name = migrations_module.__name__.split(".")[-2]
for key, run_migration in migrations_module.__dict__.items():
match = match = matcher.match(key)
if match:
version = int(match.group(1))
if version > current_versions.get(db_name, 0):
print(f"running migration {db_name}.{version}")
run_migration(db)
core_db.execute(
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", (db_name, version)
)
run_migration(core_db, core_migrations)
for ext in valid_extensions:
try:
ext_migrations = importlib.import_module(f"lnbits.extensions.{ext.code}.migrations")
with open_ext_db(ext.code) as db:
run_migration(db, ext_migrations)
except ImportError:
raise ImportError(f"Please make sure that the extension `{ext.code}` has a migrations file.")

View File

@ -1,4 +1,8 @@
from lnbits import app, migrate_databases from .app import create_app
from .commands import migrate_databases
migrate_databases() migrate_databases()
app = create_app()
app.run() app.run()

75
lnbits/app.py Normal file
View File

@ -0,0 +1,75 @@
import importlib
from flask import Flask
from flask_assets import Bundle # type: ignore
from flask_cors import CORS # type: ignore
from flask_talisman import Talisman # type: ignore
from werkzeug.middleware.proxy_fix import ProxyFix
from .commands import legacy_migrate
from .core import core_app
from .ext import assets, compress
from .helpers import get_valid_extensions
def create_app(config_object="lnbits.settings") -> Flask:
"""Create application factory, as explained here: http://flask.pocoo.org/docs/patterns/appfactories/.
:param config_object: The configuration object to use.
"""
app = Flask(__name__, static_folder="static")
app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) # type: ignore
app.config.from_object(config_object)
register_flask_extensions(app)
register_blueprints(app)
register_filters(app)
register_commands(app)
return app
def register_blueprints(app) -> None:
"""Register Flask blueprints / LNbits extensions."""
app.register_blueprint(core_app)
for ext in get_valid_extensions():
try:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
app.register_blueprint(getattr(ext_module, f"{ext.code}_ext"), url_prefix=f"/{ext.code}")
except Exception:
raise ImportError(f"Please make sure that the extension `{ext.code}` follows conventions.")
def register_commands(app):
"""Register Click commands."""
app.cli.add_command(legacy_migrate)
def register_flask_extensions(app):
"""Register Flask extensions."""
"""If possible we use the .init_app() option so that Blueprints can also use extensions."""
CORS(app)
Talisman(
app,
force_https=app.config["FORCE_HTTPS"],
content_security_policy={
"default-src": [
"'self'",
"'unsafe-eval'",
"'unsafe-inline'",
"blob:",
"api.opennode.co",
]
},
)
assets.init_app(app)
assets.register("base_css", Bundle("scss/base.scss", filters="pyscss", output="css/base.css"))
compress.init_app(app)
def register_filters(app):
"""Jinja filters."""
app.jinja_env.globals["DEBUG"] = app.config["DEBUG"]
app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions()
app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"]

51
lnbits/commands.py Normal file
View File

@ -0,0 +1,51 @@
import click
import importlib
import re
import sqlite3
from .core import migrations as core_migrations
from .db import open_db, open_ext_db
from .helpers import get_valid_extensions
@click.command("migrate")
def legacy_migrate():
migrate_databases()
def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
with open_db() as core_db:
try:
rows = core_db.fetchall("SELECT * FROM dbversions")
except sqlite3.OperationalError:
# migration 3 wasn't ran
core_migrations.m000_create_migrations_table(core_db)
rows = core_db.fetchall("SELECT * FROM dbversions")
current_versions = {row["db"]: row["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
def run_migration(db, migrations_module):
db_name = migrations_module.__name__.split(".")[-2]
for key, run_migration in migrations_module.__dict__.items():
match = match = matcher.match(key)
if match:
version = int(match.group(1))
if version > current_versions.get(db_name, 0):
print(f"running migration {db_name}.{version}")
run_migration(db)
core_db.execute(
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", (db_name, version)
)
run_migration(core_db, core_migrations)
for ext in get_valid_extensions():
try:
ext_migrations = importlib.import_module(f"lnbits.extensions.{ext.code}.migrations")
with open_ext_db(ext.code) as db:
run_migration(db, ext_migrations)
except ImportError:
raise ImportError(f"Please make sure that the extension `{ext.code}` has a migrations file.")

View File

@ -1,10 +1,10 @@
from flask import g, abort, redirect, request, render_template, send_from_directory, url_for from flask import g, abort, redirect, request, render_template, send_from_directory, url_for
from http import HTTPStatus from http import HTTPStatus
from os import getenv, path from os import path
from lnbits.core import core_app from lnbits.core import core_app
from lnbits.decorators import check_user_exists, validate_uuids from lnbits.decorators import check_user_exists, validate_uuids
from lnbits.settings import SERVICE_FEE from lnbits.settings import LNBITS_ALLOWED_USERS, SERVICE_FEE
from ..crud import ( from ..crud import (
create_account, create_account,
@ -61,9 +61,8 @@ def wallet():
user = get_user(create_account().id) user = get_user(create_account().id)
else: else:
user = get_user(user_id) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") user = get_user(user_id) or abort(HTTPStatus.NOT_FOUND, "User does not exist.")
allowed_users = getenv("LNBITS_ALLOWED_USERS", "all")
if allowed_users != "all" and user_id not in allowed_users.split(","): if LNBITS_ALLOWED_USERS and user_id not in LNBITS_ALLOWED_USERS:
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") abort(HTTPStatus.UNAUTHORIZED, "User not authorized.")
if not wallet_id: if not wallet_id:

View File

@ -2,11 +2,11 @@ from cerberus import Validator # type: ignore
from flask import g, abort, jsonify, request from flask import g, abort, jsonify, request
from functools import wraps from functools import wraps
from http import HTTPStatus from http import HTTPStatus
from os import getenv
from typing import List, Union from typing import List, Union
from uuid import UUID from uuid import UUID
from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.settings import LNBITS_ALLOWED_USERS
def api_check_wallet_key(key_type: str = "invoice"): def api_check_wallet_key(key_type: str = "invoice"):
@ -62,9 +62,8 @@ def check_user_exists(param: str = "usr"):
@wraps(view) @wraps(view)
def wrapped_view(**kwargs): def wrapped_view(**kwargs):
g.user = get_user(request.args.get(param, type=str)) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") g.user = get_user(request.args.get(param, type=str)) or abort(HTTPStatus.NOT_FOUND, "User does not exist.")
allowed_users = getenv("LNBITS_ALLOWED_USERS", "all")
if allowed_users != "all" and g.user.id not in allowed_users.split(","): if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS:
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") abort(HTTPStatus.UNAUTHORIZED, "User not authorized.")
return view(**kwargs) return view(**kwargs)

6
lnbits/ext.py Normal file
View File

@ -0,0 +1,6 @@
from flask_assets import Environment # type: ignore
from flask_compress import Compress # type: ignore
assets = Environment()
compress = Compress()

View File

@ -4,7 +4,7 @@ import shortuuid # type: ignore
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional
from .settings import LNBITS_PATH from .settings import LNBITS_DISABLED_EXTENSIONS, LNBITS_PATH
class Extension(NamedTuple): class Extension(NamedTuple):
@ -17,8 +17,8 @@ class Extension(NamedTuple):
class ExtensionManager: class ExtensionManager:
def __init__(self, *, disabled: list = []): def __init__(self):
self._disabled = disabled self._disabled: List[str] = LNBITS_DISABLED_EXTENSIONS
self._extension_folders: List[str] = [x[1] for x in os.walk(os.path.join(LNBITS_PATH, "extensions"))][0] self._extension_folders: List[str] = [x[1] for x in os.walk(os.path.join(LNBITS_PATH, "extensions"))][0]
@property @property
@ -48,5 +48,9 @@ class ExtensionManager:
return output return output
def get_valid_extensions() -> List[Extension]:
return [extension for extension in ExtensionManager().extensions if extension.is_valid]
def urlsafe_short_hash() -> str: def urlsafe_short_hash() -> str:
return shortuuid.uuid() return shortuuid.uuid()

View File

@ -1,14 +1,26 @@
import importlib import importlib
import os
from environs import Env # type: ignore
from os import path
from typing import List
env = Env()
env.read_env()
wallets_module = importlib.import_module("lnbits.wallets") wallets_module = importlib.import_module("lnbits.wallets")
wallet_class = getattr(wallets_module, os.getenv("LNBITS_BACKEND_WALLET_CLASS", "VoidWallet")) wallet_class = getattr(wallets_module, env.str("LNBITS_BACKEND_WALLET_CLASS", default="VoidWallet"))
LNBITS_PATH = os.path.dirname(os.path.realpath(__file__)) ENV = env.str("FLASK_ENV", default="production")
LNBITS_DATA_FOLDER = os.getenv("LNBITS_DATA_FOLDER", os.path.join(LNBITS_PATH, "data")) DEBUG = ENV == "development"
LNBITS_PATH = path.dirname(path.realpath(__file__))
LNBITS_DATA_FOLDER = env.str("LNBITS_DATA_FOLDER", default=path.join(LNBITS_PATH, "data"))
LNBITS_ALLOWED_USERS: List[str] = env.list("LNBITS_ALLOWED_USERS", default=[], subcast=str)
LNBITS_DISABLED_EXTENSIONS: List[str] = env.list("LNBITS_DISABLED_EXTENSIONS", default=[], subcast=str)
LNBITS_SITE_TITLE = env.str("LNBITS_SITE_TITLE", default="LNbits")
WALLET = wallet_class() WALLET = wallet_class()
DEFAULT_WALLET_NAME = os.getenv("LNBITS_DEFAULT_WALLET_NAME", "LNbits wallet") DEFAULT_WALLET_NAME = env.str("LNBITS_DEFAULT_WALLET_NAME", default="LNbits wallet")
FORCE_HTTPS = os.getenv("LNBITS_FORCE_HTTPS", "1") == "1" FORCE_HTTPS = env.bool("LNBITS_FORCE_HTTPS", default=True)
SERVICE_FEE = float(os.getenv("LNBITS_SERVICE_FEE", "0.0")) SERVICE_FEE = env.float("LNBITS_SERVICE_FEE", default=0.0)

View File

@ -6,6 +6,7 @@ certifi==2020.6.20
chardet==3.0.4 chardet==3.0.4
click==7.1.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' click==7.1.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
ecdsa==0.16.0 ecdsa==0.16.0
environs==8.0.0
flask-assets==2.0 flask-assets==2.0
flask-compress==1.5.0 flask-compress==1.5.0
flask-cors==3.0.9 flask-cors==3.0.9
@ -16,8 +17,10 @@ itsdangerous==1.1.0; python_version >= '2.7' and python_version not in '3.0, 3.1
jinja2==2.11.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' jinja2==2.11.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
lnurl==0.3.5 lnurl==0.3.5
markupsafe==1.1.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' markupsafe==1.1.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
marshmallow==3.7.1; python_version >= '3.5'
pydantic==1.6.1; python_version >= '3.6' pydantic==1.6.1; python_version >= '3.6'
pyscss==1.3.7 pyscss==1.3.7
python-dotenv==0.14.0
requests==2.24.0 requests==2.24.0
shortuuid==1.0.1 shortuuid==1.0.1
six==1.15.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' six==1.15.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'

View File

@ -1,10 +1,11 @@
import pytest import pytest
from lnbits import app from lnbits.app import create_app
@pytest.fixture @pytest.fixture
def client(): def client():
app = create_app()
app.config["TESTING"] = True app.config["TESTING"] = True
with app.test_client() as client: with app.test_client() as client: