chore: migrate after major changes on main

This commit is contained in:
Vlad Stan 2022-12-22 16:30:37 +02:00
parent 331e93195d
commit 66c908e600
13 changed files with 847 additions and 99 deletions

View File

@ -1,10 +1,15 @@
import asyncio
import glob
import importlib
import logging
import os
import signal
import sys
import traceback
import zipfile
from http import HTTPStatus
from pathlib import Path
from typing import Callable
from fastapi import FastAPI, Request
from fastapi.exceptions import HTTPException, RequestValidationError
@ -18,10 +23,12 @@ from lnbits.core.tasks import register_task_listeners
from lnbits.settings import get_wallet_class, set_wallet_class, settings
from .commands import migrate_databases
from .core import core_app
from .core import core_app, core_app_extra
from .core.services import check_admin_settings
from .core.views.generic import core_html_routes
from .helpers import (
EnabledExtensionMiddleware,
Extension,
get_css_vendored,
get_js_vendored,
get_valid_extensions,
@ -65,6 +72,7 @@ def create_app() -> FastAPI:
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(EnabledExtensionMiddleware)
register_startup(app)
register_assets(app)
@ -72,6 +80,8 @@ def create_app() -> FastAPI:
register_async_tasks(app)
register_exception_handlers(app)
setattr(core_app_extra, "register_new_ext_routes", register_new_ext_routes(app))
return app
@ -105,6 +115,22 @@ async def check_funding_source() -> None:
)
def check_installed_extensions():
"""
Check extensions that have been installed, but for some reason no longer present in the 'lnbits/extensions' directory.
One reason might be a docker-container that was re-created.
The 'data' directory (where the '.zip' files live) is expected to persist state.
"""
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
zip_files = glob.glob(f"{extensions_data_dir}/*.zip")
for zip_file in zip_files:
ext_name = Path(zip_file).stem
if not Path(f"lnbits/extensions/{ext_name}").is_dir():
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall("lnbits/extensions/")
def register_routes(app: FastAPI) -> None:
"""Register FastAPI routes / LNbits extensions."""
app.include_router(core_app)
@ -112,20 +138,7 @@ def register_routes(app: FastAPI) -> None:
for ext in get_valid_extensions():
try:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
ext_route = getattr(ext_module, f"{ext.code}_ext")
if hasattr(ext_module, f"{ext.code}_start"):
ext_start_func = getattr(ext_module, f"{ext.code}_start")
ext_start_func()
if hasattr(ext_module, f"{ext.code}_static_files"):
ext_statics = getattr(ext_module, f"{ext.code}_static_files")
for s in ext_statics:
app.mount(s["path"], s["app"], s["name"])
logger.trace(f"adding route for extension {ext_module}")
app.include_router(ext_route)
register_ext_routes(app, ext)
except Exception as e:
logger.error(str(e))
raise ImportError(
@ -133,6 +146,31 @@ def register_routes(app: FastAPI) -> None:
)
def register_new_ext_routes(app: FastAPI) -> Callable:
def register_new_ext_routes_fn(ext: Extension):
register_ext_routes(app, ext)
return register_new_ext_routes_fn
def register_ext_routes(app: FastAPI, ext: Extension) -> None:
"""Register FastAPI routes for extension."""
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
ext_route = getattr(ext_module, f"{ext.code}_ext")
if hasattr(ext_module, f"{ext.code}_start"):
ext_start_func = getattr(ext_module, f"{ext.code}_start")
ext_start_func()
if hasattr(ext_module, f"{ext.code}_static_files"):
ext_statics = getattr(ext_module, f"{ext.code}_static_files")
for s in ext_statics:
app.mount(s["path"], s["app"], s["name"])
logger.trace(f"adding route for extension {ext_module}")
app.include_router(ext_route)
def register_startup(app: FastAPI):
@app.on_event("startup")
async def lnbits_startup():
@ -151,6 +189,9 @@ def register_startup(app: FastAPI):
# 4. initialize funding source
await check_funding_source()
# 5. check extensions in `data` directory
await check_installed_extensions()
except Exception as e:
logger.error(str(e))
raise ImportError("Failed to run 'startup' event.")

View File

@ -11,6 +11,8 @@ from lnbits.settings import settings
from .core import db as core_db
from .core import migrations as core_migrations
from .core.crud import USER_ID_ALL, get_dbversions, get_inactive_extensions
from .core.helpers import migrate_extension_database, run_migration
from .db import COCKROACH, POSTGRES, SQLITE
from .helpers import (
get_css_vendored,
@ -59,30 +61,6 @@ def bundle_vendored():
async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
async def set_migration_version(conn, db_name, version):
await conn.execute(
"""
INSERT INTO dbversions (db, version) VALUES (?, ?)
ON CONFLICT (db) DO UPDATE SET version = ?
""",
(db_name, version, version),
)
async def run_migration(db, migrations_module, db_name):
for key, migrate in migrations_module.__dict__.items():
match = match = matcher.match(key)
if match:
version = int(match.group(1))
if version > current_versions.get(db_name, 0):
logger.debug(f"running migration {db_name}.{version}")
await migrate(db)
if db.schema == None:
await set_migration_version(db, db_name, version)
else:
async with core_db.connect() as conn:
await set_migration_version(conn, db_name, version)
async with core_db.connect() as conn:
if conn.type == SQLITE:
exists = await conn.fetchone(
@ -96,27 +74,18 @@ async def migrate_databases():
if not exists:
await core_migrations.m000_create_migrations_table(conn)
rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
current_versions = {row["db"]: row["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
db_name = core_migrations.__name__.split(".")[-2]
await run_migration(conn, core_migrations, db_name)
current_versions = await get_dbversions(conn)
core_version = current_versions.get("core", 0)
await run_migration(conn, core_migrations, core_version)
for ext in get_valid_extensions():
try:
module_str = (
ext.migration_module or f"lnbits.extensions.{ext.code}.migrations"
)
ext_migrations = importlib.import_module(module_str)
ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
db_name = ext.db_name or module_str.split(".")[-2]
except ImportError:
raise ImportError(
f"Please make sure that the extension `{ext.code}` has a migrations file."
)
async with ext_db.connect() as ext_conn:
await run_migration(ext_conn, ext_migrations, db_name)
current_version = current_versions.get(ext.code, 0)
await migrate_extension_database(ext, current_version)
logger.info("✔️ All migrations done.")
async def load_disabled_extension_list() -> None:
"""Update list of extensions that have been explicitly disabled"""
inactive_extensions = await get_inactive_extensions(user_id=USER_ID_ALL)
settings.lnbits_disabled_extensions += inactive_extensions

View File

@ -1,11 +1,14 @@
from fastapi.routing import APIRouter
from lnbits.core.models import CoreAppExtra
from lnbits.db import Database
db = Database("database")
core_app: APIRouter = APIRouter()
core_app_extra: CoreAppExtra = CoreAppExtra()
from .views.admin_api import * # noqa
from .views.api import * # noqa
from .views.generic import * # noqa

View File

@ -11,6 +11,8 @@ from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, sett
from . import db
from .models import BalanceCheck, Payment, User, Wallet
USER_ID_ALL = "all"
# accounts
# --------
@ -78,6 +80,18 @@ async def update_user_extension(
)
async def get_inactive_extensions(
*, user_id: str, conn: Optional[Connection] = None
) -> List[str]:
inactive_extensions = await (conn or db).fetchall(
"""SELECT extension FROM extensions WHERE "user" = ? AND NOT active""",
(user_id,),
)
return (
[ext[0] for ext in inactive_extensions] if len(inactive_extensions) != 0 else []
)
# wallets
# -------
@ -620,3 +634,20 @@ async def create_admin_settings(super_user: str, new_settings: dict):
sql = f"INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)"
await db.execute(sql, (super_user, json.dumps(new_settings)))
return await get_super_settings()
# db versions
# --------------
async def get_dbversions(conn: Optional[Connection] = None):
rows = await (conn or db).fetchall("SELECT * FROM dbversions")
return {row["db"]: row["version"] for row in rows}
async def update_migration_version(conn, db_name, version):
await (conn or db).execute(
"""
INSERT INTO dbversions (db, version) VALUES (?, ?)
ON CONFLICT (db) DO UPDATE SET version = ?
""",
(db_name, version, version),
)

93
lnbits/core/helpers.py Normal file
View File

@ -0,0 +1,93 @@
import hashlib
import importlib
import re
import urllib.request
from typing import List
import httpx
from fastapi.exceptions import HTTPException
from loguru import logger
from lnbits.helpers import InstallableExtension
from lnbits.settings import settings
from . import db as core_db
from .crud import update_migration_version
async def migrate_extension_database(ext, current_version):
try:
ext_migrations = importlib.import_module(
f"lnbits.extensions.{ext.code}.migrations"
)
ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
except ImportError:
raise ImportError(
f"Please make sure that the extension `{ext.code}` has a migrations file."
)
async with ext_db.connect() as ext_conn:
await run_migration(ext_conn, ext_migrations, current_version)
async def run_migration(db, migrations_module, current_version):
matcher = re.compile(r"^m(\d\d\d)_")
db_name = migrations_module.__name__.split(".")[-2]
for key, migrate in migrations_module.__dict__.items():
match = match = matcher.match(key)
if match:
version = int(match.group(1))
if version > current_version:
logger.debug(f"running migration {db_name}.{version}")
print(f"running migration {db_name}.{version}")
await migrate(db)
if db.schema == None:
await update_migration_version(db, db_name, version)
else:
async with core_db.connect() as conn:
await update_migration_version(conn, db_name, version)
async def get_installable_extensions() -> List[InstallableExtension]:
extension_list: List[InstallableExtension] = []
async with httpx.AsyncClient() as client:
for url in settings.lnbits_extensions_manifests:
resp = await client.get(url)
if resp.status_code != 200:
raise HTTPException(
status_code=404,
detail=f"Unable to fetch extension list for repository: {url}",
)
for e in resp.json()["extensions"]:
extension_list += [
InstallableExtension(
id=e["id"],
name=e["name"],
archive=e["archive"],
hash=e["hash"],
short_description=e["shortDescription"],
details=e["details"] if "details" in e else "",
icon=e["icon"],
dependencies=e["dependencies"] if "dependencies" in e else [],
)
]
return extension_list
def download_url(url, save_path):
with urllib.request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()

View File

@ -4,7 +4,7 @@ import hmac
import json
import time
from sqlite3 import Row
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional
from ecdsa import SECP256k1, SigningKey
from fastapi import Query
@ -213,3 +213,7 @@ class BalanceCheck(BaseModel):
@classmethod
def from_row(cls, row: Row):
return cls(wallet=row["wallet"], service=row["service"], url=row["url"])
class CoreAppExtra:
register_new_ext_routes: Callable

View File

@ -4,6 +4,15 @@
{% endblock %} {% block page %}
<div class="row q-col-gutter-md q-mb-md">
<div class="col-sm-3 col-xs-8 q-ml-auto">
<div class="col-sm-7 col-xs-6 mt-lg">
<q-btn
v-if="g.user.admin"
type="a"
:href="['/install?usr=', g.user.id].join('')"
color="primary unelevated mt-lg pt-lg"
>Add or Remove Extensions</q-btn
>
</div>
<q-input v-model="searchTerm" label="Search extensions">
<q-icon
v-if="searchTerm !== ''"

View File

@ -0,0 +1,286 @@
{% extends "base.html" %} {% from "macros.jinja" import window_vars with context
%} {% block page %}
<div class="row q-col-gutter-md q-mb-md">
<div class="col-sm-1 col-xs-4 mt-lg">
<q-btn
type="a"
:href="['/extensions?usr=', g.user.id].join('')"
color="primary unelevated mt-lg pt-lg"
>Back</q-btn
>
</div>
<div class="col-sm-6 col-xs-4 mt-lg">
<q-toggle
label="Installed Only"
color="secodary"
class="float-left"
v-model="showOnlyInstalledExtensions"
></q-toggle>
</div>
<div class="col-sm-3 col-xs-4 q-ml-auto">
<q-input v-model="searchTerm" label="Search extensions">
<q-icon
v-if="searchTerm !== ''"
name="close"
@click="searchTerm = ''"
class="cursor-pointer q-mt-lg"
/>
</q-input>
</div>
</div>
<div class="row q-col-gutter-md">
<div
class="col-xs-12 col-md-6 col-lg-4"
v-for="extension in filteredExtensions"
:key="extension.id"
>
<q-card>
<q-card-section>
<div class="row">
<div class="col-3">
<!-- hack must find better solution -->
<q-icon
:name="extension.icon"
color="grey-5"
style="font-size: 4rem"
></q-icon>
</div>
<div class="col-9">
<q-badge
v-if="extension.details"
@click="showExtensionDetails(extension)"
color="secondary"
class="cursor-pointer float-right"
>
<small>More</small>
</q-badge>
</div>
</div>
{% raw %}
<h5 class="q-mt-lg q-mb-xs">{{ extension.name}}</h5>
<small>{{ extension.shortDescription }} </small>
<div>
<small v-if="extension.dependencies?.length">Depends on:</small>
<small v-else>&nbsp;</small>
<q-badge
v-for="dep in extension.dependencies"
:key="dep"
color="orange"
>
<small v-text="dep"></small>
</q-badge>
</div>
{% endraw %}
</q-card-section>
<q-separator></q-separator>
<q-card-actions>
<div class="col-6">
<div v-if="extension.isInstalled">
<q-btn @click="showUninstall(extension)" flat color="grey-5">
Uninstall</q-btn
>
<q-toggle
:label="extension.isActive ? 'Activated': 'Deactivated' "
color="secodary"
v-model="extension.isActive"
@input="toggleExtension(extension)"
></q-toggle>
</div>
<div v-else>
<q-spinner
v-if="extension.inProgress "
color="primary"
size="2.55em"
></q-spinner>
<q-btn
v-else
@click="installExtension(extension)"
flat
color="primary"
>
Install</q-btn
>
</div>
</div>
<div class="col-6">
<q-rating
max="5"
v-model="maxStars"
size="3.5em"
color="yellow"
icon="star_border"
icon-selected="star"
icon-half="star_half"
readonly
no-dimming
class="float-right"
>
<template v-slot:tip-1>
<q-tooltip>User Review Comming Soon</q-tooltip>
</template>
<template v-slot:tip-2>
<q-tooltip>User Review Comming Soon</q-tooltip>
</template>
<template v-slot:tip-3>
<q-tooltip>User Review Comming Soon</q-tooltip>
</template>
<template v-slot:tip-4>
<q-tooltip>User Review Comming Soon</q-tooltip>
</template>
<template v-slot:tip-5>
<q-tooltip>User Review Comming Soon</q-tooltip>
</template>
</q-rating>
</div>
</q-card-actions>
</q-card>
</div>
<q-dialog v-model="showUninstallDialog">
<q-card class="q-pa-lg">
<h6 class="q-my-md text-primary">Warning</h6>
<p>
You are about to remove the extension for all users. <br />
Are you sure you want to continue?
</p>
<div class="row q-mt-lg">
<q-btn outline color="grey" @click="uninstallExtension()"
>Yes, Uninstall</q-btn
>
<q-btn v-close-popup flat color="grey" class="q-ml-auto">Cancel</q-btn>
</div>
</q-card>
</q-dialog>
{%raw%}
<q-dialog v-model="showDetailsDialog">
<q-card v-if="selectedExtension" class="q-pa-lg">
<h6 class="q-my-md text-primary">{{selectedExtension.name}}</h6>
<div
v-if="selectedExtension.details"
v-html="selectedExtension.details"
></div>
<div class="row q-mt-lg">
<q-btn v-close-popup flat color="grey" class="q-ml-auto">Done</q-btn>
</div>
</q-card>
</q-dialog>
{%endraw%}
</div>
{% endblock %} {% block scripts %} {{ window_vars(user) }}
<script>
new Vue({
el: '#vue',
data: function () {
return {
searchTerm: '',
showOnlyInstalledExtensions: false,
filteredExtensions: null,
showUninstallDialog: false,
showDetailsDialog: false,
selectedExtension: null,
maxStars: 0
}
},
watch: {
searchTerm(term) {
this.filterExtensions(term, this.onlyInstalled)
},
showOnlyInstalledExtensions(onlyInstalled) {
this.filterExtensions(this.searchTerm, onlyInstalled)
}
},
methods: {
filterExtensions: function (term, onlyInstalled) {
// Filter the extensions list
function extensionNameContains(searchTerm) {
return function (extension) {
return (
extension.name.toLowerCase().includes(searchTerm.toLowerCase()) ||
extension.shortDescription
?.toLowerCase()
.includes(searchTerm.toLowerCase())
)
}
}
this.filteredExtensions = this.extensions
.filter(e =>
this.showOnlyInstalledExtensions ? e.isInstalled : true
)
.filter(extensionNameContains(term))
},
installExtension: async function (extension) {
try {
extension.inProgress = true
await LNbits.api.request(
'POST',
`/api/v1/extension/${extension.id}/${extension.hash}?usr=${this.g.user.id}`,
this.g.user.wallets[0].adminkey
)
window.location.href = [
"{{ url_for('install.extensions') }}",
'?usr=',
this.g.user.id
].join('')
} catch (error) {
LNbits.utils.notifyApiError(error)
extension.inProgress = false
}
},
uninstallExtension: async function () {
extension = this.selectedExtension
try {
extension.inProgress = true
await LNbits.api.request(
'DELETE',
`/api/v1/extension/${extension.id}?usr=${this.g.user.id}`,
this.g.user.wallets[0].adminkey
)
window.location.href = [
"{{ url_for('install.extensions') }}",
'?usr=',
this.g.user.id
].join('')
} catch (error) {
LNbits.utils.notifyApiError(error)
extension.inProgress = false
}
},
toggleExtension: function (extension) {
const action = extension.isActive ? 'activate' : 'deactivate'
window.location.href = [
"{{ url_for('install.extensions') }}",
'?usr=',
this.g.user.id,
`&${action}=`,
extension.id
].join('')
},
showUninstall: function (extension) {
this.selectedExtension = extension
this.showUninstallDialog = true
},
showExtensionDetails: function (extension) {
this.selectedExtension = extension
this.showDetailsDialog = true
}
},
created: function () {
this.extensions = JSON.parse('{{extensions | tojson | safe}}').map(e => ({
...e,
inProgress: false
}))
this.filteredExtensions = this.extensions.concat([])
},
mixins: [windowMixin]
})
</script>
{% endblock %}

View File

@ -1,11 +1,18 @@
import asyncio
import hashlib
import importlib
import inspect
import json
import os
import shutil
import sys
import time
import uuid
import zipfile
from http import HTTPStatus
from io import BytesIO
from typing import Dict, Optional, Tuple, Union
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse
import async_timeout
@ -22,6 +29,8 @@ from fastapi import (
WebSocketDisconnect,
)
from fastapi.exceptions import HTTPException
from fastapi.params import Body
from genericpath import isfile
from loguru import logger
from pydantic import BaseModel
from pydantic.fields import Field
@ -29,15 +38,27 @@ from sse_starlette.sse import EventSourceResponse
from starlette.responses import StreamingResponse
from lnbits import bolt11, lnurl
from lnbits.core.models import Payment, Wallet
from lnbits.core.helpers import (
download_url,
file_hash,
get_installable_extensions,
migrate_extension_database,
)
from lnbits.core.models import Payment, User, Wallet
from lnbits.decorators import (
WalletTypeInfo,
check_admin,
check_user_exists,
get_key_type,
require_admin_key,
require_invoice_key,
)
from lnbits.helpers import url_for
from lnbits.helpers import (
Extension,
InstallableExtension,
get_valid_extensions,
url_for,
)
from lnbits.settings import get_wallet_class, settings
from lnbits.utils.exchange_rates import (
currencies,
@ -45,13 +66,16 @@ from lnbits.utils.exchange_rates import (
satoshis_amount_as_fiat,
)
from .. import core_app, db
from .. import core_app, core_app_extra, db
from ..crud import (
USER_ID_ALL,
get_dbversions,
get_payments,
get_standalone_payment,
get_total_balance,
get_wallet_for_key,
save_balance_check,
update_user_extension,
update_wallet,
)
from ..services import (
@ -706,3 +730,185 @@ async def websocket_update_get(item_id: str, data: str):
return {"sent": True, "data": data}
except:
return {"sent": False, "data": data}
@core_app.post("/api/v1/extension/{ext_id}/{hash}")
async def api_install_extension(
ext_id: str, hash: str, user: User = Depends(check_user_exists)
):
if not user.admin:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users"
)
try:
extension_list: List[InstallableExtension] = await get_installable_extensions()
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch installable extension list",
)
extensions = [e for e in extension_list if e.id == ext_id and e.hash == hash]
if len(extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = extensions[0]
# check that all dependecies are installed
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
# move files to the right location
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True)
ext_data_dir = os.path.join(extensions_data_dir, ext_id)
shutil.rmtree(ext_data_dir, True)
ext_zip_file = os.path.join(extensions_data_dir, f"{ext_id}.zip")
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
try:
download_url(extension.archive, ext_zip_file)
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch extension archive file",
)
archive_hash = file_hash(ext_zip_file)
if extension.hash != archive_hash:
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="File hash missmatch. Will not install.",
)
try:
ext_dir = os.path.join("lnbits/extensions", ext_id)
shutil.rmtree(ext_dir, True)
with zipfile.ZipFile(ext_zip_file, "r") as zip_ref:
zip_ref.extractall("lnbits/extensions")
# todo: is admin only
ext = Extension(extension.id, True, extension.is_admin_only, extension.name)
current_versions = await get_dbversions()
current_version = current_versions.get(ext.code, 0)
module_name = f"lnbits.extensions.{ext.code}"
# if module_name in sys.modules:
# importlib.reload(sys.modules[module_name])
ext_module = importlib.import_module(module_name)
# sys.modules[module_name] = importlib.reload(ext_module)
modules_to_reload = list_modules_for_extension(ext_id)
print("### modules_to_reload", modules_to_reload)
for m in modules_to_reload:
importlib.reload(sys.modules[m])
await migrate_extension_database(ext, current_version)
# disable by default
await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False)
settings.lnbits_disabled_extensions += [ext_id]
# mount routes at the very end
core_app_extra.register_new_ext_routes(ext)
except Exception as ex:
logger.warning(ex)
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
# remove module from extensions
shutil.rmtree(ext_dir, True)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex)
)
@core_app.delete("/api/v1/extension/{ext_id}")
async def api_uninstall_extension(ext_id: str, user: User = Depends(check_user_exists)):
if not user.admin:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users"
)
try:
extension_list: List[InstallableExtension] = await get_installable_extensions()
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch installable extension list",
)
extensions = [e for e in extension_list if e.id == ext_id]
if len(extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
# check that other extensions do not depend on this one
for active_ext_id in list(map(lambda e: e.code, get_valid_extensions(True))):
active_ext = next(
(ext for ext in extension_list if ext.id == active_ext_id), None
)
if active_ext and ext_id in active_ext.dependencies:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Cannot uninstall. Extension '{active_ext.name}' depends on this one.",
)
try:
settings.lnbits_disabled_extensions += [ext_id]
# remove downloaded archive
ext_zip_file = os.path.join(
settings.lnbits_data_folder, "extensions", f"{ext_id}.zip"
)
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
# module_name = f"lnbits.extensions.{ext_id}"
# modules_to_delete = list_modules_for_extension(ext_id)
# print('### modules_to_delete', modules_to_delete)
# for m in modules_to_delete:
# module = sys.modules[m]
# del sys.modules[m]
# del module
# remove module from extensions
ext_dir = os.path.join("lnbits/extensions", ext_id)
shutil.rmtree(ext_dir, True)
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex)
)
def list_modules_for_extension(ext_id: str) -> List[str]:
modules_for_extension = []
for key in sys.modules.keys():
try:
module = sys.modules[key]
moduleFilePath = inspect.getfile(module).lower()
dir_name = str(Path(moduleFilePath).parent.absolute())
if dir_name.endswith(f"lnbits/extensions/{ext_id}"):
print("## moduleFilePath", moduleFilePath)
modules_for_extension += [key]
except:
pass # built in modules throw if queried
return modules_for_extension

View File

@ -1,6 +1,6 @@
import asyncio
from http import HTTPStatus
from typing import Optional
from typing import List, Optional
from fastapi import Depends, Query, Request, status
from fastapi.exceptions import HTTPException
@ -11,17 +11,20 @@ from pydantic.types import UUID4
from starlette.responses import HTMLResponse, JSONResponse
from lnbits.core import db
from lnbits.core.helpers import get_installable_extensions
from lnbits.core.models import User
from lnbits.decorators import check_admin, check_user_exists
from lnbits.helpers import template_renderer, url_for
from lnbits.settings import get_wallet_class, settings
from ...helpers import get_valid_extensions
from ...helpers import InstallableExtension, get_valid_extensions
from ..crud import (
USER_ID_ALL,
create_account,
create_wallet,
delete_wallet,
get_balance_check,
get_inactive_extensions,
get_user,
save_balance_notify,
update_user_extension,
@ -52,35 +55,10 @@ async def extensions(
enable: str = Query(None),
disable: str = Query(None),
):
extension_to_enable = enable
extension_to_disable = disable
if extension_to_enable and extension_to_disable:
raise HTTPException(
HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension."
)
# check if extension exists
if extension_to_enable or extension_to_disable:
ext = extension_to_enable or extension_to_disable
if ext not in [e.code for e in get_valid_extensions()]:
raise HTTPException(
HTTPStatus.BAD_REQUEST, f"Extension '{ext}' doesn't exist."
)
if extension_to_enable:
logger.info(f"Enabling extension: {extension_to_enable} for user {user.id}")
await update_user_extension(
user_id=user.id, extension=extension_to_enable, active=True
)
elif extension_to_disable:
logger.info(f"Disabling extension: {extension_to_disable} for user {user.id}")
await update_user_extension(
user_id=user.id, extension=extension_to_disable, active=False
)
await toggle_extension(enable, disable, user.id)
# Update user as his extensions have been updated
if extension_to_enable or extension_to_disable:
if enable or disable:
user = await get_user(user.id) # type: ignore
return template_renderer().TemplateResponse(
@ -88,6 +66,70 @@ async def extensions(
)
@core_html_routes.get(
"/install", name="install.extensions", response_class=HTMLResponse
)
async def extensions_install(
request: Request,
user: User = Depends(check_user_exists), # type: ignore
activate: str = Query(None), # type: ignore
deactivate: str = Query(None), # type: ignore
):
if not user.admin:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users"
)
try:
extension_list: List[InstallableExtension] = await get_installable_extensions()
except Exception as ex:
logger.warning(ex)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch installable extension list",
)
try:
if deactivate:
settings.lnbits_disabled_extensions += [deactivate]
elif activate:
settings.lnbits_disabled_extensions = list(
filter(lambda e: e != activate, settings.lnbits_disabled_extensions)
)
await toggle_extension(activate, deactivate, USER_ID_ALL)
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
inactive_extensions = await get_inactive_extensions(user_id=USER_ID_ALL)
extensions = list(
map(
lambda ext: {
"id": ext.id,
"name": ext.name,
"hash": ext.hash,
"icon": ext.icon,
"shortDescription": ext.short_description,
"details": ext.details,
"dependencies": ext.dependencies,
"isInstalled": ext.id in installed_extensions,
"isActive": not ext.id in inactive_extensions,
},
extension_list,
)
)
return template_renderer().TemplateResponse(
"core/install.html",
{
"request": request,
"user": user.dict(),
"extensions": extensions,
},
)
except Exception as e:
logger.warning(e)
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e))
@core_html_routes.get(
"/wallet",
response_class=HTMLResponse,
@ -327,3 +369,29 @@ async def index(request: Request, user: User = Depends(check_admin)):
"balance": balance,
},
)
async def toggle_extension(extension_to_enable, extension_to_disable, user_id):
if extension_to_enable and extension_to_disable:
raise HTTPException(
HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension."
)
# check if extension exists
if extension_to_enable or extension_to_disable:
ext = extension_to_enable or extension_to_disable
if ext not in [e.code for e in get_valid_extensions(True)]:
raise HTTPException(
HTTPStatus.BAD_REQUEST, f"Extension '{ext}' doesn't exist."
)
if extension_to_enable:
logger.info(f"Enabling extension: {extension_to_enable} for user {user_id}")
await update_user_extension(
user_id=user_id, extension=extension_to_enable, active=True
)
elif extension_to_disable:
logger.info(f"Disabling extension: {extension_to_disable} for user {user_id}")
await update_user_extension(
user_id=user_id, extension=extension_to_disable, active=False
)

View File

@ -1,10 +1,13 @@
import glob
import json
import os
from http import HTTPStatus
from typing import Any, List, NamedTuple, Optional
import jinja2
import shortuuid
import shortuuid # type: ignore
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.jinja2_templating import Jinja2Templates
from lnbits.requestvars import g
@ -25,8 +28,10 @@ class Extension(NamedTuple):
class ExtensionManager:
def __init__(self):
self._disabled: List[str] = settings.lnbits_disabled_extensions
def __init__(self, include_disabled_exts=False):
self._disabled: List[str] = (
[] if include_disabled_exts else settings.lnbits_disabled_extensions
)
self._admin_only: List[str] = settings.lnbits_admin_extensions
self._extension_folders: List[str] = [
x[1] for x in os.walk(os.path.join(settings.lnbits_path, "extensions"))
@ -74,9 +79,40 @@ class ExtensionManager:
return output
def get_valid_extensions() -> List[Extension]:
class EnabledExtensionMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
pathname = scope["path"].split("/")[1]
if pathname in settings.lnbits_disabled_extensions:
response = JSONResponse(
status_code=HTTPStatus.NOT_FOUND,
content={"detail": f"Extension '{pathname}' disabled"},
)
await response(scope, receive, send)
return
await self.app(scope, receive, send)
class InstallableExtension(NamedTuple):
id: str
name: str
archive: str
hash: str
short_description: Optional[str] = None
details: Optional[str] = None
icon: Optional[str] = None
dependencies: List[str] = []
is_admin_only: bool = False
def get_valid_extensions(include_disabled_exts=False) -> List[Extension]:
return [
extension for extension in ExtensionManager().extensions if extension.is_valid
extension
for extension in ExtensionManager(include_disabled_exts).extensions
if extension.is_valid
]

View File

@ -40,6 +40,7 @@ class UsersSettings(LNbitsSettings):
lnbits_allowed_users: List[str] = Field(default=[])
lnbits_admin_extensions: List[str] = Field(default=[])
lnbits_disabled_extensions: List[str] = Field(default=[])
lnbits_extensions_manifests: List[str] = Field(default=[])
class ThemesSettings(LNbitsSettings):

View File

@ -141,7 +141,8 @@ window.LNbits = {
admin: data.admin,
email: data.email,
extensions: data.extensions,
wallets: data.wallets
wallets: data.wallets,
admin: data.admin
}
var mapWallet = this.wallet
obj.wallets = obj.wallets