From 66c908e60061e1865e2d2c6796e6e75c36c793b7 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Thu, 22 Dec 2022 16:30:37 +0200 Subject: [PATCH] chore: migrate after major changes on main --- lnbits/app.py | 71 +++-- lnbits/commands.py | 57 +--- lnbits/core/__init__.py | 3 + lnbits/core/crud.py | 31 +++ lnbits/core/helpers.py | 93 +++++++ lnbits/core/models.py | 6 +- lnbits/core/templates/core/extensions.html | 9 + lnbits/core/templates/core/install.html | 286 +++++++++++++++++++++ lnbits/core/views/api.py | 214 ++++++++++++++- lnbits/core/views/generic.py | 126 ++++++--- lnbits/helpers.py | 46 +++- lnbits/settings.py | 1 + lnbits/static/js/base.js | 3 +- 13 files changed, 847 insertions(+), 99 deletions(-) create mode 100644 lnbits/core/helpers.py create mode 100644 lnbits/core/templates/core/install.html diff --git a/lnbits/app.py b/lnbits/app.py index 1b1292ce..9fb388c0 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,10 +1,15 @@ import asyncio +import glob import importlib import logging +import os import signal import sys import traceback +import zipfile from http import HTTPStatus +from pathlib import Path +from typing import Callable from fastapi import FastAPI, Request from fastapi.exceptions import HTTPException, RequestValidationError @@ -18,10 +23,12 @@ from lnbits.core.tasks import register_task_listeners from lnbits.settings import get_wallet_class, set_wallet_class, settings from .commands import migrate_databases -from .core import core_app +from .core import core_app, core_app_extra from .core.services import check_admin_settings from .core.views.generic import core_html_routes from .helpers import ( + EnabledExtensionMiddleware, + Extension, get_css_vendored, get_js_vendored, get_valid_extensions, @@ -65,6 +72,7 @@ def create_app() -> FastAPI: ) app.add_middleware(GZipMiddleware, minimum_size=1000) + app.add_middleware(EnabledExtensionMiddleware) register_startup(app) register_assets(app) @@ -72,6 +80,8 @@ def create_app() -> FastAPI: register_async_tasks(app) register_exception_handlers(app) + setattr(core_app_extra, "register_new_ext_routes", register_new_ext_routes(app)) + return app @@ -105,6 +115,22 @@ async def check_funding_source() -> None: ) +def check_installed_extensions(): + """ + Check extensions that have been installed, but for some reason no longer present in the 'lnbits/extensions' directory. + One reason might be a docker-container that was re-created. + The 'data' directory (where the '.zip' files live) is expected to persist state. + """ + extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") + + zip_files = glob.glob(f"{extensions_data_dir}/*.zip") + for zip_file in zip_files: + ext_name = Path(zip_file).stem + if not Path(f"lnbits/extensions/{ext_name}").is_dir(): + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall("lnbits/extensions/") + + def register_routes(app: FastAPI) -> None: """Register FastAPI routes / LNbits extensions.""" app.include_router(core_app) @@ -112,20 +138,7 @@ def register_routes(app: FastAPI) -> None: for ext in get_valid_extensions(): try: - ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") - ext_route = getattr(ext_module, f"{ext.code}_ext") - - if hasattr(ext_module, f"{ext.code}_start"): - ext_start_func = getattr(ext_module, f"{ext.code}_start") - ext_start_func() - - if hasattr(ext_module, f"{ext.code}_static_files"): - ext_statics = getattr(ext_module, f"{ext.code}_static_files") - for s in ext_statics: - app.mount(s["path"], s["app"], s["name"]) - - logger.trace(f"adding route for extension {ext_module}") - app.include_router(ext_route) + register_ext_routes(app, ext) except Exception as e: logger.error(str(e)) raise ImportError( @@ -133,6 +146,31 @@ def register_routes(app: FastAPI) -> None: ) +def register_new_ext_routes(app: FastAPI) -> Callable: + def register_new_ext_routes_fn(ext: Extension): + register_ext_routes(app, ext) + + return register_new_ext_routes_fn + + +def register_ext_routes(app: FastAPI, ext: Extension) -> None: + """Register FastAPI routes for extension.""" + ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") + ext_route = getattr(ext_module, f"{ext.code}_ext") + + if hasattr(ext_module, f"{ext.code}_start"): + ext_start_func = getattr(ext_module, f"{ext.code}_start") + ext_start_func() + + if hasattr(ext_module, f"{ext.code}_static_files"): + ext_statics = getattr(ext_module, f"{ext.code}_static_files") + for s in ext_statics: + app.mount(s["path"], s["app"], s["name"]) + + logger.trace(f"adding route for extension {ext_module}") + app.include_router(ext_route) + + def register_startup(app: FastAPI): @app.on_event("startup") async def lnbits_startup(): @@ -151,6 +189,9 @@ def register_startup(app: FastAPI): # 4. initialize funding source await check_funding_source() + + # 5. check extensions in `data` directory + await check_installed_extensions() except Exception as e: logger.error(str(e)) raise ImportError("Failed to run 'startup' event.") diff --git a/lnbits/commands.py b/lnbits/commands.py index 82ea1430..66b45b93 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -11,6 +11,8 @@ from lnbits.settings import settings from .core import db as core_db from .core import migrations as core_migrations +from .core.crud import USER_ID_ALL, get_dbversions, get_inactive_extensions +from .core.helpers import migrate_extension_database, run_migration from .db import COCKROACH, POSTGRES, SQLITE from .helpers import ( get_css_vendored, @@ -59,30 +61,6 @@ def bundle_vendored(): async def migrate_databases(): """Creates the necessary databases if they don't exist already; or migrates them.""" - async def set_migration_version(conn, db_name, version): - await conn.execute( - """ - INSERT INTO dbversions (db, version) VALUES (?, ?) - ON CONFLICT (db) DO UPDATE SET version = ? - """, - (db_name, version, version), - ) - - async def run_migration(db, migrations_module, db_name): - for key, migrate in migrations_module.__dict__.items(): - match = match = matcher.match(key) - if match: - version = int(match.group(1)) - if version > current_versions.get(db_name, 0): - logger.debug(f"running migration {db_name}.{version}") - await migrate(db) - - if db.schema == None: - await set_migration_version(db, db_name, version) - else: - async with core_db.connect() as conn: - await set_migration_version(conn, db_name, version) - async with core_db.connect() as conn: if conn.type == SQLITE: exists = await conn.fetchone( @@ -96,27 +74,18 @@ async def migrate_databases(): if not exists: await core_migrations.m000_create_migrations_table(conn) - rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall() - current_versions = {row["db"]: row["version"] for row in rows} - matcher = re.compile(r"^m(\d\d\d)_") - db_name = core_migrations.__name__.split(".")[-2] - await run_migration(conn, core_migrations, db_name) + current_versions = await get_dbversions(conn) + core_version = current_versions.get("core", 0) + await run_migration(conn, core_migrations, core_version) for ext in get_valid_extensions(): - try: - - module_str = ( - ext.migration_module or f"lnbits.extensions.{ext.code}.migrations" - ) - ext_migrations = importlib.import_module(module_str) - ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db - db_name = ext.db_name or module_str.split(".")[-2] - except ImportError: - raise ImportError( - f"Please make sure that the extension `{ext.code}` has a migrations file." - ) - - async with ext_db.connect() as ext_conn: - await run_migration(ext_conn, ext_migrations, db_name) + current_version = current_versions.get(ext.code, 0) + await migrate_extension_database(ext, current_version) logger.info("✔️ All migrations done.") + + +async def load_disabled_extension_list() -> None: + """Update list of extensions that have been explicitly disabled""" + inactive_extensions = await get_inactive_extensions(user_id=USER_ID_ALL) + settings.lnbits_disabled_extensions += inactive_extensions diff --git a/lnbits/core/__init__.py b/lnbits/core/__init__.py index dec15d78..75b6d587 100644 --- a/lnbits/core/__init__.py +++ b/lnbits/core/__init__.py @@ -1,11 +1,14 @@ from fastapi.routing import APIRouter +from lnbits.core.models import CoreAppExtra from lnbits.db import Database db = Database("database") core_app: APIRouter = APIRouter() +core_app_extra: CoreAppExtra = CoreAppExtra() + from .views.admin_api import * # noqa from .views.api import * # noqa from .views.generic import * # noqa diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index a80fadf2..1289c33a 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -11,6 +11,8 @@ from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, sett from . import db from .models import BalanceCheck, Payment, User, Wallet +USER_ID_ALL = "all" + # accounts # -------- @@ -78,6 +80,18 @@ async def update_user_extension( ) +async def get_inactive_extensions( + *, user_id: str, conn: Optional[Connection] = None +) -> List[str]: + inactive_extensions = await (conn or db).fetchall( + """SELECT extension FROM extensions WHERE "user" = ? AND NOT active""", + (user_id,), + ) + return ( + [ext[0] for ext in inactive_extensions] if len(inactive_extensions) != 0 else [] + ) + + # wallets # ------- @@ -620,3 +634,20 @@ async def create_admin_settings(super_user: str, new_settings: dict): sql = f"INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)" await db.execute(sql, (super_user, json.dumps(new_settings))) return await get_super_settings() + + +# db versions +# -------------- +async def get_dbversions(conn: Optional[Connection] = None): + rows = await (conn or db).fetchall("SELECT * FROM dbversions") + return {row["db"]: row["version"] for row in rows} + + +async def update_migration_version(conn, db_name, version): + await (conn or db).execute( + """ + INSERT INTO dbversions (db, version) VALUES (?, ?) + ON CONFLICT (db) DO UPDATE SET version = ? + """, + (db_name, version, version), + ) diff --git a/lnbits/core/helpers.py b/lnbits/core/helpers.py new file mode 100644 index 00000000..3675d438 --- /dev/null +++ b/lnbits/core/helpers.py @@ -0,0 +1,93 @@ +import hashlib +import importlib +import re +import urllib.request +from typing import List + +import httpx +from fastapi.exceptions import HTTPException +from loguru import logger + +from lnbits.helpers import InstallableExtension +from lnbits.settings import settings + +from . import db as core_db +from .crud import update_migration_version + + +async def migrate_extension_database(ext, current_version): + try: + ext_migrations = importlib.import_module( + f"lnbits.extensions.{ext.code}.migrations" + ) + ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db + except ImportError: + raise ImportError( + f"Please make sure that the extension `{ext.code}` has a migrations file." + ) + + async with ext_db.connect() as ext_conn: + await run_migration(ext_conn, ext_migrations, current_version) + + +async def run_migration(db, migrations_module, current_version): + matcher = re.compile(r"^m(\d\d\d)_") + db_name = migrations_module.__name__.split(".")[-2] + for key, migrate in migrations_module.__dict__.items(): + match = match = matcher.match(key) + if match: + version = int(match.group(1)) + if version > current_version: + logger.debug(f"running migration {db_name}.{version}") + print(f"running migration {db_name}.{version}") + await migrate(db) + + if db.schema == None: + await update_migration_version(db, db_name, version) + else: + async with core_db.connect() as conn: + await update_migration_version(conn, db_name, version) + + +async def get_installable_extensions() -> List[InstallableExtension]: + extension_list: List[InstallableExtension] = [] + + async with httpx.AsyncClient() as client: + for url in settings.lnbits_extensions_manifests: + resp = await client.get(url) + if resp.status_code != 200: + raise HTTPException( + status_code=404, + detail=f"Unable to fetch extension list for repository: {url}", + ) + for e in resp.json()["extensions"]: + extension_list += [ + InstallableExtension( + id=e["id"], + name=e["name"], + archive=e["archive"], + hash=e["hash"], + short_description=e["shortDescription"], + details=e["details"] if "details" in e else "", + icon=e["icon"], + dependencies=e["dependencies"] if "dependencies" in e else [], + ) + ] + + return extension_list + + +def download_url(url, save_path): + with urllib.request.urlopen(url) as dl_file: + with open(save_path, "wb") as out_file: + out_file.write(dl_file.read()) + + +def file_hash(filename): + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() diff --git a/lnbits/core/models.py b/lnbits/core/models.py index eca1bf50..7b147208 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -4,7 +4,7 @@ import hmac import json import time from sqlite3 import Row -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional from ecdsa import SECP256k1, SigningKey from fastapi import Query @@ -213,3 +213,7 @@ class BalanceCheck(BaseModel): @classmethod def from_row(cls, row: Row): return cls(wallet=row["wallet"], service=row["service"], url=row["url"]) + + +class CoreAppExtra: + register_new_ext_routes: Callable diff --git a/lnbits/core/templates/core/extensions.html b/lnbits/core/templates/core/extensions.html index 88e50269..dc0037e2 100644 --- a/lnbits/core/templates/core/extensions.html +++ b/lnbits/core/templates/core/extensions.html @@ -4,6 +4,15 @@ {% endblock %} {% block page %}
+
+ Add or Remove Extensions +
+
+ Back +
+
+ +
+ +
+ + + +
+
+ +
+
+ + +
+
+ + +
+
+ + More + +
+
+ + {% raw %} +
{{ extension.name}}
+ {{ extension.shortDescription }} +
+ Depends on: +   + + + +
+ {% endraw %} +
+ + +
+
+ + Uninstall + +
+
+ + + Install +
+
+ +
+ + + + + + + +
+
+
+
+ + +
Warning
+

+ You are about to remove the extension for all users.
+ Are you sure you want to continue? +

+ +
+ Yes, Uninstall + Cancel +
+
+
+ + {%raw%} + + +
{{selectedExtension.name}}
+
+ +
+ Done +
+
+
+ {%endraw%} +
+{% endblock %} {% block scripts %} {{ window_vars(user) }} + +{% endblock %} diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index d545df9a..7e552e8a 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -1,11 +1,18 @@ import asyncio import hashlib +import importlib +import inspect import json +import os +import shutil +import sys import time import uuid +import zipfile from http import HTTPStatus from io import BytesIO -from typing import Dict, Optional, Tuple, Union +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse import async_timeout @@ -22,6 +29,8 @@ from fastapi import ( WebSocketDisconnect, ) from fastapi.exceptions import HTTPException +from fastapi.params import Body +from genericpath import isfile from loguru import logger from pydantic import BaseModel from pydantic.fields import Field @@ -29,15 +38,27 @@ from sse_starlette.sse import EventSourceResponse from starlette.responses import StreamingResponse from lnbits import bolt11, lnurl -from lnbits.core.models import Payment, Wallet +from lnbits.core.helpers import ( + download_url, + file_hash, + get_installable_extensions, + migrate_extension_database, +) +from lnbits.core.models import Payment, User, Wallet from lnbits.decorators import ( WalletTypeInfo, check_admin, + check_user_exists, get_key_type, require_admin_key, require_invoice_key, ) -from lnbits.helpers import url_for +from lnbits.helpers import ( + Extension, + InstallableExtension, + get_valid_extensions, + url_for, +) from lnbits.settings import get_wallet_class, settings from lnbits.utils.exchange_rates import ( currencies, @@ -45,13 +66,16 @@ from lnbits.utils.exchange_rates import ( satoshis_amount_as_fiat, ) -from .. import core_app, db +from .. import core_app, core_app_extra, db from ..crud import ( + USER_ID_ALL, + get_dbversions, get_payments, get_standalone_payment, get_total_balance, get_wallet_for_key, save_balance_check, + update_user_extension, update_wallet, ) from ..services import ( @@ -706,3 +730,185 @@ async def websocket_update_get(item_id: str, data: str): return {"sent": True, "data": data} except: return {"sent": False, "data": data} + + +@core_app.post("/api/v1/extension/{ext_id}/{hash}") +async def api_install_extension( + ext_id: str, hash: str, user: User = Depends(check_user_exists) +): + if not user.admin: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users" + ) + + try: + extension_list: List[InstallableExtension] = await get_installable_extensions() + except Exception as ex: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Cannot fetch installable extension list", + ) + + extensions = [e for e in extension_list if e.id == ext_id and e.hash == hash] + if len(extensions) == 0: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Unknown extension id: {ext_id}", + ) + extension = extensions[0] + + # check that all dependecies are installed + installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) + if not set(extension.dependencies).issubset(installed_extensions): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"Not all dependencies are installed: {extension.dependencies}", + ) + + # move files to the right location + extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") + os.makedirs(extensions_data_dir, exist_ok=True) + ext_data_dir = os.path.join(extensions_data_dir, ext_id) + shutil.rmtree(ext_data_dir, True) + ext_zip_file = os.path.join(extensions_data_dir, f"{ext_id}.zip") + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + + try: + download_url(extension.archive, ext_zip_file) + except Exception as ex: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Cannot fetch extension archive file", + ) + + archive_hash = file_hash(ext_zip_file) + if extension.hash != archive_hash: + # remove downloaded archive + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="File hash missmatch. Will not install.", + ) + + try: + ext_dir = os.path.join("lnbits/extensions", ext_id) + shutil.rmtree(ext_dir, True) + with zipfile.ZipFile(ext_zip_file, "r") as zip_ref: + zip_ref.extractall("lnbits/extensions") + + # todo: is admin only + ext = Extension(extension.id, True, extension.is_admin_only, extension.name) + + current_versions = await get_dbversions() + current_version = current_versions.get(ext.code, 0) + + module_name = f"lnbits.extensions.{ext.code}" + # if module_name in sys.modules: + # importlib.reload(sys.modules[module_name]) + ext_module = importlib.import_module(module_name) + # sys.modules[module_name] = importlib.reload(ext_module) + + modules_to_reload = list_modules_for_extension(ext_id) + print("### modules_to_reload", modules_to_reload) + for m in modules_to_reload: + importlib.reload(sys.modules[m]) + + await migrate_extension_database(ext, current_version) + + # disable by default + await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False) + settings.lnbits_disabled_extensions += [ext_id] + + # mount routes at the very end + core_app_extra.register_new_ext_routes(ext) + except Exception as ex: + logger.warning(ex) + # remove downloaded archive + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + + # remove module from extensions + shutil.rmtree(ext_dir, True) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex) + ) + + +@core_app.delete("/api/v1/extension/{ext_id}") +async def api_uninstall_extension(ext_id: str, user: User = Depends(check_user_exists)): + if not user.admin: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users" + ) + + try: + extension_list: List[InstallableExtension] = await get_installable_extensions() + except Exception as ex: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Cannot fetch installable extension list", + ) + + extensions = [e for e in extension_list if e.id == ext_id] + if len(extensions) == 0: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Unknown extension id: {ext_id}", + ) + + # check that other extensions do not depend on this one + for active_ext_id in list(map(lambda e: e.code, get_valid_extensions(True))): + active_ext = next( + (ext for ext in extension_list if ext.id == active_ext_id), None + ) + if active_ext and ext_id in active_ext.dependencies: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Cannot uninstall. Extension '{active_ext.name}' depends on this one.", + ) + + try: + settings.lnbits_disabled_extensions += [ext_id] + + # remove downloaded archive + ext_zip_file = os.path.join( + settings.lnbits_data_folder, "extensions", f"{ext_id}.zip" + ) + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + + # module_name = f"lnbits.extensions.{ext_id}" + + # modules_to_delete = list_modules_for_extension(ext_id) + # print('### modules_to_delete', modules_to_delete) + # for m in modules_to_delete: + # module = sys.modules[m] + # del sys.modules[m] + # del module + + # remove module from extensions + ext_dir = os.path.join("lnbits/extensions", ext_id) + shutil.rmtree(ext_dir, True) + except Exception as ex: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex) + ) + + +def list_modules_for_extension(ext_id: str) -> List[str]: + modules_for_extension = [] + for key in sys.modules.keys(): + try: + module = sys.modules[key] + moduleFilePath = inspect.getfile(module).lower() + + dir_name = str(Path(moduleFilePath).parent.absolute()) + if dir_name.endswith(f"lnbits/extensions/{ext_id}"): + print("## moduleFilePath", moduleFilePath) + modules_for_extension += [key] + + except: + pass # built in modules throw if queried + return modules_for_extension diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index ab19feb8..d14a43f6 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -1,6 +1,6 @@ import asyncio from http import HTTPStatus -from typing import Optional +from typing import List, Optional from fastapi import Depends, Query, Request, status from fastapi.exceptions import HTTPException @@ -11,17 +11,20 @@ from pydantic.types import UUID4 from starlette.responses import HTMLResponse, JSONResponse from lnbits.core import db +from lnbits.core.helpers import get_installable_extensions from lnbits.core.models import User from lnbits.decorators import check_admin, check_user_exists from lnbits.helpers import template_renderer, url_for from lnbits.settings import get_wallet_class, settings -from ...helpers import get_valid_extensions +from ...helpers import InstallableExtension, get_valid_extensions from ..crud import ( + USER_ID_ALL, create_account, create_wallet, delete_wallet, get_balance_check, + get_inactive_extensions, get_user, save_balance_notify, update_user_extension, @@ -52,35 +55,10 @@ async def extensions( enable: str = Query(None), disable: str = Query(None), ): - extension_to_enable = enable - extension_to_disable = disable - - if extension_to_enable and extension_to_disable: - raise HTTPException( - HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension." - ) - - # check if extension exists - if extension_to_enable or extension_to_disable: - ext = extension_to_enable or extension_to_disable - if ext not in [e.code for e in get_valid_extensions()]: - raise HTTPException( - HTTPStatus.BAD_REQUEST, f"Extension '{ext}' doesn't exist." - ) - - if extension_to_enable: - logger.info(f"Enabling extension: {extension_to_enable} for user {user.id}") - await update_user_extension( - user_id=user.id, extension=extension_to_enable, active=True - ) - elif extension_to_disable: - logger.info(f"Disabling extension: {extension_to_disable} for user {user.id}") - await update_user_extension( - user_id=user.id, extension=extension_to_disable, active=False - ) + await toggle_extension(enable, disable, user.id) # Update user as his extensions have been updated - if extension_to_enable or extension_to_disable: + if enable or disable: user = await get_user(user.id) # type: ignore return template_renderer().TemplateResponse( @@ -88,6 +66,70 @@ async def extensions( ) +@core_html_routes.get( + "/install", name="install.extensions", response_class=HTMLResponse +) +async def extensions_install( + request: Request, + user: User = Depends(check_user_exists), # type: ignore + activate: str = Query(None), # type: ignore + deactivate: str = Query(None), # type: ignore +): + if not user.admin: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users" + ) + + try: + extension_list: List[InstallableExtension] = await get_installable_extensions() + except Exception as ex: + logger.warning(ex) + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Cannot fetch installable extension list", + ) + + try: + if deactivate: + settings.lnbits_disabled_extensions += [deactivate] + elif activate: + settings.lnbits_disabled_extensions = list( + filter(lambda e: e != activate, settings.lnbits_disabled_extensions) + ) + await toggle_extension(activate, deactivate, USER_ID_ALL) + + installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) + inactive_extensions = await get_inactive_extensions(user_id=USER_ID_ALL) + extensions = list( + map( + lambda ext: { + "id": ext.id, + "name": ext.name, + "hash": ext.hash, + "icon": ext.icon, + "shortDescription": ext.short_description, + "details": ext.details, + "dependencies": ext.dependencies, + "isInstalled": ext.id in installed_extensions, + "isActive": not ext.id in inactive_extensions, + }, + extension_list, + ) + ) + + return template_renderer().TemplateResponse( + "core/install.html", + { + "request": request, + "user": user.dict(), + "extensions": extensions, + }, + ) + except Exception as e: + logger.warning(e) + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) + + @core_html_routes.get( "/wallet", response_class=HTMLResponse, @@ -327,3 +369,29 @@ async def index(request: Request, user: User = Depends(check_admin)): "balance": balance, }, ) + + +async def toggle_extension(extension_to_enable, extension_to_disable, user_id): + if extension_to_enable and extension_to_disable: + raise HTTPException( + HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension." + ) + + # check if extension exists + if extension_to_enable or extension_to_disable: + ext = extension_to_enable or extension_to_disable + if ext not in [e.code for e in get_valid_extensions(True)]: + raise HTTPException( + HTTPStatus.BAD_REQUEST, f"Extension '{ext}' doesn't exist." + ) + + if extension_to_enable: + logger.info(f"Enabling extension: {extension_to_enable} for user {user_id}") + await update_user_extension( + user_id=user_id, extension=extension_to_enable, active=True + ) + elif extension_to_disable: + logger.info(f"Disabling extension: {extension_to_disable} for user {user_id}") + await update_user_extension( + user_id=user_id, extension=extension_to_disable, active=False + ) diff --git a/lnbits/helpers.py b/lnbits/helpers.py index 4804bdea..52a7f6ab 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -1,10 +1,13 @@ import glob import json import os +from http import HTTPStatus from typing import Any, List, NamedTuple, Optional import jinja2 -import shortuuid +import shortuuid # type: ignore +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send from lnbits.jinja2_templating import Jinja2Templates from lnbits.requestvars import g @@ -25,8 +28,10 @@ class Extension(NamedTuple): class ExtensionManager: - def __init__(self): - self._disabled: List[str] = settings.lnbits_disabled_extensions + def __init__(self, include_disabled_exts=False): + self._disabled: List[str] = ( + [] if include_disabled_exts else settings.lnbits_disabled_extensions + ) self._admin_only: List[str] = settings.lnbits_admin_extensions self._extension_folders: List[str] = [ x[1] for x in os.walk(os.path.join(settings.lnbits_path, "extensions")) @@ -74,9 +79,40 @@ class ExtensionManager: return output -def get_valid_extensions() -> List[Extension]: +class EnabledExtensionMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + pathname = scope["path"].split("/")[1] + if pathname in settings.lnbits_disabled_extensions: + response = JSONResponse( + status_code=HTTPStatus.NOT_FOUND, + content={"detail": f"Extension '{pathname}' disabled"}, + ) + await response(scope, receive, send) + return + + await self.app(scope, receive, send) + + +class InstallableExtension(NamedTuple): + id: str + name: str + archive: str + hash: str + short_description: Optional[str] = None + details: Optional[str] = None + icon: Optional[str] = None + dependencies: List[str] = [] + is_admin_only: bool = False + + +def get_valid_extensions(include_disabled_exts=False) -> List[Extension]: return [ - extension for extension in ExtensionManager().extensions if extension.is_valid + extension + for extension in ExtensionManager(include_disabled_exts).extensions + if extension.is_valid ] diff --git a/lnbits/settings.py b/lnbits/settings.py index 6ec4db0c..d00d038d 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -40,6 +40,7 @@ class UsersSettings(LNbitsSettings): lnbits_allowed_users: List[str] = Field(default=[]) lnbits_admin_extensions: List[str] = Field(default=[]) lnbits_disabled_extensions: List[str] = Field(default=[]) + lnbits_extensions_manifests: List[str] = Field(default=[]) class ThemesSettings(LNbitsSettings): diff --git a/lnbits/static/js/base.js b/lnbits/static/js/base.js index 32b075b7..d424d563 100644 --- a/lnbits/static/js/base.js +++ b/lnbits/static/js/base.js @@ -141,7 +141,8 @@ window.LNbits = { admin: data.admin, email: data.email, extensions: data.extensions, - wallets: data.wallets + wallets: data.wallets, + admin: data.admin } var mapWallet = this.wallet obj.wallets = obj.wallets