Merge pull request #1509 from lnbits/stop_extension_background_work

Extension Upgrade - stop background work
This commit is contained in:
Arc 2023-02-16 10:01:15 +00:00 committed by GitHub
commit 8637e7fd22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 8 deletions

View File

@ -2,10 +2,12 @@ import importlib
import re import re
from typing import Any from typing import Any
import httpx
from loguru import logger from loguru import logger
from lnbits.db import Connection from lnbits.db import Connection
from lnbits.extension_manager import Extension from lnbits.extension_manager import Extension
from lnbits.settings import settings
from . import db as core_db from . import db as core_db
from .crud import update_migration_version from .crud import update_migration_version
@ -42,3 +44,22 @@ async def run_migration(db: Connection, migrations_module: Any, current_version:
else: else:
async with core_db.connect() as conn: async with core_db.connect() as conn:
await update_migration_version(conn, db_name, version) await update_migration_version(conn, db_name, version)
async def stop_extension_background_work(ext_id: str, user: str):
"""
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
Extensions SHOULD expose a DELETE enpoint at the root level of their API.
This function tries first to call the endpoint using `http` and if if fails it tries using `https`.
"""
async with httpx.AsyncClient() as client:
try:
url = f"http://{settings.host}:{settings.port}/{ext_id}/api/v1?usr={user}"
await client.delete(url)
except Exception as ex:
logger.warning(ex)
try:
# try https
url = f"https://{settings.host}:{settings.port}/{ext_id}/api/v1?usr={user}"
except Exception as ex:
logger.warning(ex)

View File

@ -29,7 +29,10 @@ from sse_starlette.sse import EventSourceResponse
from starlette.responses import RedirectResponse, StreamingResponse from starlette.responses import RedirectResponse, StreamingResponse
from lnbits import bolt11, lnurl from lnbits import bolt11, lnurl
from lnbits.core.helpers import migrate_extension_database from lnbits.core.helpers import (
migrate_extension_database,
stop_extension_background_work,
)
from lnbits.core.models import Payment, User, Wallet from lnbits.core.models import Payment, User, Wallet
from lnbits.decorators import ( from lnbits.decorators import (
WalletTypeInfo, WalletTypeInfo,
@ -729,7 +732,6 @@ async def websocket_update_get(item_id: str, data: str):
async def api_install_extension( async def api_install_extension(
data: CreateExtension, user: User = Depends(check_admin) data: CreateExtension, user: User = Depends(check_admin)
): ):
release = await InstallableExtension.get_extension_release( release = await InstallableExtension.get_extension_release(
data.ext_id, data.source_repo, data.archive data.ext_id, data.source_repo, data.archive
) )
@ -752,6 +754,10 @@ async def api_install_extension(
await migrate_extension_database(extension, db_version) await migrate_extension_database(extension, db_version)
await add_installed_extension(ext_info) await add_installed_extension(ext_info)
# call stop while the old routes are still active
await stop_extension_background_work(data.ext_id, user.id)
if data.ext_id not in settings.lnbits_deactivated_extensions: if data.ext_id not in settings.lnbits_deactivated_extensions:
settings.lnbits_deactivated_extensions += [data.ext_id] settings.lnbits_deactivated_extensions += [data.ext_id]
@ -798,6 +804,9 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)
) )
try: try:
# call stop while the old routes are still active
await stop_extension_background_work(ext_id, user.id)
if ext_id not in settings.lnbits_deactivated_extensions: if ext_id not in settings.lnbits_deactivated_extensions:
settings.lnbits_deactivated_extensions += [ext_id] settings.lnbits_deactivated_extensions += [ext_id]

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@ -16,6 +17,7 @@ lnurlp_static_files = [
"name": "lnurlp_static", "name": "lnurlp_static",
} }
] ]
scheduled_tasks: List[asyncio.Task] = []
lnurlp_ext: APIRouter = APIRouter(prefix="/lnurlp", tags=["lnurlp"]) lnurlp_ext: APIRouter = APIRouter(prefix="/lnurlp", tags=["lnurlp"])
@ -32,4 +34,5 @@ from .views_api import * # noqa: F401,F403
def lnurlp_start(): def lnurlp_start():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.create_task(catch_everything_and_restart(wait_for_paid_invoices)) task = loop.create_task(catch_everything_and_restart(wait_for_paid_invoices))
scheduled_tasks.append(task)

View File

@ -61,9 +61,9 @@ class PayLink(BaseModel):
def success_action(self, payment_hash: str) -> Optional[Dict]: def success_action(self, payment_hash: str) -> Optional[Dict]:
if self.success_url: if self.success_url:
url: ParseResult = urlparse(self.success_url) url: ParseResult = urlparse(self.success_url)
#qs = parse_qs(url.query) # qs = parse_qs(url.query)
#setattr(qs, "payment_hash", payment_hash) # setattr(qs, "payment_hash", payment_hash)
#url = url._replace(query=urlencode(qs, doseq=True)) # url = url._replace(query=urlencode(qs, doseq=True))
return { return {
"tag": "url", "tag": "url",
"description": self.success_text or "~", "description": self.success_text or "~",

View File

@ -1,4 +1,5 @@
import json import json
from asyncio.log import logger
from http import HTTPStatus from http import HTTPStatus
from fastapi import Depends, Query, Request from fastapi import Depends, Query, Request
@ -6,10 +7,10 @@ from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from lnbits.core.crud import get_user from lnbits.core.crud import get_user
from lnbits.decorators import WalletTypeInfo, get_key_type from lnbits.decorators import WalletTypeInfo, check_admin, get_key_type
from lnbits.utils.exchange_rates import currencies, get_fiat_rate_satoshis from lnbits.utils.exchange_rates import currencies, get_fiat_rate_satoshis
from . import lnurlp_ext from . import lnurlp_ext, scheduled_tasks
from .crud import ( from .crud import (
create_pay_link, create_pay_link,
delete_pay_link, delete_pay_link,
@ -166,3 +167,14 @@ async def api_check_fiat_rate(currency):
rate = None rate = None
return {"rate": rate} return {"rate": rate}
@lnurlp_ext.delete("/api/v1", status_code=HTTPStatus.OK)
async def api_stop(wallet: WalletTypeInfo = Depends(check_admin)):
for t in scheduled_tasks:
try:
t.cancel()
except Exception as ex:
logger.warning(ex)
return {"success": True}