diff --git a/Makefile b/Makefile index f80747a6..2ac497b5 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,8 @@ black: $(shell find lnbits -name "*.py") mypy: $(shell find lnbits -name "*.py") ./venv/bin/mypy lnbits + ./venv/bin/mypy lnbits/core + ./venv/bin/mypy lnbits/extensions/* checkprettier: $(shell find lnbits -name "*.js" -name ".html") ./node_modules/.bin/prettier --check lnbits/static/js/*.js lnbits/core/static/js/*.js lnbits/extensions/*/templates/*/*.html ./lnbits/core/templates/core/*.html lnbits/templates/*.html lnbits/extensions/*/static/js/*.js diff --git a/lnbits/__main__.py b/lnbits/__main__.py index 66f8dbc9..89fc6163 100644 --- a/lnbits/__main__.py +++ b/lnbits/__main__.py @@ -1,13 +1,17 @@ -from .app import create_app -from .commands import migrate_databases, transpile_scss, bundle_vendored -from .settings import LNBITS_SITE_TITLE, SERVICE_FEE, DEBUG, LNBITS_DATA_FOLDER, WALLET, LNBITS_COMMIT +import trio # type: ignore -migrate_databases() +from .commands import migrate_databases, transpile_scss, bundle_vendored + +trio.run(migrate_databases) transpile_scss() bundle_vendored() +from .app import create_app + app = create_app() +from .settings import LNBITS_SITE_TITLE, SERVICE_FEE, DEBUG, LNBITS_DATA_FOLDER, WALLET, LNBITS_COMMIT + print( f"""Starting LNbits with - git version: {LNBITS_COMMIT} diff --git a/lnbits/app.py b/lnbits/app.py index 8528b898..b1562f62 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -9,7 +9,6 @@ from secure import SecureHeaders # type: ignore from .commands import db_migrate, handle_assets from .core import core_app -from .db import open_db, open_ext_db from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored from .proxy_fix import ASGIProxyFix from .tasks import run_deferred_async, invoice_listener, internal_invoice_listener, webhook_handler, grab_app_for_later @@ -63,13 +62,9 @@ def register_blueprints(app: QuartTrio) -> None: ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") bp = getattr(ext_module, f"{ext.code}_ext") - @bp.before_request - async def before_request(): - g.ext_db = open_ext_db(ext.code) - @bp.teardown_request async def after_request(exc): - g.ext_db.close() + await ext_module.db.close_session() app.register_blueprint(bp, url_prefix=f"/{ext.code}") except Exception: @@ -106,18 +101,19 @@ def register_request_hooks(app: QuartTrio): @app.before_request async def before_request(): - g.db = open_db() g.nursery = app.nursery + @app.teardown_request + async def after_request(exc): + from lnbits.core import db + + await db.close_session() + @app.after_request async def set_secure_headers(response): secure_headers.quart(response) return response - @app.teardown_request - async def after_request(exc): - g.db.close() - def register_async_tasks(app): @app.route("/wallet/webhook", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) diff --git a/lnbits/commands.py b/lnbits/commands.py index a3dde9c2..58899bb9 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -1,19 +1,19 @@ +import trio # type: ignore import warnings import click import importlib import re import os -import sqlite3 +from sqlalchemy.exc import OperationalError # type: ignore -from .core import migrations as core_migrations -from .db import open_db, open_ext_db +from .core import db as core_db, migrations as core_migrations from .helpers import get_valid_extensions, get_css_vendored, get_js_vendored, url_for_vendored from .settings import LNBITS_PATH @click.command("migrate") def db_migrate(): - migrate_databases() + trio.run(migrate_databases) @click.command("assets") @@ -45,39 +45,44 @@ def bundle_vendored(): f.write(output) -def migrate_databases(): +async def migrate_databases(): """Creates the necessary databases if they don't exist already; or migrates them.""" - with open_db() as core_db: + core_conn = await core_db.connect() + core_txn = await core_conn.begin() + + try: + rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall() + except OperationalError: + # migration 3 wasn't ran + core_migrations.m000_create_migrations_table(core_conn) + rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall() + + current_versions = {row["db"]: row["version"] for row in rows} + matcher = re.compile(r"^m(\d\d\d)_") + + async def run_migration(db, migrations_module): + db_name = migrations_module.__name__.split(".")[-2] + for key, migrate in migrations_module.__dict__.items(): + match = match = matcher.match(key) + if match: + version = int(match.group(1)) + if version > current_versions.get(db_name, 0): + print(f"running migration {db_name}.{version}") + await migrate(db) + await core_conn.execute( + "INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", (db_name, version) + ) + + await run_migration(core_conn, core_migrations) + + for ext in get_valid_extensions(): try: - rows = core_db.fetchall("SELECT * FROM dbversions") - except sqlite3.OperationalError: - # migration 3 wasn't ran - core_migrations.m000_create_migrations_table(core_db) - rows = core_db.fetchall("SELECT * FROM dbversions") + ext_migrations = importlib.import_module(f"lnbits.extensions.{ext.code}.migrations") + ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db + await run_migration(ext_db, ext_migrations) + except ImportError: + raise ImportError(f"Please make sure that the extension `{ext.code}` has a migrations file.") - current_versions = {row["db"]: row["version"] for row in rows} - matcher = re.compile(r"^m(\d\d\d)_") - - def run_migration(db, migrations_module): - db_name = migrations_module.__name__.split(".")[-2] - for key, run_migration in migrations_module.__dict__.items(): - match = match = matcher.match(key) - if match: - version = int(match.group(1)) - if version > current_versions.get(db_name, 0): - print(f"running migration {db_name}.{version}") - run_migration(db) - core_db.execute( - "INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", (db_name, version) - ) - - run_migration(core_db, core_migrations) - - for ext in get_valid_extensions(): - try: - ext_migrations = importlib.import_module(f"lnbits.extensions.{ext.code}.migrations") - with open_ext_db(ext.code) as db: - run_migration(db, ext_migrations) - except ImportError: - raise ImportError(f"Please make sure that the extension `{ext.code}` has a migrations file.") + await core_txn.commit() + await core_conn.close() diff --git a/lnbits/core/__init__.py b/lnbits/core/__init__.py index e863d0af..a2ea1ddf 100644 --- a/lnbits/core/__init__.py +++ b/lnbits/core/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("database") core_app: Blueprint = Blueprint( "core", __name__, template_folder="templates", static_folder="static", static_url_path="/core/static" diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 984492a5..c9c3b107 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -2,11 +2,11 @@ import json import datetime from uuid import uuid4 from typing import List, Optional, Dict -from quart import g from lnbits import bolt11 from lnbits.settings import DEFAULT_WALLET_NAME +from . import db from .models import User, Wallet, Payment @@ -14,28 +14,28 @@ from .models import User, Wallet, Payment # -------- -def create_account() -> User: +async def create_account() -> User: user_id = uuid4().hex - g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) + await db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) - new_account = get_account(user_id=user_id) + new_account = await get_account(user_id=user_id) assert new_account, "Newly created account couldn't be retrieved" return new_account -def get_account(user_id: str) -> Optional[User]: - row = g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)) +async def get_account(user_id: str) -> Optional[User]: + row = await db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)) return User(**row) if row else None -def get_user(user_id: str) -> Optional[User]: - user = g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) +async def get_user(user_id: str) -> Optional[User]: + user = await db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) if user: - extensions = g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) - wallets = g.db.fetchall( + extensions = await db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) + wallets = await db.fetchall( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -51,8 +51,8 @@ def get_user(user_id: str) -> Optional[User]: ) -def update_user_extension(*, user_id: str, extension: str, active: int) -> None: - g.db.execute( +async def update_user_extension(*, user_id: str, extension: str, active: int) -> None: + await db.execute( """ INSERT OR REPLACE INTO extensions (user, extension, active) VALUES (?, ?, ?) @@ -65,9 +65,9 @@ def update_user_extension(*, user_id: str, extension: str, active: int) -> None: # ------- -def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: +async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: wallet_id = uuid4().hex - g.db.execute( + await db.execute( """ INSERT INTO wallets (id, name, user, adminkey, inkey) VALUES (?, ?, ?, ?, ?) @@ -75,14 +75,14 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: (wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex), ) - new_wallet = get_wallet(wallet_id=wallet_id) + new_wallet = await get_wallet(wallet_id=wallet_id) assert new_wallet, "Newly created wallet couldn't be retrieved" return new_wallet -def delete_wallet(*, user_id: str, wallet_id: str) -> None: - g.db.execute( +async def delete_wallet(*, user_id: str, wallet_id: str) -> None: + await db.execute( """ UPDATE wallets AS w SET @@ -95,8 +95,8 @@ def delete_wallet(*, user_id: str, wallet_id: str) -> None: ) -def get_wallet(wallet_id: str) -> Optional[Wallet]: - row = g.db.fetchone( +async def get_wallet(wallet_id: str) -> Optional[Wallet]: + row = await db.fetchone( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -108,8 +108,8 @@ def get_wallet(wallet_id: str) -> Optional[Wallet]: return Wallet(**row) if row else None -def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: - row = g.db.fetchone( +async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: + row = await db.fetchone( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -131,8 +131,8 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: # --------------- -def get_standalone_payment(checking_id: str) -> Optional[Payment]: - row = g.db.fetchone( +async def get_standalone_payment(checking_id: str) -> Optional[Payment]: + row = await db.fetchone( """ SELECT * FROM apipayments @@ -144,8 +144,8 @@ def get_standalone_payment(checking_id: str) -> Optional[Payment]: return Payment.from_row(row) if row else None -def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: - row = g.db.fetchone( +async def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: + row = await db.fetchone( """ SELECT * FROM apipayments @@ -157,7 +157,7 @@ def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: return Payment.from_row(row) if row else None -def get_wallet_payments( +async def get_wallet_payments( wallet_id: str, *, complete: bool = False, @@ -197,7 +197,7 @@ def get_wallet_payments( clause += "AND checking_id NOT LIKE 'temp_%' " clause += "AND checking_id NOT LIKE 'internal_%' " - rows = g.db.fetchall( + rows = await db.fetchall( f""" SELECT * FROM apipayments @@ -210,8 +210,8 @@ def get_wallet_payments( return [Payment.from_row(row) for row in rows] -def delete_expired_invoices() -> None: - rows = g.db.fetchall( +async def delete_expired_invoices() -> None: + rows = await db.fetchall( """ SELECT bolt11 FROM apipayments @@ -228,7 +228,7 @@ def delete_expired_invoices() -> None: if expiration_date > datetime.datetime.utcnow(): continue - g.db.execute( + await db.execute( """ DELETE FROM apipayments WHERE pending = 1 AND hash = ? @@ -241,7 +241,7 @@ def delete_expired_invoices() -> None: # -------- -def create_payment( +async def create_payment( *, wallet_id: str, checking_id: str, @@ -254,7 +254,7 @@ def create_payment( pending: bool = True, extra: Optional[Dict] = None, ) -> Payment: - g.db.execute( + await db.execute( """ INSERT INTO apipayments (wallet, checking_id, bolt11, hash, preimage, @@ -275,14 +275,14 @@ def create_payment( ), ) - new_payment = get_wallet_payment(wallet_id, payment_hash) + new_payment = await get_wallet_payment(wallet_id, payment_hash) assert new_payment, "Newly created payment couldn't be retrieved" return new_payment -def update_payment_status(checking_id: str, pending: bool) -> None: - g.db.execute( +async def update_payment_status(checking_id: str, pending: bool) -> None: + await db.execute( "UPDATE apipayments SET pending = ? WHERE checking_id = ?", ( int(pending), @@ -291,12 +291,12 @@ def update_payment_status(checking_id: str, pending: bool) -> None: ) -def delete_payment(checking_id: str) -> None: - g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) +async def delete_payment(checking_id: str) -> None: + await db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) -def check_internal(payment_hash: str) -> Optional[str]: - row = g.db.fetchone( +async def check_internal(payment_hash: str) -> Optional[str]: + row = await db.fetchone( """ SELECT checking_id FROM apipayments WHERE hash = ? AND pending AND amount > 0 diff --git a/lnbits/core/migrations.py b/lnbits/core/migrations.py index 828d9d42..8a37b652 100644 --- a/lnbits/core/migrations.py +++ b/lnbits/core/migrations.py @@ -1,8 +1,8 @@ -import sqlite3 +from sqlalchemy.exc import OperationalError # type: ignore -def m000_create_migrations_table(db): - db.execute( +async def m000_create_migrations_table(db): + await db.execute( """ CREATE TABLE dbversions ( db TEXT PRIMARY KEY, @@ -12,11 +12,11 @@ def m000_create_migrations_table(db): ) -def m001_initial(db): +async def m001_initial(db): """ Initial LNbits tables. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS accounts ( id TEXT PRIMARY KEY, @@ -25,7 +25,7 @@ def m001_initial(db): ); """ ) - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS extensions ( user TEXT NOT NULL, @@ -36,7 +36,7 @@ def m001_initial(db): ); """ ) - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS wallets ( id TEXT PRIMARY KEY, @@ -47,7 +47,7 @@ def m001_initial(db): ); """ ) - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS apipayments ( payhash TEXT NOT NULL, @@ -63,7 +63,7 @@ def m001_initial(db): """ ) - db.execute( + await db.execute( """ CREATE VIEW IF NOT EXISTS balances AS SELECT wallet, COALESCE(SUM(s), 0) AS balance FROM ( @@ -82,22 +82,22 @@ def m001_initial(db): ) -def m002_add_fields_to_apipayments(db): +async def m002_add_fields_to_apipayments(db): """ Adding fields to apipayments for better accounting, and renaming payhash to checking_id since that is what it really is. """ try: - db.execute("ALTER TABLE apipayments RENAME COLUMN payhash TO checking_id") - db.execute("ALTER TABLE apipayments ADD COLUMN hash TEXT") - db.execute("CREATE INDEX by_hash ON apipayments (hash)") - db.execute("ALTER TABLE apipayments ADD COLUMN preimage TEXT") - db.execute("ALTER TABLE apipayments ADD COLUMN bolt11 TEXT") - db.execute("ALTER TABLE apipayments ADD COLUMN extra TEXT") + await db.execute("ALTER TABLE apipayments RENAME COLUMN payhash TO checking_id") + await db.execute("ALTER TABLE apipayments ADD COLUMN hash TEXT") + await db.execute("CREATE INDEX by_hash ON apipayments (hash)") + await db.execute("ALTER TABLE apipayments ADD COLUMN preimage TEXT") + await db.execute("ALTER TABLE apipayments ADD COLUMN bolt11 TEXT") + await db.execute("ALTER TABLE apipayments ADD COLUMN extra TEXT") import json - rows = db.fetchall("SELECT * FROM apipayments") + rows = await (await db.execute("SELECT * FROM apipayments")).fetchall() for row in rows: if not row["memo"] or not row["memo"].startswith("#"): continue @@ -106,7 +106,7 @@ def m002_add_fields_to_apipayments(db): prefix = "#" + ext + " " if row["memo"].startswith(prefix): new = row["memo"][len(prefix) :] - db.execute( + await db.execute( """ UPDATE apipayments SET extra = ?, memo = ? WHERE checking_id = ? AND memo = ? @@ -114,7 +114,7 @@ def m002_add_fields_to_apipayments(db): (json.dumps({"tag": ext}), new, row["checking_id"], row["memo"]), ) break - except sqlite3.OperationalError: + except OperationalError: # this is necessary now because it may be the case that this migration will # run twice in some environments. # catching errors like this won't be necessary in anymore now that we diff --git a/lnbits/core/models.py b/lnbits/core/models.py index 4655c256..89ffc1c1 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -40,18 +40,14 @@ class Wallet(NamedTuple): hashing_key = hashlib.sha256(self.id.encode("utf-8")).digest() linking_key = hmac.digest(hashing_key, domain.encode("utf-8"), "sha256") - return SigningKey.from_string( - linking_key, - curve=SECP256k1, - hashfunc=hashlib.sha256, - ) + return SigningKey.from_string(linking_key, curve=SECP256k1, hashfunc=hashlib.sha256,) - def get_payment(self, payment_hash: str) -> Optional["Payment"]: + async def get_payment(self, payment_hash: str) -> Optional["Payment"]: from .crud import get_wallet_payment - return get_wallet_payment(self.id, payment_hash) + return await get_wallet_payment(self.id, payment_hash) - def get_payments( + async def get_payments( self, *, complete: bool = True, @@ -62,7 +58,7 @@ class Wallet(NamedTuple): ) -> List["Payment"]: from .crud import get_wallet_payments - return get_wallet_payments( + return await get_wallet_payments( self.id, complete=complete, pending=pending, @@ -125,12 +121,12 @@ class Payment(NamedTuple): def is_uncheckable(self) -> bool: return self.checking_id.startswith("temp_") or self.checking_id.startswith("internal_") - def set_pending(self, pending: bool) -> None: + async def set_pending(self, pending: bool) -> None: from .crud import update_payment_status - update_payment_status(self.checking_id, pending) + await update_payment_status(self.checking_id, pending) - def check_pending(self) -> None: + async def check_pending(self) -> None: if self.is_uncheckable: return @@ -139,9 +135,9 @@ class Payment(NamedTuple): else: pending = WALLET.get_invoice_status(self.checking_id) - self.set_pending(pending.pending) + await self.set_pending(pending.pending) - def delete(self) -> None: + async def delete(self) -> None: from .crud import delete_payment - delete_payment(self.checking_id) + await delete_payment(self.checking_id) diff --git a/lnbits/core/services.py b/lnbits/core/services.py index 991e278b..718fa976 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -1,4 +1,3 @@ -import trio # type: ignore import json import httpx from io import BytesIO @@ -18,10 +17,11 @@ from lnbits.helpers import urlsafe_short_hash from lnbits.settings import WALLET from lnbits.wallets.base import PaymentStatus, PaymentResponse +from . import db from .crud import get_wallet, create_payment, delete_payment, check_internal, update_payment_status, get_wallet_payment -def create_invoice( +async def create_invoice( *, wallet_id: str, amount: int, # in satoshis @@ -29,6 +29,7 @@ def create_invoice( description_hash: Optional[bytes] = None, extra: Optional[Dict] = None, ) -> Tuple[str, str]: + await db.begin() invoice_memo = None if description_hash else memo storeable_memo = memo @@ -41,7 +42,7 @@ def create_invoice( invoice = bolt11.decode(payment_request) amount_msat = amount * 1000 - create_payment( + await create_payment( wallet_id=wallet_id, checking_id=checking_id, payment_request=payment_request, @@ -51,11 +52,11 @@ def create_invoice( extra=extra, ) - g.db.commit() + await db.commit() return invoice.payment_hash, payment_request -def pay_invoice( +async def pay_invoice( *, wallet_id: str, payment_request: str, @@ -63,6 +64,7 @@ def pay_invoice( extra: Optional[Dict] = None, description: str = "", ) -> str: + await db.begin() temp_id = f"temp_{urlsafe_short_hash()}" internal_id = f"internal_{urlsafe_short_hash()}" @@ -94,58 +96,53 @@ def pay_invoice( ) # check_internal() returns the checking_id of the invoice we're waiting for - internal_checking_id = check_internal(invoice.payment_hash) + internal_checking_id = await check_internal(invoice.payment_hash) if internal_checking_id: # create a new payment from this wallet - create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs) + await create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs) else: # create a temporary payment here so we can check if # the balance is enough in the next step fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) - create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) + await create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) # do the balance check - wallet = get_wallet(wallet_id) + wallet = await get_wallet(wallet_id) assert wallet if wallet.balance_msat < 0: - g.db.rollback() + await db.rollback() raise PermissionError("Insufficient balance.") else: - g.db.commit() + await db.commit() + await db.begin() if internal_checking_id: # mark the invoice from the other side as not pending anymore # so the other side only has access to his new money when we are sure # the payer has enough to deduct from - update_payment_status(checking_id=internal_checking_id, pending=False) + await update_payment_status(checking_id=internal_checking_id, pending=False) # notify receiver asynchronously from lnbits.tasks import internal_invoice_paid - try: - internal_invoice_paid.send_nowait(internal_checking_id) - except trio.WouldBlock: - pass + await internal_invoice_paid.send(internal_checking_id) else: # actually pay the external invoice payment: PaymentResponse = WALLET.pay_invoice(payment_request) if payment.ok and payment.checking_id: - create_payment( - checking_id=payment.checking_id, - fee=payment.fee_msat, - preimage=payment.preimage, - **payment_kwargs, + await create_payment( + checking_id=payment.checking_id, fee=payment.fee_msat, preimage=payment.preimage, **payment_kwargs, ) - delete_payment(temp_id) + await delete_payment(temp_id) else: raise Exception(payment.error_message or "Failed to pay_invoice on backend.") - g.db.commit() + await db.commit() return invoice.payment_hash async def redeem_lnurl_withdraw(wallet_id: str, res: LnurlWithdrawResponse, memo: Optional[str] = None) -> None: - _, payment_request = create_invoice( + _, payment_request = await create_invoice( wallet_id=wallet_id, amount=res.max_sats, memo=memo or res.default_description or "", @@ -154,8 +151,7 @@ async def redeem_lnurl_withdraw(wallet_id: str, res: LnurlWithdrawResponse, memo async with httpx.AsyncClient() as client: await client.get( - res.callback.base, - params={**res.callback.query_params, **{"k1": res.k1, "pr": payment_request}}, + res.callback.base, params={**res.callback.query_params, **{"k1": res.k1, "pr": payment_request}}, ) @@ -212,11 +208,7 @@ async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]: async with httpx.AsyncClient() as client: r = await client.get( callback, - params={ - "k1": k1.hex(), - "key": key.verifying_key.to_string("compressed").hex(), - "sig": sig.hex(), - }, + params={"k1": k1.hex(), "key": key.verifying_key.to_string("compressed").hex(), "sig": sig.hex(),}, ) try: resp = json.loads(r.text) @@ -225,13 +217,11 @@ async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]: return LnurlErrorResponse(reason=resp["reason"]) except (KeyError, json.decoder.JSONDecodeError): - return LnurlErrorResponse( - reason=r.text[:200] + "..." if len(r.text) > 200 else r.text, - ) + return LnurlErrorResponse(reason=r.text[:200] + "..." if len(r.text) > 200 else r.text,) -def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus: - payment = get_wallet_payment(wallet_id, payment_hash) +async def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus: + payment = await get_wallet_payment(wallet_id, payment_hash) if not payment: return PaymentStatus(None) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index fd4a1159..ca56c881 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -2,7 +2,6 @@ import trio # type: ignore import json import lnurl # type: ignore import httpx -import traceback from urllib.parse import urlparse, urlunparse, urlencode, parse_qs, ParseResult from quart import g, jsonify, request, make_response from http import HTTPStatus @@ -12,7 +11,7 @@ from typing import Dict, Union from lnbits import bolt11 from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from .. import core_app +from .. import core_app, db from ..services import create_invoice, pay_invoice, perform_lnurlauth from ..crud import delete_expired_invoices from ..tasks import sse_listeners @@ -22,13 +21,7 @@ from ..tasks import sse_listeners @api_check_wallet_key("invoice") async def api_wallet(): return ( - jsonify( - { - "id": g.wallet.id, - "name": g.wallet.name, - "balance": g.wallet.balance_msat, - } - ), + jsonify({"id": g.wallet.id, "name": g.wallet.name, "balance": g.wallet.balance_msat,}), HTTPStatus.OK, ) @@ -37,12 +30,12 @@ async def api_wallet(): @api_check_wallet_key("invoice") async def api_payments(): if "check_pending" in request.args: - delete_expired_invoices() + await delete_expired_invoices() - for payment in g.wallet.get_payments(complete=False, pending=True, exclude_uncheckable=True): - payment.check_pending() + for payment in await g.wallet.get_payments(complete=False, pending=True, exclude_uncheckable=True): + await payment.check_pending() - return jsonify(g.wallet.get_payments(pending=True)), HTTPStatus.OK + return jsonify(await g.wallet.get_payments(pending=True)), HTTPStatus.OK @api_check_wallet_key("invoice") @@ -63,12 +56,14 @@ async def api_payments_create_invoice(): memo = g.data["memo"] try: - payment_hash, payment_request = create_invoice( + payment_hash, payment_request = await create_invoice( wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash ) - except Exception as e: - g.db.rollback() - return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR + except Exception as exc: + await db.rollback() + raise exc + + await db.commit() invoice = bolt11.decode(payment_request) @@ -76,11 +71,7 @@ async def api_payments_create_invoice(): if g.data.get("lnurl_callback"): async with httpx.AsyncClient() as client: try: - r = await client.get( - g.data["lnurl_callback"], - params={"pr": payment_request}, - timeout=10, - ) + r = await client.get(g.data["lnurl_callback"], params={"pr": payment_request}, timeout=10,) if r.is_error: lnurl_response = r.text else: @@ -110,15 +101,14 @@ async def api_payments_create_invoice(): @api_validate_post_request(schema={"bolt11": {"type": "string", "empty": False, "required": True}}) async def api_payments_pay_invoice(): try: - payment_hash = pay_invoice(wallet_id=g.wallet.id, payment_request=g.data["bolt11"]) + payment_hash = await pay_invoice(wallet_id=g.wallet.id, payment_request=g.data["bolt11"]) except ValueError as e: return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST except PermissionError as e: return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN except Exception as exc: - traceback.print_exc(7) - g.db.rollback() - return jsonify({"message": str(exc)}), HTTPStatus.INTERNAL_SERVER_ERROR + await db.rollback() + raise exc return ( jsonify( @@ -157,9 +147,7 @@ async def api_payments_pay_lnurl(): async with httpx.AsyncClient() as client: try: r = await client.get( - g.data["callback"], - params={"amount": g.data["amount"], "comment": g.data["comment"]}, - timeout=40, + g.data["callback"], params={"amount": g.data["amount"], "comment": g.data["comment"]}, timeout=40, ) if r.is_error: return jsonify({"message": "failed to connect"}), HTTPStatus.BAD_REQUEST @@ -198,16 +186,12 @@ async def api_payments_pay_lnurl(): if g.data["comment"]: extra["comment"] = g.data["comment"] - payment_hash = pay_invoice( - wallet_id=g.wallet.id, - payment_request=params["pr"], - description=g.data.get("description", ""), - extra=extra, + payment_hash = await pay_invoice( + wallet_id=g.wallet.id, payment_request=params["pr"], description=g.data.get("description", ""), extra=extra, ) except Exception as exc: - traceback.print_exc(7) - g.db.rollback() - return jsonify({"message": str(exc)}), HTTPStatus.INTERNAL_SERVER_ERROR + await db.rollback() + raise exc return ( jsonify( @@ -225,7 +209,7 @@ async def api_payments_pay_lnurl(): @core_app.route("/api/v1/payments/", methods=["GET"]) @api_check_wallet_key("invoice") async def api_payment(payment_hash): - payment = g.wallet.get_payment(payment_hash) + payment = await g.wallet.get_payment(payment_hash) if not payment: return jsonify({"message": "Payment does not exist."}), HTTPStatus.NOT_FOUND @@ -233,7 +217,7 @@ async def api_payment(payment_hash): return jsonify({"paid": True, "preimage": payment.preimage}), HTTPStatus.OK try: - payment.check_pending() + await payment.check_pending() except Exception: return jsonify({"paid": False}), HTTPStatus.OK @@ -243,7 +227,6 @@ async def api_payment(payment_hash): @core_app.route("/api/v1/payments/sse", methods=["GET"]) @api_check_wallet_key("invoice", accept_querystring=True) async def api_payments_sse(): - g.db.close() this_wallet_id = g.wallet.id send_payment, receive_payment = trio.open_memory_channel(0) @@ -364,9 +347,7 @@ async def api_lnurlscan(code: str): @core_app.route("/api/v1/lnurlauth", methods=["POST"]) @api_check_wallet_key("admin") @api_validate_post_request( - schema={ - "callback": {"type": "string", "required": True}, - } + schema={"callback": {"type": "string", "required": True},} ) async def api_perform_lnurlauth(): err = await perform_lnurlauth(g.data["callback"]) diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 4ff17981..3156169d 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -8,8 +8,8 @@ from lnurl import LnurlResponse, LnurlWithdrawResponse, decode as decode_lnurl from lnbits.core import core_app from lnbits.decorators import check_user_exists, validate_uuids from lnbits.settings import LNBITS_ALLOWED_USERS, SERVICE_FEE -from lnbits.tasks import run_on_pseudo_request +from .. import db from ..crud import ( create_account, get_user, @@ -41,11 +41,11 @@ async def extensions(): abort(HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension.") if extension_to_enable: - update_user_extension(user_id=g.user.id, extension=extension_to_enable, active=1) + await update_user_extension(user_id=g.user.id, extension=extension_to_enable, active=1) elif extension_to_disable: - update_user_extension(user_id=g.user.id, extension=extension_to_disable, active=0) + await update_user_extension(user_id=g.user.id, extension=extension_to_disable, active=0) - return await render_template("core/extensions.html", user=get_user(g.user.id)) + return await render_template("core/extensions.html", user=await get_user(g.user.id)) @core_app.route("/wallet") @@ -63,9 +63,12 @@ async def wallet(): # nothing: create everything if not user_id: - user = get_user(create_account().id) + user = await get_user((await create_account()).id) else: - user = get_user(user_id) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") + user = await get_user(user_id) + if not user: + abort(HTTPStatus.NOT_FOUND, "User does not exist.") + return if LNBITS_ALLOWED_USERS and user_id not in LNBITS_ALLOWED_USERS: abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") @@ -74,7 +77,7 @@ async def wallet(): if user.wallets and not wallet_name: wallet = user.wallets[0] else: - wallet = create_wallet(user_id=user.id, wallet_name=wallet_name) + wallet = await create_wallet(user_id=user.id, wallet_name=wallet_name) return redirect(url_for("core.wallet", usr=user.id, wal=wallet.id)) @@ -95,7 +98,7 @@ async def deletewallet(): if wallet_id not in user_wallet_ids: abort(HTTPStatus.FORBIDDEN, "Not your wallet.") else: - delete_wallet(user_id=g.user.id, wallet_id=wallet_id) + await delete_wallet(user_id=g.user.id, wallet_id=wallet_id) user_wallet_ids.remove(wallet_id) if user_wallet_ids: @@ -120,14 +123,12 @@ async def lnurlwallet(): except Exception as exc: return f"Could not process lnurl-withdraw: {exc}", HTTPStatus.INTERNAL_SERVER_ERROR - account = create_account() - user = get_user(account.id) - wallet = create_wallet(user_id=user.id) - g.db.commit() + account = await create_account() + user = await get_user(account.id) + wallet = await create_wallet(user_id=user.id) + await db.commit() - await run_on_pseudo_request( - redeem_lnurl_withdraw, wallet.id, withdraw_res, "LNbits initial funding: voucher redeem." - ) + g.nursery.start_soon(redeem_lnurl_withdraw, wallet.id, withdraw_res, "LNbits initial funding: voucher redeem.") await trio.sleep(3) return redirect(url_for("core.wallet", usr=user.id, wal=wallet.id)) diff --git a/lnbits/db.py b/lnbits/db.py index b7c9e023..61891cad 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -1,66 +1,85 @@ import os -import sqlite3 +from typing import Tuple, Optional, Any +from sqlalchemy_aio import TRIO_STRATEGY # type: ignore +from sqlalchemy import create_engine # type: ignore +from quart import g from .settings import LNBITS_DATA_FOLDER class Database: - def __init__(self, db_path: str): - self.path = db_path - self.connection = sqlite3.connect(db_path) - self.connection.row_factory = sqlite3.Row - self.cursor = self.connection.cursor() - self.closed = False + def __init__(self, db_name: str): + self.db_name = db_name + db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") + self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY) - def close(self): - self.__exit__(None, None, None) + def connect(self): + return self.engine.connect() - def __enter__(self): - return self + def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]: + try: + return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None) + except RuntimeError: + return None, None - def __exit__(self, exc_type, exc_val, exc_tb): - if self.closed: + async def begin(self): + conn, _ = self.session_connection() + if conn: return - if exc_val: - self.connection.rollback() - self.cursor.close() - self.connection.close() - else: - self.connection.commit() - self.cursor.close() - self.connection.close() + conn = await self.engine.connect() + setattr(g, f"{self.db_name}_conn", conn) + txn = await conn.begin() + setattr(g, f"{self.db_name}_txn", txn) - self.closed = True + async def fetchall(self, query: str, values: tuple = ()) -> list: + conn, _ = self.session_connection() + if conn: + result = await conn.execute(query, values) + return await result.fetchall() - def commit(self): - self.connection.commit() + async with self.connect() as conn: + result = await conn.execute(query, values) + return await result.fetchall() - def rollback(self): - self.connection.rollback() + async def fetchone(self, query: str, values: tuple = ()): + conn, _ = self.session_connection() + if conn: + result = await conn.execute(query, values) + row = await result.fetchone() + await result.close() + return row - def fetchall(self, query: str, values: tuple = ()) -> list: - """Given a query, return cursor.fetchall() rows.""" - self.execute(query, values) - return self.cursor.fetchall() + async with self.connect() as conn: + result = await conn.execute(query, values) + row = await result.fetchone() + await result.close() + return row - def fetchone(self, query: str, values: tuple = ()): - self.execute(query, values) - return self.cursor.fetchone() + async def execute(self, query: str, values: tuple = ()): + conn, _ = self.session_connection() + if conn: + return await conn.execute(query, values) - def execute(self, query: str, values: tuple = ()) -> None: - """Given a query, cursor.execute() it.""" - try: - self.cursor.execute(query, values) - except sqlite3.Error as exc: - self.connection.rollback() - raise exc + async with self.connect() as conn: + return await conn.execute(query, values) + async def commit(self): + conn, txn = self.session_connection() + if conn and txn: + await txn.commit() + await self.close_session() -def open_db(db_name: str = "database") -> Database: - db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") - return Database(db_path=db_path) + async def rollback(self): + conn, txn = self.session_connection() + if conn and txn: + await txn.rollback() + await self.close_session() - -def open_ext_db(extension_name: str) -> Database: - return open_db(f"ext_{extension_name}") + async def close_session(self): + conn, txn = self.session_connection() + if conn and txn: + await txn.close() + await conn.close() + delattr(g, f"{self.db_name}_conn") + delattr(g, f"{self.db_name}_txn") diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 34d132e9..1e659e09 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -15,7 +15,7 @@ def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False): async def wrapped_view(**kwargs): try: key_value = request.headers.get("X-Api-Key") or request.args["api-key"] - g.wallet = get_wallet_for_key(key_value, key_type) + g.wallet = await get_wallet_for_key(key_value, key_type) except KeyError: return ( jsonify({"message": "`X-Api-Key` header missing."}), @@ -63,7 +63,9 @@ def check_user_exists(param: str = "usr"): def wrap(view): @wraps(view) async def wrapped_view(**kwargs): - g.user = get_user(request.args.get(param, type=str)) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") + g.user = await get_user(request.args.get(param, type=str)) or abort( + HTTPStatus.NOT_FOUND, "User does not exist." + ) if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS: abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") diff --git a/lnbits/extensions/amilk/__init__.py b/lnbits/extensions/amilk/__init__.py index 182f0235..9aa7047c 100644 --- a/lnbits/extensions/amilk/__init__.py +++ b/lnbits/extensions/amilk/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_amilk") amilk_ext: Blueprint = Blueprint("amilk", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/amilk/crud.py b/lnbits/extensions/amilk/crud.py index 43577765..5170ca1f 100644 --- a/lnbits/extensions/amilk/crud.py +++ b/lnbits/extensions/amilk/crud.py @@ -2,43 +2,39 @@ from base64 import urlsafe_b64encode from uuid import uuid4 from typing import List, Optional, Union -from lnbits.db import open_ext_db - +from . import db from .models import AMilk -def create_amilk(*, wallet_id: str, lnurl: str, atime: int, amount: int) -> AMilk: - with open_ext_db("amilk") as db: - amilk_id = urlsafe_b64encode(uuid4().bytes_le).decode("utf-8") - db.execute( - """ - INSERT INTO amilks (id, wallet, lnurl, atime, amount) - VALUES (?, ?, ?, ?, ?) - """, - (amilk_id, wallet_id, lnurl, atime, amount), - ) +async def create_amilk(*, wallet_id: str, lnurl: str, atime: int, amount: int) -> AMilk: + amilk_id = urlsafe_b64encode(uuid4().bytes_le).decode("utf-8") + await db.execute( + """ + INSERT INTO amilks (id, wallet, lnurl, atime, amount) + VALUES (?, ?, ?, ?, ?) + """, + (amilk_id, wallet_id, lnurl, atime, amount), + ) - return get_amilk(amilk_id) + amilk = await get_amilk(amilk_id) + assert amilk, "Newly created amilk_id couldn't be retrieved" + return amilk -def get_amilk(amilk_id: str) -> Optional[AMilk]: - with open_ext_db("amilk") as db: - row = db.fetchone("SELECT * FROM amilks WHERE id = ?", (amilk_id,)) - +async def get_amilk(amilk_id: str) -> Optional[AMilk]: + row = await db.fetchone("SELECT * FROM amilks WHERE id = ?", (amilk_id,)) return AMilk(**row) if row else None -def get_amilks(wallet_ids: Union[str, List[str]]) -> List[AMilk]: +async def get_amilks(wallet_ids: Union[str, List[str]]) -> List[AMilk]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("amilk") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM amilks WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM amilks WHERE wallet IN ({q})", (*wallet_ids,)) return [AMilk(**row) for row in rows] -def delete_amilk(amilk_id: str) -> None: - with open_ext_db("amilk") as db: - db.execute("DELETE FROM amilks WHERE id = ?", (amilk_id,)) +async def delete_amilk(amilk_id: str) -> None: + await db.execute("DELETE FROM amilks WHERE id = ?", (amilk_id,)) diff --git a/lnbits/extensions/amilk/migrations.py b/lnbits/extensions/amilk/migrations.py index 3ab2d4ab..f096ccdb 100644 --- a/lnbits/extensions/amilk/migrations.py +++ b/lnbits/extensions/amilk/migrations.py @@ -1,8 +1,8 @@ -def m001_initial(db): +async def m001_initial(db): """ Initial amilks table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS amilks ( id TEXT PRIMARY KEY, diff --git a/lnbits/extensions/amilk/views.py b/lnbits/extensions/amilk/views.py index fa214e32..2f61df77 100644 --- a/lnbits/extensions/amilk/views.py +++ b/lnbits/extensions/amilk/views.py @@ -2,8 +2,8 @@ from quart import g, abort, render_template from http import HTTPStatus from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.amilk import amilk_ext +from . import amilk_ext from .crud import get_amilk @@ -16,5 +16,8 @@ async def index(): @amilk_ext.route("/") async def wall(amilk_id): - amilk = get_amilk(amilk_id) or abort(HTTPStatus.NOT_FOUND, "AMilk does not exist.") + amilk = await get_amilk(amilk_id) + if not amilk: + abort(HTTPStatus.NOT_FOUND, "AMilk does not exist.") + return await render_template("amilk/wall.html", amilk=amilk) diff --git a/lnbits/extensions/amilk/views_api.py b/lnbits/extensions/amilk/views_api.py index 2ad85c3a..8ffaa4db 100644 --- a/lnbits/extensions/amilk/views_api.py +++ b/lnbits/extensions/amilk/views_api.py @@ -1,15 +1,15 @@ import httpx from quart import g, jsonify, request, abort from http import HTTPStatus -from lnurl import LnurlWithdrawResponse, handle as handle_lnurl -from lnurl.exceptions import LnurlException +from lnurl import LnurlWithdrawResponse, handle as handle_lnurl # type: ignore +from lnurl.exceptions import LnurlException # type: ignore from time import sleep from lnbits.core.crud import get_user from lnbits.decorators import api_check_wallet_key, api_validate_post_request from lnbits.core.services import create_invoice, check_invoice_status -from lnbits.extensions.amilk import amilk_ext +from . import amilk_ext from .crud import create_amilk, get_amilk, get_amilks, delete_amilk @@ -19,14 +19,14 @@ async def api_amilks(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids - return jsonify([amilk._asdict() for amilk in get_amilks(wallet_ids)]), HTTPStatus.OK + return jsonify([amilk._asdict() for amilk in await get_amilks(wallet_ids)]), HTTPStatus.OK @amilk_ext.route("/api/v1/amilk/milk/", methods=["GET"]) async def api_amilkit(amilk_id): - milk = get_amilk(amilk_id) + milk = await get_amilk(amilk_id) memo = milk.id try: @@ -34,7 +34,7 @@ async def api_amilkit(amilk_id): except LnurlException: abort(HTTPStatus.INTERNAL_SERVER_ERROR, "Could not process withdraw LNURL.") - payment_hash, payment_request = create_invoice( + payment_hash, payment_request = await create_invoice( wallet_id=milk.wallet, amount=withdraw_res.max_sats, memo=memo, extra={"tag": "amilk"} ) @@ -48,7 +48,7 @@ async def api_amilkit(amilk_id): for i in range(10): sleep(i) - invoice_status = check_invoice_status(milk.wallet, payment_hash) + invoice_status = await check_invoice_status(milk.wallet, payment_hash) if invoice_status.paid: return jsonify({"paid": True}), HTTPStatus.OK else: @@ -67,7 +67,9 @@ async def api_amilkit(amilk_id): } ) async def api_amilk_create(): - amilk = create_amilk(wallet_id=g.wallet.id, lnurl=g.data["lnurl"], atime=g.data["atime"], amount=g.data["amount"]) + amilk = await create_amilk( + wallet_id=g.wallet.id, lnurl=g.data["lnurl"], atime=g.data["atime"], amount=g.data["amount"] + ) return jsonify(amilk._asdict()), HTTPStatus.CREATED @@ -75,7 +77,7 @@ async def api_amilk_create(): @amilk_ext.route("/api/v1/amilk/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_amilk_delete(amilk_id): - amilk = get_amilk(amilk_id) + amilk = await get_amilk(amilk_id) if not amilk: return jsonify({"message": "Paywall does not exist."}), HTTPStatus.NOT_FOUND @@ -83,6 +85,6 @@ async def api_amilk_delete(amilk_id): if amilk.wallet != g.wallet.id: return jsonify({"message": "Not your amilk."}), HTTPStatus.FORBIDDEN - delete_amilk(amilk_id) + await delete_amilk(amilk_id) return "", HTTPStatus.NO_CONTENT diff --git a/lnbits/extensions/diagonalley/migrations.py b/lnbits/extensions/diagonalley/migrations.py index afec1a6a..a70368fc 100644 --- a/lnbits/extensions/diagonalley/migrations.py +++ b/lnbits/extensions/diagonalley/migrations.py @@ -1,8 +1,8 @@ -def m001_initial(db): +async def m001_initial(db): """ Initial products table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS products ( id TEXT PRIMARY KEY, @@ -20,7 +20,7 @@ def m001_initial(db): """ Initial indexers table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS indexers ( id TEXT PRIMARY KEY, @@ -41,7 +41,7 @@ def m001_initial(db): """ Initial orders table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS orders ( id TEXT PRIMARY KEY, diff --git a/lnbits/extensions/events/__init__.py b/lnbits/extensions/events/__init__.py index abd48951..4496b2c1 100644 --- a/lnbits/extensions/events/__init__.py +++ b/lnbits/extensions/events/__init__.py @@ -1,4 +1,7 @@ from quart import Blueprint +from lnbits.db import Database + +db = Database("ext_events") events_ext: Blueprint = Blueprint("events", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/events/crud.py b/lnbits/extensions/events/crud.py index b7219f28..efc40731 100644 --- a/lnbits/extensions/events/crud.py +++ b/lnbits/extensions/events/crud.py @@ -1,45 +1,46 @@ from typing import List, Optional, Union -from lnbits.db import open_ext_db from lnbits.helpers import urlsafe_short_hash +from . import db from .models import Tickets, Events -#######TICKETS######## +# TICKETS -def create_ticket(payment_hash: str, wallet: str, event: str, name: str, email: str) -> Tickets: - with open_ext_db("events") as db: - db.execute( - """ - INSERT INTO ticket (id, wallet, event, name, email, registered, paid) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - (payment_hash, wallet, event, name, email, False, False), - ) +async def create_ticket(payment_hash: str, wallet: str, event: str, name: str, email: str) -> Tickets: + await db.execute( + """ + INSERT INTO ticket (id, wallet, event, name, email, registered, paid) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + (payment_hash, wallet, event, name, email, False, False), + ) - return get_ticket(payment_hash) + ticket = await get_ticket(payment_hash) + assert ticket, "Newly created ticket couldn't be retrieved" + return ticket -def update_ticket(paid: bool, payment_hash: str) -> Tickets: - with open_ext_db("events") as db: - row = db.fetchone("SELECT * FROM ticket WHERE id = ?", (payment_hash,)) - if row[6] == True: - return get_ticket(payment_hash) - db.execute( +async def set_ticket_paid(payment_hash: str) -> Tickets: + row = await db.fetchone("SELECT * FROM ticket WHERE id = ?", (payment_hash,)) + if row[6] != True: + await db.execute( """ UPDATE ticket - SET paid = ? + SET paid = true WHERE id = ? """, - (paid, payment_hash), + (payment_hash,), ) - eventdata = get_event(row[2]) + eventdata = await get_event(row[2]) + assert eventdata, "Couldn't get event from ticket being paid" + sold = eventdata.sold + 1 amount_tickets = eventdata.amount_tickets - 1 - db.execute( + await db.execute( """ UPDATE events SET sold = ?, amount_tickets = ? @@ -47,36 +48,34 @@ def update_ticket(paid: bool, payment_hash: str) -> Tickets: """, (sold, amount_tickets, row[2]), ) - return get_ticket(payment_hash) + + ticket = await get_ticket(payment_hash) + assert ticket, "Newly updated ticket couldn't be retrieved" + return ticket -def get_ticket(payment_hash: str) -> Optional[Tickets]: - with open_ext_db("events") as db: - row = db.fetchone("SELECT * FROM ticket WHERE id = ?", (payment_hash,)) - +async def get_ticket(payment_hash: str) -> Optional[Tickets]: + row = await db.fetchone("SELECT * FROM ticket WHERE id = ?", (payment_hash,)) return Tickets(**row) if row else None -def get_tickets(wallet_ids: Union[str, List[str]]) -> List[Tickets]: +async def get_tickets(wallet_ids: Union[str, List[str]]) -> List[Tickets]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("events") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM ticket WHERE wallet IN ({q})", (*wallet_ids,)) - print("scrum") + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM ticket WHERE wallet IN ({q})", (*wallet_ids,)) return [Tickets(**row) for row in rows] -def delete_ticket(payment_hash: str) -> None: - with open_ext_db("events") as db: - db.execute("DELETE FROM ticket WHERE id = ?", (payment_hash,)) +async def delete_ticket(payment_hash: str) -> None: + await db.execute("DELETE FROM ticket WHERE id = ?", (payment_hash,)) -########EVENTS######### +# EVENTS -def create_event( +async def create_event( *, wallet: str, name: str, @@ -87,81 +86,68 @@ def create_event( amount_tickets: int, price_per_ticket: int, ) -> Events: - with open_ext_db("events") as db: - event_id = urlsafe_short_hash() - db.execute( - """ - INSERT INTO events (id, wallet, name, info, closing_date, event_start_date, event_end_date, amount_tickets, price_per_ticket, sold) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - event_id, - wallet, - name, - info, - closing_date, - event_start_date, - event_end_date, - amount_tickets, - price_per_ticket, - 0, - ), - ) - print(event_id) + event_id = urlsafe_short_hash() + await db.execute( + """ + INSERT INTO events (id, wallet, name, info, closing_date, event_start_date, event_end_date, amount_tickets, price_per_ticket, sold) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + event_id, + wallet, + name, + info, + closing_date, + event_start_date, + event_end_date, + amount_tickets, + price_per_ticket, + 0, + ), + ) - return get_event(event_id) + event = await get_event(event_id) + assert event, "Newly created event couldn't be retrieved" + return event -def update_event(event_id: str, **kwargs) -> Events: +async def update_event(event_id: str, **kwargs) -> Events: q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()]) - with open_ext_db("events") as db: - db.execute(f"UPDATE events SET {q} WHERE id = ?", (*kwargs.values(), event_id)) + await db.execute(f"UPDATE events SET {q} WHERE id = ?", (*kwargs.values(), event_id)) + event = await get_event(event_id) + assert event, "Newly updated event couldn't be retrieved" + return event - row = db.fetchone("SELECT * FROM events WHERE id = ?", (event_id,)) +async def get_event(event_id: str) -> Optional[Events]: + row = await db.fetchone("SELECT * FROM events WHERE id = ?", (event_id,)) return Events(**row) if row else None -def get_event(event_id: str) -> Optional[Events]: - with open_ext_db("events") as db: - row = db.fetchone("SELECT * FROM events WHERE id = ?", (event_id,)) - - return Events(**row) if row else None - - -def get_events(wallet_ids: Union[str, List[str]]) -> List[Events]: +async def get_events(wallet_ids: Union[str, List[str]]) -> List[Events]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("events") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM events WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM events WHERE wallet IN ({q})", (*wallet_ids,)) return [Events(**row) for row in rows] -def delete_event(event_id: str) -> None: - with open_ext_db("events") as db: - db.execute("DELETE FROM events WHERE id = ?", (event_id,)) +async def delete_event(event_id: str) -> None: + await db.execute("DELETE FROM events WHERE id = ?", (event_id,)) -########EVENTTICKETS######### +# EVENTTICKETS -def get_event_tickets(event_id: str, wallet_id: str) -> Tickets: - - with open_ext_db("events") as db: - rows = db.fetchall("SELECT * FROM ticket WHERE wallet = ? AND event = ?", (wallet_id, event_id)) - print(rows) - +async def get_event_tickets(event_id: str, wallet_id: str) -> List[Tickets]: + rows = await db.fetchall("SELECT * FROM ticket WHERE wallet = ? AND event = ?", (wallet_id, event_id)) return [Tickets(**row) for row in rows] -def reg_ticket(ticket_id: str) -> Tickets: - with open_ext_db("events") as db: - db.execute("UPDATE ticket SET registered = ? WHERE id = ?", (True, ticket_id)) - ticket = db.fetchone("SELECT * FROM ticket WHERE id = ?", (ticket_id,)) - print(ticket[1]) - rows = db.fetchall("SELECT * FROM ticket WHERE event = ?", (ticket[1],)) - +async def reg_ticket(ticket_id: str) -> List[Tickets]: + await db.execute("UPDATE ticket SET registered = ? WHERE id = ?", (True, ticket_id)) + ticket = await db.fetchone("SELECT * FROM ticket WHERE id = ?", (ticket_id,)) + rows = await db.fetchall("SELECT * FROM ticket WHERE event = ?", (ticket[1],)) return [Tickets(**row) for row in rows] diff --git a/lnbits/extensions/events/migrations.py b/lnbits/extensions/events/migrations.py index 95e361b0..52a7658c 100644 --- a/lnbits/extensions/events/migrations.py +++ b/lnbits/extensions/events/migrations.py @@ -1,6 +1,6 @@ -def m001_initial(db): +async def m001_initial(db): - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS events ( id TEXT PRIMARY KEY, @@ -18,7 +18,7 @@ def m001_initial(db): """ ) - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS tickets ( id TEXT PRIMARY KEY, @@ -33,9 +33,9 @@ def m001_initial(db): ) -def m002_changed(db): +async def m002_changed(db): - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS ticket ( id TEXT PRIMARY KEY, @@ -50,7 +50,7 @@ def m002_changed(db): """ ) - for row in [list(row) for row in db.fetchall("SELECT * FROM tickets")]: + for row in [list(row) for row in await db.fetchall("SELECT * FROM tickets")]: usescsv = "" for i in range(row[5]): @@ -59,7 +59,7 @@ def m002_changed(db): else: usescsv += "," + str(1) usescsv = usescsv[1:] - db.execute( + await db.execute( """ INSERT INTO ticket ( id, @@ -72,14 +72,6 @@ def m002_changed(db): ) VALUES (?, ?, ?, ?, ?, ?, ?) """, - ( - row[0], - row[1], - row[2], - row[3], - row[4], - row[5], - True, - ), + (row[0], row[1], row[2], row[3], row[4], row[5], True,), ) - db.execute("DROP TABLE tickets") + await db.execute("DROP TABLE tickets") diff --git a/lnbits/extensions/events/views.py b/lnbits/extensions/events/views.py index 56c01343..86426acb 100644 --- a/lnbits/extensions/events/views.py +++ b/lnbits/extensions/events/views.py @@ -1,10 +1,10 @@ from quart import g, abort, render_template from datetime import date, datetime - -from lnbits.decorators import check_user_exists, validate_uuids from http import HTTPStatus -from lnbits.extensions.events import events_ext +from lnbits.decorators import check_user_exists, validate_uuids + +from . import events_ext from .crud import get_ticket, get_event @@ -17,7 +17,10 @@ async def index(): @events_ext.route("/") async def display(event_id): - event = get_event(event_id) or abort(HTTPStatus.NOT_FOUND, "Event does not exist.") + event = await get_event(event_id) + if not event: + abort(HTTPStatus.NOT_FOUND, "Event does not exist.") + if event.amount_tickets < 1: return await render_template( "events/error.html", event_name=event.name, event_error="Sorry, tickets are sold out :(" @@ -39,8 +42,14 @@ async def display(event_id): @events_ext.route("/ticket/") async def ticket(ticket_id): - ticket = get_ticket(ticket_id) or abort(HTTPStatus.NOT_FOUND, "Ticket does not exist.") - event = get_event(ticket.event) or abort(HTTPStatus.NOT_FOUND, "Event does not exist.") + ticket = await get_ticket(ticket_id) + if not ticket: + abort(HTTPStatus.NOT_FOUND, "Ticket does not exist.") + + event = await get_event(ticket.event) + if not event: + abort(HTTPStatus.NOT_FOUND, "Event does not exist.") + return await render_template( "events/ticket.html", ticket_id=ticket_id, ticket_name=event.name, ticket_info=event.info ) @@ -48,7 +57,9 @@ async def ticket(ticket_id): @events_ext.route("/register/") async def register(event_id): - event = get_event(event_id) or abort(HTTPStatus.NOT_FOUND, "Event does not exist.") + event = await get_event(event_id) + if not event: + abort(HTTPStatus.NOT_FOUND, "Event does not exist.") return await render_template( "events/register.html", event_id=event_id, event_name=event.name, wallet_id=event.wallet diff --git a/lnbits/extensions/events/views_api.py b/lnbits/extensions/events/views_api.py index 49515b67..3467cc2d 100644 --- a/lnbits/extensions/events/views_api.py +++ b/lnbits/extensions/events/views_api.py @@ -5,10 +5,10 @@ from lnbits.core.crud import get_user, get_wallet from lnbits.core.services import create_invoice, check_invoice_status from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.extensions.events import events_ext +from . import events_ext from .crud import ( create_ticket, - update_ticket, + set_ticket_paid, get_ticket, get_tickets, delete_ticket, @@ -22,7 +22,7 @@ from .crud import ( ) -#########Events########## +# Events @events_ext.route("/api/v1/events", methods=["GET"]) @@ -31,9 +31,9 @@ async def api_events(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids - return jsonify([event._asdict() for event in get_events(wallet_ids)]), HTTPStatus.OK + return jsonify([event._asdict() for event in await get_events(wallet_ids)]), HTTPStatus.OK @events_ext.route("/api/v1/events", methods=["POST"]) @@ -53,35 +53,31 @@ async def api_events(): ) async def api_event_create(event_id=None): if event_id: - event = get_event(event_id) - print(g.data) - + event = await get_event(event_id) if not event: return jsonify({"message": "Form does not exist."}), HTTPStatus.NOT_FOUND if event.wallet != g.wallet.id: return jsonify({"message": "Not your event."}), HTTPStatus.FORBIDDEN - event = update_event(event_id, **g.data) + event = await update_event(event_id, **g.data) else: - event = create_event(**g.data) - print(event) + event = await create_event(**g.data) + return jsonify(event._asdict()), HTTPStatus.CREATED @events_ext.route("/api/v1/events/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_form_delete(event_id): - event = get_event(event_id) - + event = await get_event(event_id) if not event: return jsonify({"message": "Event does not exist."}), HTTPStatus.NOT_FOUND if event.wallet != g.wallet.id: return jsonify({"message": "Not your event."}), HTTPStatus.FORBIDDEN - delete_event(event_id) - + await delete_event(event_id) return "", HTTPStatus.NO_CONTENT @@ -94,9 +90,9 @@ async def api_tickets(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids - return jsonify([ticket._asdict() for ticket in get_tickets(wallet_ids)]), HTTPStatus.OK + return jsonify([ticket._asdict() for ticket in await get_tickets(wallet_ids)]), HTTPStatus.OK @events_ext.route("/api/v1/tickets//", methods=["POST"]) @@ -107,17 +103,17 @@ async def api_tickets(): } ) async def api_ticket_make_ticket(event_id, sats): - event = get_event(event_id) + event = await get_event(event_id) if not event: return jsonify({"message": "Event does not exist."}), HTTPStatus.NOT_FOUND try: - payment_hash, payment_request = create_invoice( + payment_hash, payment_request = await create_invoice( wallet_id=event.wallet, amount=int(sats), memo=f"{event_id}", extra={"tag": "events"} ) except Exception as e: return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR - ticket = create_ticket(payment_hash=payment_hash, wallet=event.wallet, event=event_id, **g.data) + ticket = await create_ticket(payment_hash=payment_hash, wallet=event.wallet, event=event_id, **g.data) if not ticket: return jsonify({"message": "Event could not be fetched."}), HTTPStatus.NOT_FOUND @@ -127,17 +123,19 @@ async def api_ticket_make_ticket(event_id, sats): @events_ext.route("/api/v1/tickets/", methods=["GET"]) async def api_ticket_send_ticket(payment_hash): - ticket = get_ticket(payment_hash) + ticket = await get_ticket(payment_hash) + try: - is_paid = not check_invoice_status(ticket.wallet, payment_hash).pending + status = await check_invoice_status(ticket.wallet, payment_hash) + is_paid = not status.pending except Exception: return jsonify({"message": "Not paid."}), HTTPStatus.NOT_FOUND if is_paid: - wallet = get_wallet(ticket.wallet) - payment = wallet.get_payment(payment_hash) - payment.set_pending(False) - ticket = update_ticket(paid=True, payment_hash=payment_hash) + wallet = await get_wallet(ticket.wallet) + payment = await wallet.get_payment(payment_hash) + await payment.set_pending(False) + ticket = await set_ticket_paid(payment_hash=payment_hash) return jsonify({"paid": True, "ticket_id": ticket.id}), HTTPStatus.OK @@ -147,7 +145,7 @@ async def api_ticket_send_ticket(payment_hash): @events_ext.route("/api/v1/tickets/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_ticket_delete(ticket_id): - ticket = get_ticket(ticket_id) + ticket = await get_ticket(ticket_id) if not ticket: return jsonify({"message": "Ticket does not exist."}), HTTPStatus.NOT_FOUND @@ -155,32 +153,28 @@ async def api_ticket_delete(ticket_id): if ticket.wallet != g.wallet.id: return jsonify({"message": "Not your ticket."}), HTTPStatus.FORBIDDEN - delete_ticket(ticket_id) - + await delete_ticket(ticket_id) return "", HTTPStatus.NO_CONTENT -#########EventTickets########## +# Event Tickets @events_ext.route("/api/v1/eventtickets//", methods=["GET"]) async def api_event_tickets(wallet_id, event_id): - return ( - jsonify([ticket._asdict() for ticket in get_event_tickets(wallet_id=wallet_id, event_id=event_id)]), + jsonify([ticket._asdict() for ticket in await get_event_tickets(wallet_id=wallet_id, event_id=event_id)]), HTTPStatus.OK, ) @events_ext.route("/api/v1/register/ticket/", methods=["GET"]) async def api_event_register_ticket(ticket_id): - - ticket = get_ticket(ticket_id) - + ticket = await get_ticket(ticket_id) if not ticket: return jsonify({"message": "Ticket does not exist."}), HTTPStatus.FORBIDDEN if ticket.registered == True: return jsonify({"message": "Ticket already registered"}), HTTPStatus.FORBIDDEN - return jsonify([ticket._asdict() for ticket in reg_ticket(ticket_id)]), HTTPStatus.OK + return jsonify([ticket._asdict() for ticket in await reg_ticket(ticket_id)]), HTTPStatus.OK diff --git a/lnbits/extensions/example/__init__.py b/lnbits/extensions/example/__init__.py index 1d6055ae..43a8223c 100644 --- a/lnbits/extensions/example/__init__.py +++ b/lnbits/extensions/example/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_example") example_ext: Blueprint = Blueprint("example", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/example/views.py b/lnbits/extensions/example/views.py index 0fe68a85..99e58f62 100644 --- a/lnbits/extensions/example/views.py +++ b/lnbits/extensions/example/views.py @@ -1,7 +1,8 @@ from quart import g, render_template from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.example import example_ext + +from . import example_ext @example_ext.route("/") diff --git a/lnbits/extensions/example/views_api.py b/lnbits/extensions/example/views_api.py index 095ddce2..29814a78 100644 --- a/lnbits/extensions/example/views_api.py +++ b/lnbits/extensions/example/views_api.py @@ -10,7 +10,7 @@ from quart import jsonify from http import HTTPStatus -from lnbits.extensions.example import example_ext +from . import example_ext # add your endpoints here @@ -20,21 +20,9 @@ from lnbits.extensions.example import example_ext async def api_example(): """Try to add descriptions for others.""" tools = [ - { - "name": "Flask", - "url": "https://flask.palletsprojects.com/", - "language": "Python", - }, - { - "name": "Vue.js", - "url": "https://vuejs.org/", - "language": "JavaScript", - }, - { - "name": "Quasar Framework", - "url": "https://quasar.dev/", - "language": "JavaScript", - }, + {"name": "Flask", "url": "https://flask.palletsprojects.com/", "language": "Python",}, + {"name": "Vue.js", "url": "https://vuejs.org/", "language": "JavaScript",}, + {"name": "Quasar Framework", "url": "https://quasar.dev/", "language": "JavaScript",}, ] return jsonify(tools), HTTPStatus.OK diff --git a/lnbits/extensions/lndhub/__init__.py b/lnbits/extensions/lndhub/__init__.py index ae0cc403..ed368c04 100644 --- a/lnbits/extensions/lndhub/__init__.py +++ b/lnbits/extensions/lndhub/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_lndhub") lndhub_ext: Blueprint = Blueprint("lndhub", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/lndhub/decorators.py b/lnbits/extensions/lndhub/decorators.py index d44fb255..07a109bc 100644 --- a/lnbits/extensions/lndhub/decorators.py +++ b/lnbits/extensions/lndhub/decorators.py @@ -15,7 +15,7 @@ def check_wallet(requires_admin=False): if requires_admin and key_type != "admin": return jsonify({"error": True, "code": 2, "message": "insufficient permissions"}) - g.wallet = get_wallet_for_key(key, key_type) + g.wallet = await get_wallet_for_key(key, key_type) if not g.wallet: return jsonify({"error": True, "code": 2, "message": "insufficient permissions"}) return await view(**kwargs) diff --git a/lnbits/extensions/lndhub/migrations.py b/lnbits/extensions/lndhub/migrations.py index 3c90d947..d6ea5fde 100644 --- a/lnbits/extensions/lndhub/migrations.py +++ b/lnbits/extensions/lndhub/migrations.py @@ -1,2 +1,2 @@ -def migrate(): +async def migrate(): pass diff --git a/lnbits/extensions/lndhub/views.py b/lnbits/extensions/lndhub/views.py index e9478ff1..2bc01fc1 100644 --- a/lnbits/extensions/lndhub/views.py +++ b/lnbits/extensions/lndhub/views.py @@ -1,7 +1,7 @@ from quart import render_template, g from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.lndhub import lndhub_ext +from . import lndhub_ext @lndhub_ext.route("/") diff --git a/lnbits/extensions/lndhub/views_api.py b/lnbits/extensions/lndhub/views_api.py index acfacfbc..7c19c76e 100644 --- a/lnbits/extensions/lndhub/views_api.py +++ b/lnbits/extensions/lndhub/views_api.py @@ -8,7 +8,7 @@ from lnbits.decorators import api_validate_post_request from lnbits.settings import WALLET from lnbits import bolt11 -from lnbits.extensions.lndhub import lndhub_ext +from . import lndhub_ext from .decorators import check_wallet from .utils import to_buffer, decoded_as_lndhub @@ -46,20 +46,11 @@ async def lndhub_auth(): ) async def lndhub_addinvoice(): try: - _, pr = create_invoice( - wallet_id=g.wallet.id, - amount=int(g.data["amt"]), - memo=g.data["memo"], - extra={"tag": "lndhub"}, + _, pr = await create_invoice( + wallet_id=g.wallet.id, amount=int(g.data["amt"]), memo=g.data["memo"], extra={"tag": "lndhub"}, ) except Exception as e: - return jsonify( - { - "error": True, - "code": 7, - "message": "Failed to create invoice: " + str(e), - } - ) + return jsonify({"error": True, "code": 7, "message": "Failed to create invoice: " + str(e),}) invoice = bolt11.decode(pr) return jsonify( @@ -78,19 +69,11 @@ async def lndhub_addinvoice(): @api_validate_post_request(schema={"invoice": {"type": "string", "required": True}}) async def lndhub_payinvoice(): try: - pay_invoice( - wallet_id=g.wallet.id, - payment_request=g.data["invoice"], - extra={"tag": "lndhub"}, + await pay_invoice( + wallet_id=g.wallet.id, payment_request=g.data["invoice"], extra={"tag": "lndhub"}, ) except Exception as e: - return jsonify( - { - "error": True, - "code": 10, - "message": "Payment failed: " + str(e), - } - ) + return jsonify({"error": True, "code": 10, "message": "Payment failed: " + str(e),}) invoice: bolt11.Invoice = bolt11.decode(g.data["invoice"]) return jsonify( @@ -119,10 +102,10 @@ async def lndhub_balance(): @lndhub_ext.route("/ext/gettxs", methods=["GET"]) @check_wallet() async def lndhub_gettxs(): - for payment in g.wallet.get_payments( + for payment in await g.wallet.get_payments( complete=False, pending=True, outgoing=True, incoming=False, exclude_uncheckable=True ): - payment.set_pending(WALLET.get_payment_status(payment.checking_id).pending) + await payment.set_pending(WALLET.get_payment_status(payment.checking_id).pending) limit = int(request.args.get("limit", 200)) return jsonify( @@ -138,7 +121,7 @@ async def lndhub_gettxs(): "memo": payment.memo if not payment.pending else "Payment in transition", } for payment in reversed( - g.wallet.get_payments(pending=True, complete=True, outgoing=True, incoming=False)[:limit] + (await g.wallet.get_payments(pending=True, complete=True, outgoing=True, incoming=False))[:limit] ) ] ) @@ -147,11 +130,11 @@ async def lndhub_gettxs(): @lndhub_ext.route("/ext/getuserinvoices", methods=["GET"]) @check_wallet() async def lndhub_getuserinvoices(): - delete_expired_invoices() - for invoice in g.wallet.get_payments( + await delete_expired_invoices() + for invoice in await g.wallet.get_payments( complete=False, pending=True, outgoing=False, incoming=True, exclude_uncheckable=True ): - invoice.set_pending(WALLET.get_invoice_status(invoice.checking_id).pending) + await invoice.set_pending(WALLET.get_invoice_status(invoice.checking_id).pending) limit = int(request.args.get("limit", 200)) return jsonify( @@ -169,7 +152,7 @@ async def lndhub_getuserinvoices(): "type": "user_invoice", } for invoice in reversed( - g.wallet.get_payments(pending=True, complete=True, incoming=True, outgoing=False)[:limit] + (await g.wallet.get_payments(pending=True, complete=True, incoming=True, outgoing=False))[:limit] ) ] ) diff --git a/lnbits/extensions/lnticket/__init__.py b/lnbits/extensions/lnticket/__init__.py index 7a91b6b6..21ef19a1 100644 --- a/lnbits/extensions/lnticket/__init__.py +++ b/lnbits/extensions/lnticket/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_lnticket") lnticket_ext: Blueprint = Blueprint("lnticket", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/lnticket/crud.py b/lnbits/extensions/lnticket/crud.py index 561137d9..06b10f41 100644 --- a/lnbits/extensions/lnticket/crud.py +++ b/lnbits/extensions/lnticket/crud.py @@ -1,44 +1,44 @@ from typing import List, Optional, Union -from lnbits.db import open_ext_db from lnbits.helpers import urlsafe_short_hash +from . import db from .models import Tickets, Forms -#######TICKETS######## +async def create_ticket( + payment_hash: str, wallet: str, form: str, name: str, email: str, ltext: str, sats: int, +) -> Tickets: + await db.execute( + """ + INSERT INTO ticket (id, form, email, ltext, name, wallet, sats, paid) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (payment_hash, form, email, ltext, name, wallet, sats, False), + ) + + ticket = await get_ticket(payment_hash) + assert ticket, "Newly created ticket couldn't be retrieved" + return ticket -def create_ticket(payment_hash: str, wallet: str, form: str, name: str, email: str, ltext: str, sats: int) -> Tickets: - with open_ext_db("lnticket") as db: - db.execute( - """ - INSERT INTO ticket (id, form, email, ltext, name, wallet, sats, paid) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - (payment_hash, form, email, ltext, name, wallet, sats, False), - ) - - return get_ticket(payment_hash) - - -def update_ticket(paid: bool, payment_hash: str) -> Tickets: - with open_ext_db("lnticket") as db: - row = db.fetchone("SELECT * FROM ticket WHERE id = ?", (payment_hash,)) - if row[7] == True: - return get_ticket(payment_hash) - db.execute( +async def set_ticket_paid(payment_hash: str) -> Tickets: + row = await db.fetchone("SELECT * FROM ticket WHERE id = ?", (payment_hash,)) + if row[7] == False: + await db.execute( """ UPDATE ticket - SET paid = ? + SET paid = true WHERE id = ? """, - (paid, payment_hash), + (payment_hash,), ) - formdata = get_form(row[1]) + formdata = await get_form(row[1]) + assert formdata, "Couldn't get form from paid ticket" + amount = formdata.amountmade + row[7] - db.execute( + await db.execute( """ UPDATE forms SET amountmade = ? @@ -46,76 +46,71 @@ def update_ticket(paid: bool, payment_hash: str) -> Tickets: """, (amount, row[1]), ) - return get_ticket(payment_hash) + + ticket = await get_ticket(payment_hash) + assert ticket, "Newly updated ticket couldn't be retrieved" + return ticket -def get_ticket(ticket_id: str) -> Optional[Tickets]: - with open_ext_db("lnticket") as db: - row = db.fetchone("SELECT * FROM ticket WHERE id = ?", (ticket_id,)) - +async def get_ticket(ticket_id: str) -> Optional[Tickets]: + row = await db.fetchone("SELECT * FROM ticket WHERE id = ?", (ticket_id,)) return Tickets(**row) if row else None -def get_tickets(wallet_ids: Union[str, List[str]]) -> List[Tickets]: +async def get_tickets(wallet_ids: Union[str, List[str]]) -> List[Tickets]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("lnticket") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM ticket WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM ticket WHERE wallet IN ({q})", (*wallet_ids,)) return [Tickets(**row) for row in rows] -def delete_ticket(ticket_id: str) -> None: - with open_ext_db("lnticket") as db: - db.execute("DELETE FROM ticket WHERE id = ?", (ticket_id,)) +async def delete_ticket(ticket_id: str) -> None: + await db.execute("DELETE FROM ticket WHERE id = ?", (ticket_id,)) -########FORMS######### +# FORMS -def create_form(*, wallet: str, name: str, description: str, costpword: int) -> Forms: - with open_ext_db("lnticket") as db: - form_id = urlsafe_short_hash() - db.execute( - """ - INSERT INTO forms (id, wallet, name, description, costpword, amountmade) - VALUES (?, ?, ?, ?, ?, ?) - """, - (form_id, wallet, name, description, costpword, 0), - ) +async def create_form(*, wallet: str, name: str, description: str, costpword: int) -> Forms: + form_id = urlsafe_short_hash() + await db.execute( + """ + INSERT INTO forms (id, wallet, name, description, costpword, amountmade) + VALUES (?, ?, ?, ?, ?, ?) + """, + (form_id, wallet, name, description, costpword, 0), + ) - return get_form(form_id) + form = await get_form(form_id) + assert form, "Newly created form couldn't be retrieved" + return form -def update_form(form_id: str, **kwargs) -> Forms: +async def update_form(form_id: str, **kwargs) -> Forms: q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()]) - with open_ext_db("lnticket") as db: - db.execute(f"UPDATE forms SET {q} WHERE id = ?", (*kwargs.values(), form_id)) - row = db.fetchone("SELECT * FROM forms WHERE id = ?", (form_id,)) + await db.execute(f"UPDATE forms SET {q} WHERE id = ?", (*kwargs.values(), form_id)) + row = await db.fetchone("SELECT * FROM forms WHERE id = ?", (form_id,)) + assert row, "Newly updated form couldn't be retrieved" + return Forms(**row) + +async def get_form(form_id: str) -> Optional[Forms]: + row = await db.fetchone("SELECT * FROM forms WHERE id = ?", (form_id,)) return Forms(**row) if row else None -def get_form(form_id: str) -> Optional[Forms]: - with open_ext_db("lnticket") as db: - row = db.fetchone("SELECT * FROM forms WHERE id = ?", (form_id,)) - - return Forms(**row) if row else None - - -def get_forms(wallet_ids: Union[str, List[str]]) -> List[Forms]: +async def get_forms(wallet_ids: Union[str, List[str]]) -> List[Forms]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("lnticket") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM forms WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM forms WHERE wallet IN ({q})", (*wallet_ids,)) return [Forms(**row) for row in rows] -def delete_form(form_id: str) -> None: - with open_ext_db("lnticket") as db: - db.execute("DELETE FROM forms WHERE id = ?", (form_id,)) +async def delete_form(form_id: str) -> None: + await db.execute("DELETE FROM forms WHERE id = ?", (form_id,)) diff --git a/lnbits/extensions/lnticket/migrations.py b/lnbits/extensions/lnticket/migrations.py index b7b7f6b4..09e26a78 100644 --- a/lnbits/extensions/lnticket/migrations.py +++ b/lnbits/extensions/lnticket/migrations.py @@ -1,6 +1,6 @@ -def m001_initial(db): +async def m001_initial(db): - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS forms ( id TEXT PRIMARY KEY, @@ -14,7 +14,7 @@ def m001_initial(db): """ ) - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS tickets ( id TEXT PRIMARY KEY, @@ -30,9 +30,9 @@ def m001_initial(db): ) -def m002_changed(db): +async def m002_changed(db): - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS ticket ( id TEXT PRIMARY KEY, @@ -48,7 +48,7 @@ def m002_changed(db): """ ) - for row in [list(row) for row in db.fetchall("SELECT * FROM tickets")]: + for row in [list(row) for row in await db.fetchall("SELECT * FROM tickets")]: usescsv = "" for i in range(row[5]): @@ -57,7 +57,7 @@ def m002_changed(db): else: usescsv += "," + str(1) usescsv = usescsv[1:] - db.execute( + await db.execute( """ INSERT INTO ticket ( id, @@ -71,15 +71,6 @@ def m002_changed(db): ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, - ( - row[0], - row[1], - row[2], - row[3], - row[4], - row[5], - row[6], - True, - ), + (row[0], row[1], row[2], row[3], row[4], row[5], row[6], True,), ) - db.execute("DROP TABLE tickets") + await db.execute("DROP TABLE tickets") diff --git a/lnbits/extensions/lnticket/views.py b/lnbits/extensions/lnticket/views.py index 18048e02..16f75fbc 100644 --- a/lnbits/extensions/lnticket/views.py +++ b/lnbits/extensions/lnticket/views.py @@ -3,7 +3,7 @@ from quart import g, abort, render_template from lnbits.decorators import check_user_exists, validate_uuids from http import HTTPStatus -from lnbits.extensions.lnticket import lnticket_ext +from . import lnticket_ext from .crud import get_form @@ -16,8 +16,9 @@ async def index(): @lnticket_ext.route("/") async def display(form_id): - form = get_form(form_id) or abort(HTTPStatus.NOT_FOUND, "LNTicket does not exist.") - print(form.id) + form = await get_form(form_id) + if not form: + abort(HTTPStatus.NOT_FOUND, "LNTicket does not exist.") return await render_template( "lnticket/display.html", diff --git a/lnbits/extensions/lnticket/views_api.py b/lnbits/extensions/lnticket/views_api.py index d89892bb..832a4f9e 100644 --- a/lnbits/extensions/lnticket/views_api.py +++ b/lnbits/extensions/lnticket/views_api.py @@ -6,10 +6,10 @@ from lnbits.core.crud import get_user, get_wallet from lnbits.core.services import create_invoice, check_invoice_status from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.extensions.lnticket import lnticket_ext +from . import lnticket_ext from .crud import ( create_ticket, - update_ticket, + set_ticket_paid, get_ticket, get_tickets, delete_ticket, @@ -21,7 +21,7 @@ from .crud import ( ) -#########FORMS########## +# FORMS @lnticket_ext.route("/api/v1/forms", methods=["GET"]) @@ -30,9 +30,9 @@ async def api_forms(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids - return jsonify([form._asdict() for form in get_forms(wallet_ids)]), HTTPStatus.OK + return jsonify([form._asdict() for form in await get_forms(wallet_ids)]), HTTPStatus.OK @lnticket_ext.route("/api/v1/forms", methods=["POST"]) @@ -48,7 +48,7 @@ async def api_forms(): ) async def api_form_create(form_id=None): if form_id: - form = get_form(form_id) + form = await get_form(form_id) if not form: return jsonify({"message": "Form does not exist."}), HTTPStatus.NOT_FOUND @@ -56,16 +56,16 @@ async def api_form_create(form_id=None): if form.wallet != g.wallet.id: return jsonify({"message": "Not your form."}), HTTPStatus.FORBIDDEN - form = update_form(form_id, **g.data) + form = await update_form(form_id, **g.data) else: - form = create_form(**g.data) + form = await create_form(**g.data) return jsonify(form._asdict()), HTTPStatus.CREATED @lnticket_ext.route("/api/v1/forms/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_form_delete(form_id): - form = get_form(form_id) + form = await get_form(form_id) if not form: return jsonify({"message": "Form does not exist."}), HTTPStatus.NOT_FOUND @@ -73,7 +73,7 @@ async def api_form_delete(form_id): if form.wallet != g.wallet.id: return jsonify({"message": "Not your form."}), HTTPStatus.FORBIDDEN - delete_form(form_id) + await delete_form(form_id) return "", HTTPStatus.NO_CONTENT @@ -87,9 +87,9 @@ async def api_tickets(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids - return jsonify([form._asdict() for form in get_tickets(wallet_ids)]), HTTPStatus.OK + return jsonify([form._asdict() for form in await get_tickets(wallet_ids)]), HTTPStatus.OK @lnticket_ext.route("/api/v1/tickets/", methods=["POST"]) @@ -102,22 +102,17 @@ async def api_tickets(): } ) async def api_ticket_make_ticket(form_id): - form = get_form(form_id) + form = await get_form(form_id) if not form: return jsonify({"message": "LNTicket does not exist."}), HTTPStatus.NOT_FOUND - try: - nwords = len(re.split(r"\s+", g.data["ltext"])) - sats = nwords * form.costpword - payment_hash, payment_request = create_invoice( - wallet_id=form.wallet, - amount=sats, - memo=f"ticket with {nwords} words on {form_id}", - extra={"tag": "lnticket"}, - ) - except Exception as e: - return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR - ticket = create_ticket(payment_hash=payment_hash, wallet=form.wallet, sats=sats, **g.data) + nwords = len(re.split(r"\s+", g.data["ltext"])) + sats = nwords * form.costpword + payment_hash, payment_request = await create_invoice( + wallet_id=form.wallet, amount=sats, memo=f"ticket with {nwords} words on {form_id}", extra={"tag": "lnticket"}, + ) + + ticket = await create_ticket(payment_hash=payment_hash, wallet=form.wallet, sats=sats, **g.data) if not ticket: return jsonify({"message": "LNTicket could not be fetched."}), HTTPStatus.NOT_FOUND @@ -127,17 +122,18 @@ async def api_ticket_make_ticket(form_id): @lnticket_ext.route("/api/v1/tickets/", methods=["GET"]) async def api_ticket_send_ticket(payment_hash): - ticket = get_ticket(payment_hash) + ticket = await get_ticket(payment_hash) try: - is_paid = not check_invoice_status(ticket.wallet, payment_hash).pending + status = await check_invoice_status(ticket.wallet, payment_hash) + is_paid = not status.pending except Exception: return jsonify({"message": "Not paid."}), HTTPStatus.NOT_FOUND if is_paid: - wallet = get_wallet(ticket.wallet) - payment = wallet.get_payment(payment_hash) - payment.set_pending(False) - ticket = update_ticket(paid=True, payment_hash=payment_hash) + wallet = await get_wallet(ticket.wallet) + payment = await wallet.get_payment(payment_hash) + await payment.set_pending(False) + ticket = await set_ticket_paid(payment_hash=payment_hash) return jsonify({"paid": True, "ticket_id": ticket.id}), HTTPStatus.OK return jsonify({"paid": False}), HTTPStatus.OK @@ -146,7 +142,7 @@ async def api_ticket_send_ticket(payment_hash): @lnticket_ext.route("/api/v1/tickets/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_ticket_delete(ticket_id): - ticket = get_ticket(ticket_id) + ticket = await get_ticket(ticket_id) if not ticket: return jsonify({"message": "Paywall does not exist."}), HTTPStatus.NOT_FOUND @@ -154,6 +150,6 @@ async def api_ticket_delete(ticket_id): if ticket.wallet != g.wallet.id: return jsonify({"message": "Not your ticket."}), HTTPStatus.FORBIDDEN - delete_ticket(ticket_id) + await delete_ticket(ticket_id) return "", HTTPStatus.NO_CONTENT diff --git a/lnbits/extensions/lnurlp/__init__.py b/lnbits/extensions/lnurlp/__init__.py index 2c9a3e02..7a44f9d5 100644 --- a/lnbits/extensions/lnurlp/__init__.py +++ b/lnbits/extensions/lnurlp/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_lnurlp") lnurlp_ext: Blueprint = Blueprint("lnurlp", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/lnurlp/crud.py b/lnbits/extensions/lnurlp/crud.py index 1a1b84e8..7ceccc54 100644 --- a/lnbits/extensions/lnurlp/crud.py +++ b/lnbits/extensions/lnurlp/crud.py @@ -1,14 +1,10 @@ -import json from typing import List, Optional, Union -from lnbits.db import open_ext_db -from lnbits.core.models import Payment -from quart import g - +from . import db from .models import PayLink -def create_pay_link( +async def create_pay_link( *, wallet_id: str, description: str, @@ -19,96 +15,66 @@ def create_pay_link( webhook_url: Optional[str] = None, success_text: Optional[str] = None, success_url: Optional[str] = None, -) -> Optional[PayLink]: - with open_ext_db("lnurlp") as db: - db.execute( - """ - INSERT INTO pay_links ( - wallet, - description, - min, - max, - served_meta, - served_pr, - webhook_url, - success_text, - success_url, - comment_chars, - currency - ) - VALUES (?, ?, ?, ?, 0, 0, ?, ?, ?, ?, ?) - """, - ( - wallet_id, - description, - min, - max, - webhook_url, - success_text, - success_url, - comment_chars, - currency, - ), +) -> PayLink: + result = await db.execute( + """ + INSERT INTO pay_links ( + wallet, + description, + min, + max, + served_meta, + served_pr, + webhook_url, + success_text, + success_url, + comment_chars, + currency ) - link_id = db.cursor.lastrowid - return get_pay_link(link_id) + VALUES (?, ?, ?, ?, 0, 0, ?, ?, ?, ?, ?) + """, + (wallet_id, description, min, max, webhook_url, success_text, success_url, comment_chars, currency,), + ) + link_id = result._result_proxy.lastrowid + link = await get_pay_link(link_id) + assert link, "Newly created link couldn't be retrieved" + return link -def get_pay_link(link_id: int) -> Optional[PayLink]: - with open_ext_db("lnurlp") as db: - row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) - +async def get_pay_link(link_id: int) -> Optional[PayLink]: + row = await db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) return PayLink.from_row(row) if row else None -def get_pay_links(wallet_ids: Union[str, List[str]]) -> List[PayLink]: +async def get_pay_links(wallet_ids: Union[str, List[str]]) -> List[PayLink]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("lnurlp") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall( - f""" - SELECT * FROM pay_links WHERE wallet IN ({q}) - ORDER BY Id - """, - (*wallet_ids,), - ) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall( + f""" + SELECT * FROM pay_links WHERE wallet IN ({q}) + ORDER BY Id + """, + (*wallet_ids,), + ) return [PayLink.from_row(row) for row in rows] -def update_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: +async def update_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()]) - - with open_ext_db("lnurlp") as db: - db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) - row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) - + await db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) + row = await db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) return PayLink.from_row(row) if row else None -def increment_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: +async def increment_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: q = ", ".join([f"{field[0]} = {field[0]} + ?" for field in kwargs.items()]) - - with open_ext_db("lnurlp") as db: - db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) - row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) - + await db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) + row = await db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) return PayLink.from_row(row) if row else None -def delete_pay_link(link_id: int) -> None: - with open_ext_db("lnurlp") as db: - db.execute("DELETE FROM pay_links WHERE id = ?", (link_id,)) - - -def mark_webhook_sent(payment: Payment, status: int) -> None: - payment.extra["wh_status"] = status - g.db.execute( - """ - UPDATE apipayments SET extra = ? - WHERE hash = ? - """, - (json.dumps(payment.extra), payment.payment_hash), - ) +async def delete_pay_link(link_id: int) -> None: + await db.execute("DELETE FROM pay_links WHERE id = ?", (link_id,)) diff --git a/lnbits/extensions/lnurlp/lnurl.py b/lnbits/extensions/lnurlp/lnurl.py index 331eb047..b505e494 100644 --- a/lnbits/extensions/lnurlp/lnurl.py +++ b/lnbits/extensions/lnurlp/lnurl.py @@ -13,7 +13,7 @@ from .helpers import get_fiat_rate @lnurlp_ext.route("/api/v1/lnurl/", methods=["GET"]) async def api_lnurl_response(link_id): - link = increment_pay_link(link_id, served_meta=1) + link = await increment_pay_link(link_id, served_meta=1) if not link: return jsonify({"status": "ERROR", "reason": "LNURL-pay not found."}), HTTPStatus.OK @@ -34,7 +34,7 @@ async def api_lnurl_response(link_id): @lnurlp_ext.route("/api/v1/lnurl/cb/", methods=["GET"]) async def api_lnurl_callback(link_id): - link = increment_pay_link(link_id, served_pr=1) + link = await increment_pay_link(link_id, served_pr=1) if not link: return jsonify({"status": "ERROR", "reason": "LNURL-pay not found."}), HTTPStatus.OK @@ -71,7 +71,7 @@ async def api_lnurl_callback(link_id): HTTPStatus.OK, ) - payment_hash, payment_request = create_invoice( + payment_hash, payment_request = await create_invoice( wallet_id=link.wallet, amount=int(amount_received / 1000), memo=link.description, @@ -79,10 +79,6 @@ async def api_lnurl_callback(link_id): extra={"tag": "lnurlp", "link": link.id, "comment": comment}, ) - resp = LnurlPayActionResponse( - pr=payment_request, - success_action=link.success_action(payment_hash), - routes=[], - ) + resp = LnurlPayActionResponse(pr=payment_request, success_action=link.success_action(payment_hash), routes=[],) return jsonify(resp.dict()), HTTPStatus.OK diff --git a/lnbits/extensions/lnurlp/migrations.py b/lnbits/extensions/lnurlp/migrations.py index f20cd684..13dbc959 100644 --- a/lnbits/extensions/lnurlp/migrations.py +++ b/lnbits/extensions/lnurlp/migrations.py @@ -1,8 +1,8 @@ -def m001_initial(db): +async def m001_initial(db): """ Initial pay table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS pay_links ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -16,14 +16,14 @@ def m001_initial(db): ) -def m002_webhooks_and_success_actions(db): +async def m002_webhooks_and_success_actions(db): """ Webhooks and success actions. """ - db.execute("ALTER TABLE pay_links ADD COLUMN webhook_url TEXT;") - db.execute("ALTER TABLE pay_links ADD COLUMN success_text TEXT;") - db.execute("ALTER TABLE pay_links ADD COLUMN success_url TEXT;") - db.execute( + await db.execute("ALTER TABLE pay_links ADD COLUMN webhook_url TEXT;") + await db.execute("ALTER TABLE pay_links ADD COLUMN success_text TEXT;") + await db.execute("ALTER TABLE pay_links ADD COLUMN success_url TEXT;") + await db.execute( """ CREATE TABLE invoices ( pay_link INTEGER NOT NULL REFERENCES pay_links (id), @@ -35,14 +35,14 @@ def m002_webhooks_and_success_actions(db): ) -def m003_min_max_comment_fiat(db): +async def m003_min_max_comment_fiat(db): """ Support for min/max amounts, comments and fiat prices that get converted automatically to satoshis based on some API. """ - db.execute("ALTER TABLE pay_links ADD COLUMN currency TEXT;") # null = satoshis - db.execute("ALTER TABLE pay_links ADD COLUMN comment_chars INTEGER DEFAULT 0;") - db.execute("ALTER TABLE pay_links RENAME COLUMN amount TO min;") - db.execute("ALTER TABLE pay_links ADD COLUMN max INTEGER;") - db.execute("UPDATE pay_links SET max = min;") - db.execute("DROP TABLE invoices") + await db.execute("ALTER TABLE pay_links ADD COLUMN currency TEXT;") # null = satoshis + await db.execute("ALTER TABLE pay_links ADD COLUMN comment_chars INTEGER DEFAULT 0;") + await db.execute("ALTER TABLE pay_links RENAME COLUMN amount TO min;") + await db.execute("ALTER TABLE pay_links ADD COLUMN max INTEGER;") + await db.execute("UPDATE pay_links SET max = min;") + await db.execute("DROP TABLE invoices") diff --git a/lnbits/extensions/lnurlp/models.py b/lnbits/extensions/lnurlp/models.py index 93185b6d..9d0e4fd9 100644 --- a/lnbits/extensions/lnurlp/models.py +++ b/lnbits/extensions/lnurlp/models.py @@ -3,9 +3,9 @@ from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, ParseResult from quart import url_for from typing import NamedTuple, Optional, Dict from sqlite3 import Row -from lnurl import Lnurl, encode as lnurl_encode -from lnurl.types import LnurlPayMetadata -from lnurl.models import LnurlPaySuccessAction, MessageAction, UrlAction +from lnurl import Lnurl, encode as lnurl_encode # type: ignore +from lnurl.types import LnurlPayMetadata # type: ignore +from lnurl.models import LnurlPaySuccessAction, MessageAction, UrlAction # type: ignore class PayLink(NamedTuple): diff --git a/lnbits/extensions/lnurlp/tasks.py b/lnbits/extensions/lnurlp/tasks.py index 3be3a98e..10aa9c77 100644 --- a/lnbits/extensions/lnurlp/tasks.py +++ b/lnbits/extensions/lnurlp/tasks.py @@ -1,10 +1,12 @@ import trio # type: ignore +import json import httpx +from lnbits.core import db as core_db from lnbits.core.models import Payment -from lnbits.tasks import run_on_pseudo_request, register_invoice_listener +from lnbits.tasks import register_invoice_listener -from .crud import mark_webhook_sent, get_pay_link +from .crud import get_pay_link async def register_listeners(): @@ -15,7 +17,7 @@ async def register_listeners(): async def wait_for_paid_invoices(invoice_paid_chan: trio.MemoryReceiveChannel): async for payment in invoice_paid_chan: - await run_on_pseudo_request(on_invoice_paid, payment) + await on_invoice_paid(payment) async def on_invoice_paid(payment: Payment) -> None: @@ -27,7 +29,7 @@ async def on_invoice_paid(payment: Payment) -> None: # this webhook has already been sent return - pay_link = get_pay_link(payment.extra.get("link", -1)) + pay_link = await get_pay_link(payment.extra.get("link", -1)) if pay_link and pay_link.webhook_url: async with httpx.AsyncClient() as client: try: @@ -42,6 +44,18 @@ async def on_invoice_paid(payment: Payment) -> None: }, timeout=40, ) - mark_webhook_sent(payment, r.status_code) + await mark_webhook_sent(payment, r.status_code) except (httpx.ConnectError, httpx.RequestError): - mark_webhook_sent(payment, -1) + await mark_webhook_sent(payment, -1) + + +async def mark_webhook_sent(payment: Payment, status: int) -> None: + payment.extra["wh_status"] = status + + await core_db.execute( + """ + UPDATE apipayments SET extra = ? + WHERE hash = ? + """, + (json.dumps(payment.extra), payment.payment_hash), + ) diff --git a/lnbits/extensions/lnurlp/templates/lnurlp/_api_docs.html b/lnbits/extensions/lnurlp/templates/lnurlp/_api_docs.html index d5b5b015..ec8985e3 100644 --- a/lnbits/extensions/lnurlp/templates/lnurlp/_api_docs.html +++ b/lnbits/extensions/lnurlp/templates/lnurlp/_api_docs.html @@ -17,7 +17,7 @@ [<pay_link_object>, ...]
Curl example
curl -X GET {{ request.url_root }}pay/api/v1/links -H "X-Api-Key: {{ + >curl -X GET {{ request.url_root }}lnurlp/api/v1/links -H "X-Api-Key: {{ g.user.wallets[0].inkey }}" @@ -38,7 +38,7 @@ {"lnurl": <string>}
Curl example
curl -X GET {{ request.url_root }}pay/api/v1/links/<pay_id> -H + >curl -X GET {{ request.url_root }}lnurlp/api/v1/links/<pay_id> -H "X-Api-Key: {{ g.user.wallets[0].inkey }}" @@ -63,7 +63,7 @@ {"lnurl": <string>}
Curl example
curl -X POST {{ request.url_root }}pay/api/v1/links -d + >curl -X POST {{ request.url_root }}lnurlp/api/v1/links -d '{"description": <string>, "amount": <integer>}' -H "Content-type: application/json" -H "X-Api-Key: {{ g.user.wallets[0].adminkey }}" @@ -93,7 +93,7 @@ {"lnurl": <string>}
Curl example
curl -X PUT {{ request.url_root }}pay/api/v1/links/<pay_id> -d + >curl -X PUT {{ request.url_root }}lnurlp/api/v1/links/<pay_id> -d '{"description": <string>, "amount": <integer>}' -H "Content-type: application/json" -H "X-Api-Key: {{ g.user.wallets[0].adminkey }}" @@ -120,7 +120,7 @@
Curl example
curl -X DELETE {{ request.url_root }}pay/api/v1/links/<pay_id> + >curl -X DELETE {{ request.url_root }}lnurlp/api/v1/links/<pay_id> -H "X-Api-Key: {{ g.user.wallets[0].adminkey }}" diff --git a/lnbits/extensions/lnurlp/views.py b/lnbits/extensions/lnurlp/views.py index 25d02e94..72f30c13 100644 --- a/lnbits/extensions/lnurlp/views.py +++ b/lnbits/extensions/lnurlp/views.py @@ -3,7 +3,7 @@ from http import HTTPStatus from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.lnurlp import lnurlp_ext +from . import lnurlp_ext from .crud import get_pay_link @@ -16,11 +16,17 @@ async def index(): @lnurlp_ext.route("/") async def display(link_id): - link = get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") + link = await get_pay_link(link_id) + if not link: + abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") + return await render_template("lnurlp/display.html", link=link) @lnurlp_ext.route("/print/") async def print_qr(link_id): - link = get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") + link = await get_pay_link(link_id) + if not link: + abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") + return await render_template("lnurlp/print_qr.html", link=link) diff --git a/lnbits/extensions/lnurlp/views_api.py b/lnbits/extensions/lnurlp/views_api.py index c211ba06..f68bc5b3 100644 --- a/lnbits/extensions/lnurlp/views_api.py +++ b/lnbits/extensions/lnurlp/views_api.py @@ -5,7 +5,7 @@ from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl # type: ignore from lnbits.core.crud import get_user from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.extensions.lnurlp import lnurlp_ext # type: ignore +from . import lnurlp_ext from .crud import ( create_pay_link, get_pay_link, @@ -22,11 +22,11 @@ async def api_links(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids try: return ( - jsonify([{**link._asdict(), **{"lnurl": link.lnurl}} for link in get_pay_links(wallet_ids)]), + jsonify([{**link._asdict(), **{"lnurl": link.lnurl}} for link in await get_pay_links(wallet_ids)]), HTTPStatus.OK, ) except LnurlInvalidUrl: @@ -39,7 +39,7 @@ async def api_links(): @lnurlp_ext.route("/api/v1/links/", methods=["GET"]) @api_check_wallet_key("invoice") async def api_link_retrieve(link_id): - link = get_pay_link(link_id) + link = await get_pay_link(link_id) if not link: return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND @@ -75,7 +75,7 @@ async def api_link_create_or_update(link_id=None): return jsonify({"message": "Must use full satoshis."}), HTTPStatus.BAD_REQUEST if link_id: - link = get_pay_link(link_id) + link = await get_pay_link(link_id) if not link: return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND @@ -83,9 +83,9 @@ async def api_link_create_or_update(link_id=None): if link.wallet != g.wallet.id: return jsonify({"message": "Not your pay link."}), HTTPStatus.FORBIDDEN - link = update_pay_link(link_id, **g.data) + link = await update_pay_link(link_id, **g.data) else: - link = create_pay_link(wallet_id=g.wallet.id, **g.data) + link = await create_pay_link(wallet_id=g.wallet.id, **g.data) return jsonify({**link._asdict(), **{"lnurl": link.lnurl}}), HTTPStatus.OK if link_id else HTTPStatus.CREATED @@ -93,7 +93,7 @@ async def api_link_create_or_update(link_id=None): @lnurlp_ext.route("/api/v1/links/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_link_delete(link_id): - link = get_pay_link(link_id) + link = await get_pay_link(link_id) if not link: return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND @@ -101,7 +101,7 @@ async def api_link_delete(link_id): if link.wallet != g.wallet.id: return jsonify({"message": "Not your pay link."}), HTTPStatus.FORBIDDEN - delete_pay_link(link_id) + await delete_pay_link(link_id) return "", HTTPStatus.NO_CONTENT diff --git a/lnbits/extensions/paywall/__init__.py b/lnbits/extensions/paywall/__init__.py index a83e4f89..8ab130cf 100644 --- a/lnbits/extensions/paywall/__init__.py +++ b/lnbits/extensions/paywall/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_paywall") paywall_ext: Blueprint = Blueprint("paywall", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/paywall/crud.py b/lnbits/extensions/paywall/crud.py index 532f6438..6c2338e2 100644 --- a/lnbits/extensions/paywall/crud.py +++ b/lnbits/extensions/paywall/crud.py @@ -1,45 +1,43 @@ from typing import List, Optional, Union -from lnbits.db import open_ext_db from lnbits.helpers import urlsafe_short_hash +from . import db from .models import Paywall -def create_paywall( +async def create_paywall( *, wallet_id: str, url: str, memo: str, description: Optional[str] = None, amount: int = 0, remembers: bool = True ) -> Paywall: - with open_ext_db("paywall") as db: - paywall_id = urlsafe_short_hash() - db.execute( - """ - INSERT INTO paywalls (id, wallet, url, memo, description, amount, remembers) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - (paywall_id, wallet_id, url, memo, description, amount, int(remembers)), - ) + paywall_id = urlsafe_short_hash() + await db.execute( + """ + INSERT INTO paywalls (id, wallet, url, memo, description, amount, remembers) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + (paywall_id, wallet_id, url, memo, description, amount, int(remembers)), + ) - return get_paywall(paywall_id) + paywall = await get_paywall(paywall_id) + assert paywall, "Newly created paywall couldn't be retrieved" + return paywall -def get_paywall(paywall_id: str) -> Optional[Paywall]: - with open_ext_db("paywall") as db: - row = db.fetchone("SELECT * FROM paywalls WHERE id = ?", (paywall_id,)) +async def get_paywall(paywall_id: str) -> Optional[Paywall]: + row = await db.fetchone("SELECT * FROM paywalls WHERE id = ?", (paywall_id,)) return Paywall.from_row(row) if row else None -def get_paywalls(wallet_ids: Union[str, List[str]]) -> List[Paywall]: +async def get_paywalls(wallet_ids: Union[str, List[str]]) -> List[Paywall]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("paywall") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM paywalls WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM paywalls WHERE wallet IN ({q})", (*wallet_ids,)) return [Paywall.from_row(row) for row in rows] -def delete_paywall(paywall_id: str) -> None: - with open_ext_db("paywall") as db: - db.execute("DELETE FROM paywalls WHERE id = ?", (paywall_id,)) +async def delete_paywall(paywall_id: str) -> None: + await db.execute("DELETE FROM paywalls WHERE id = ?", (paywall_id,)) diff --git a/lnbits/extensions/paywall/migrations.py b/lnbits/extensions/paywall/migrations.py index d1b6a3a9..fd1dd5ec 100644 --- a/lnbits/extensions/paywall/migrations.py +++ b/lnbits/extensions/paywall/migrations.py @@ -1,11 +1,11 @@ -from sqlite3 import OperationalError +from sqlalchemy.exc import OperationalError # type: ignore -def m001_initial(db): +async def m001_initial(db): """ Initial paywalls table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS paywalls ( id TEXT PRIMARY KEY, @@ -20,16 +20,16 @@ def m001_initial(db): ) -def m002_redux(db): +async def m002_redux(db): """ Creates an improved paywalls table and migrates the existing data. """ try: - db.execute("SELECT remembers FROM paywalls") + await db.execute("SELECT remembers FROM paywalls") except OperationalError: - db.execute("ALTER TABLE paywalls RENAME TO paywalls_old") - db.execute( + await db.execute("ALTER TABLE paywalls RENAME TO paywalls_old") + await db.execute( """ CREATE TABLE IF NOT EXISTS paywalls ( id TEXT PRIMARY KEY, @@ -44,10 +44,10 @@ def m002_redux(db): ); """ ) - db.execute("CREATE INDEX IF NOT EXISTS wallet_idx ON paywalls (wallet)") + await db.execute("CREATE INDEX IF NOT EXISTS wallet_idx ON paywalls (wallet)") - for row in [list(row) for row in db.fetchall("SELECT * FROM paywalls_old")]: - db.execute( + for row in [list(row) for row in await db.fetchall("SELECT * FROM paywalls_old")]: + await db.execute( """ INSERT INTO paywalls ( id, @@ -62,4 +62,4 @@ def m002_redux(db): (row[0], row[1], row[3], row[4], row[5], row[6]), ) - db.execute("DROP TABLE paywalls_old") + await db.execute("DROP TABLE paywalls_old") diff --git a/lnbits/extensions/paywall/views.py b/lnbits/extensions/paywall/views.py index bd5c388c..7373d5c4 100644 --- a/lnbits/extensions/paywall/views.py +++ b/lnbits/extensions/paywall/views.py @@ -3,7 +3,7 @@ from http import HTTPStatus from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.paywall import paywall_ext +from . import paywall_ext from .crud import get_paywall @@ -16,5 +16,5 @@ async def index(): @paywall_ext.route("/") async def display(paywall_id): - paywall = get_paywall(paywall_id) or abort(HTTPStatus.NOT_FOUND, "Paywall does not exist.") + paywall = await get_paywall(paywall_id) or abort(HTTPStatus.NOT_FOUND, "Paywall does not exist.") return await render_template("paywall/display.html", paywall=paywall) diff --git a/lnbits/extensions/paywall/views_api.py b/lnbits/extensions/paywall/views_api.py index 0a00fd9e..c2a2c62a 100644 --- a/lnbits/extensions/paywall/views_api.py +++ b/lnbits/extensions/paywall/views_api.py @@ -5,7 +5,7 @@ from lnbits.core.crud import get_user, get_wallet from lnbits.core.services import create_invoice, check_invoice_status from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.extensions.paywall import paywall_ext +from . import paywall_ext from .crud import create_paywall, get_paywall, get_paywalls, delete_paywall @@ -15,9 +15,9 @@ async def api_paywalls(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids - return jsonify([paywall._asdict() for paywall in get_paywalls(wallet_ids)]), HTTPStatus.OK + return jsonify([paywall._asdict() for paywall in await get_paywalls(wallet_ids)]), HTTPStatus.OK @paywall_ext.route("/api/v1/paywalls", methods=["POST"]) @@ -32,15 +32,14 @@ async def api_paywalls(): } ) async def api_paywall_create(): - paywall = create_paywall(wallet_id=g.wallet.id, **g.data) - + paywall = await create_paywall(wallet_id=g.wallet.id, **g.data) return jsonify(paywall._asdict()), HTTPStatus.CREATED @paywall_ext.route("/api/v1/paywalls/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_paywall_delete(paywall_id): - paywall = get_paywall(paywall_id) + paywall = await get_paywall(paywall_id) if not paywall: return jsonify({"message": "Paywall does not exist."}), HTTPStatus.NOT_FOUND @@ -48,7 +47,7 @@ async def api_paywall_delete(paywall_id): if paywall.wallet != g.wallet.id: return jsonify({"message": "Not your paywall."}), HTTPStatus.FORBIDDEN - delete_paywall(paywall_id) + await delete_paywall(paywall_id) return "", HTTPStatus.NO_CONTENT @@ -56,14 +55,14 @@ async def api_paywall_delete(paywall_id): @paywall_ext.route("/api/v1/paywalls//invoice", methods=["POST"]) @api_validate_post_request(schema={"amount": {"type": "integer", "min": 1, "required": True}}) async def api_paywall_create_invoice(paywall_id): - paywall = get_paywall(paywall_id) + paywall = await get_paywall(paywall_id) if g.data["amount"] < paywall.amount: return jsonify({"message": f"Minimum amount is {paywall.amount} sat."}), HTTPStatus.BAD_REQUEST try: amount = g.data["amount"] if g.data["amount"] > paywall.amount else paywall.amount - payment_hash, payment_request = create_invoice( + payment_hash, payment_request = await create_invoice( wallet_id=paywall.wallet, amount=amount, memo=f"{paywall.memo}", extra={"tag": "paywall"} ) except Exception as e: @@ -75,20 +74,21 @@ async def api_paywall_create_invoice(paywall_id): @paywall_ext.route("/api/v1/paywalls//check_invoice", methods=["POST"]) @api_validate_post_request(schema={"payment_hash": {"type": "string", "empty": False, "required": True}}) async def api_paywal_check_invoice(paywall_id): - paywall = get_paywall(paywall_id) + paywall = await get_paywall(paywall_id) if not paywall: return jsonify({"message": "Paywall does not exist."}), HTTPStatus.NOT_FOUND try: - is_paid = not check_invoice_status(paywall.wallet, g.data["payment_hash"]).pending + status = await check_invoice_status(paywall.wallet, g.data["payment_hash"]) + is_paid = not status.pending except Exception: return jsonify({"paid": False}), HTTPStatus.OK if is_paid: - wallet = get_wallet(paywall.wallet) - payment = wallet.get_payment(g.data["payment_hash"]) - payment.set_pending(False) + wallet = await get_wallet(paywall.wallet) + payment = await wallet.get_payment(g.data["payment_hash"]) + await payment.set_pending(False) return jsonify({"paid": True, "url": paywall.url, "remembers": paywall.remembers}), HTTPStatus.OK diff --git a/lnbits/extensions/tpos/__init__.py b/lnbits/extensions/tpos/__init__.py index 1b645bdf..78732d86 100644 --- a/lnbits/extensions/tpos/__init__.py +++ b/lnbits/extensions/tpos/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_tpos") tpos_ext: Blueprint = Blueprint("tpos", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/tpos/crud.py b/lnbits/extensions/tpos/crud.py index 6d11576b..69f70730 100644 --- a/lnbits/extensions/tpos/crud.py +++ b/lnbits/extensions/tpos/crud.py @@ -1,43 +1,40 @@ from typing import List, Optional, Union -from lnbits.db import open_ext_db from lnbits.helpers import urlsafe_short_hash +from . import db from .models import TPoS -def create_tpos(*, wallet_id: str, name: str, currency: str) -> TPoS: - with open_ext_db("tpos") as db: - tpos_id = urlsafe_short_hash() - db.execute( - """ - INSERT INTO tposs (id, wallet, name, currency) - VALUES (?, ?, ?, ?) - """, - (tpos_id, wallet_id, name, currency), - ) +async def create_tpos(*, wallet_id: str, name: str, currency: str) -> TPoS: + tpos_id = urlsafe_short_hash() + await db.execute( + """ + INSERT INTO tposs (id, wallet, name, currency) + VALUES (?, ?, ?, ?) + """, + (tpos_id, wallet_id, name, currency), + ) - return get_tpos(tpos_id) + tpos = await get_tpos(tpos_id) + assert tpos, "Newly created tpos couldn't be retrieved" + return tpos -def get_tpos(tpos_id: str) -> Optional[TPoS]: - with open_ext_db("tpos") as db: - row = db.fetchone("SELECT * FROM tposs WHERE id = ?", (tpos_id,)) - +async def get_tpos(tpos_id: str) -> Optional[TPoS]: + row = await db.fetchone("SELECT * FROM tposs WHERE id = ?", (tpos_id,)) return TPoS.from_row(row) if row else None -def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: +async def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("tpos") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM tposs WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM tposs WHERE wallet IN ({q})", (*wallet_ids,)) return [TPoS.from_row(row) for row in rows] -def delete_tpos(tpos_id: str) -> None: - with open_ext_db("tpos") as db: - db.execute("DELETE FROM tposs WHERE id = ?", (tpos_id,)) +async def delete_tpos(tpos_id: str) -> None: + await db.execute("DELETE FROM tposs WHERE id = ?", (tpos_id,)) diff --git a/lnbits/extensions/tpos/migrations.py b/lnbits/extensions/tpos/migrations.py index 1a03ed25..243ebe0b 100644 --- a/lnbits/extensions/tpos/migrations.py +++ b/lnbits/extensions/tpos/migrations.py @@ -1,8 +1,8 @@ -def m001_initial(db): +async def m001_initial(db): """ Initial tposs table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS tposs ( id TEXT PRIMARY KEY, diff --git a/lnbits/extensions/tpos/views.py b/lnbits/extensions/tpos/views.py index 13d0499e..ce842295 100644 --- a/lnbits/extensions/tpos/views.py +++ b/lnbits/extensions/tpos/views.py @@ -3,7 +3,7 @@ from http import HTTPStatus from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.tpos import tpos_ext +from . import tpos_ext from .crud import get_tpos @@ -16,6 +16,8 @@ async def index(): @tpos_ext.route("/") async def tpos(tpos_id): - tpos = get_tpos(tpos_id) or abort(HTTPStatus.NOT_FOUND, "TPoS does not exist.") + tpos = await get_tpos(tpos_id) + if not tpos: + abort(HTTPStatus.NOT_FOUND, "TPoS does not exist.") return await render_template("tpos/tpos.html", tpos=tpos) diff --git a/lnbits/extensions/tpos/views_api.py b/lnbits/extensions/tpos/views_api.py index 6041f6cf..717beaf3 100644 --- a/lnbits/extensions/tpos/views_api.py +++ b/lnbits/extensions/tpos/views_api.py @@ -5,7 +5,7 @@ from lnbits.core.crud import get_user, get_wallet from lnbits.core.services import create_invoice, check_invoice_status from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.extensions.tpos import tpos_ext +from . import tpos_ext from .crud import create_tpos, get_tpos, get_tposs, delete_tpos @@ -15,9 +15,9 @@ async def api_tposs(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = await get_user(g.wallet.user).wallet_ids - return jsonify([tpos._asdict() for tpos in get_tposs(wallet_ids)]), HTTPStatus.OK + return jsonify([tpos._asdict() for tpos in await get_tposs(wallet_ids)]), HTTPStatus.OK @tpos_ext.route("/api/v1/tposs", methods=["POST"]) @@ -29,15 +29,14 @@ async def api_tposs(): } ) async def api_tpos_create(): - tpos = create_tpos(wallet_id=g.wallet.id, **g.data) - + tpos = await create_tpos(wallet_id=g.wallet.id, **g.data) return jsonify(tpos._asdict()), HTTPStatus.CREATED @tpos_ext.route("/api/v1/tposs/", methods=["DELETE"]) @api_check_wallet_key("admin") async def api_tpos_delete(tpos_id): - tpos = get_tpos(tpos_id) + tpos = await get_tpos(tpos_id) if not tpos: return jsonify({"message": "TPoS does not exist."}), HTTPStatus.NOT_FOUND @@ -45,7 +44,7 @@ async def api_tpos_delete(tpos_id): if tpos.wallet != g.wallet.id: return jsonify({"message": "Not your TPoS."}), HTTPStatus.FORBIDDEN - delete_tpos(tpos_id) + await delete_tpos(tpos_id) return "", HTTPStatus.NO_CONTENT @@ -53,13 +52,13 @@ async def api_tpos_delete(tpos_id): @tpos_ext.route("/api/v1/tposs//invoices/", methods=["POST"]) @api_validate_post_request(schema={"amount": {"type": "integer", "min": 1, "required": True}}) async def api_tpos_create_invoice(tpos_id): - tpos = get_tpos(tpos_id) + tpos = await get_tpos(tpos_id) if not tpos: return jsonify({"message": "TPoS does not exist."}), HTTPStatus.NOT_FOUND try: - payment_hash, payment_request = create_invoice( + payment_hash, payment_request = await create_invoice( wallet_id=tpos.wallet, amount=g.data["amount"], memo=f"{tpos.name}", extra={"tag": "tpos"} ) except Exception as e: @@ -70,21 +69,22 @@ async def api_tpos_create_invoice(tpos_id): @tpos_ext.route("/api/v1/tposs//invoices/", methods=["GET"]) async def api_tpos_check_invoice(tpos_id, payment_hash): - tpos = get_tpos(tpos_id) + tpos = await get_tpos(tpos_id) if not tpos: return jsonify({"message": "TPoS does not exist."}), HTTPStatus.NOT_FOUND try: - is_paid = not check_invoice_status(tpos.wallet, payment_hash).pending + status = await check_invoice_status(tpos.wallet, payment_hash) + is_paid = not status.pending except Exception as exc: print(exc) return jsonify({"paid": False}), HTTPStatus.OK if is_paid: - wallet = get_wallet(tpos.wallet) - payment = wallet.get_payment(payment_hash) - payment.set_pending(False) + wallet = await get_wallet(tpos.wallet) + payment = await wallet.get_payment(payment_hash) + await payment.set_pending(False) return jsonify({"paid": True}), HTTPStatus.OK diff --git a/lnbits/extensions/usermanager/__init__.py b/lnbits/extensions/usermanager/__init__.py index 60adfb94..2bdbf0b5 100644 --- a/lnbits/extensions/usermanager/__init__.py +++ b/lnbits/extensions/usermanager/__init__.py @@ -1,5 +1,7 @@ from quart import Blueprint +from lnbits.db import Database +db = Database("ext_usermanager") usermanager_ext: Blueprint = Blueprint("usermanager", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/usermanager/crud.py b/lnbits/extensions/usermanager/crud.py index db9556b7..db41b8dd 100644 --- a/lnbits/extensions/usermanager/crud.py +++ b/lnbits/extensions/usermanager/crud.py @@ -1,8 +1,7 @@ -from lnbits.db import open_ext_db -from .models import Users, Wallets -from typing import Optional +from typing import Optional, List -from ...core.crud import ( +from lnbits.core.models import Payment +from lnbits.core.crud import ( create_account, get_user, get_wallet_payments, @@ -10,106 +9,91 @@ from ...core.crud import ( delete_wallet, ) - -###Users +from . import db +from .models import Users, Wallets -def create_usermanager_user(user_name: str, wallet_name: str, admin_id: str) -> Users: - user = get_user(create_account().id) - - wallet = create_wallet(user_id=user.id, wallet_name=wallet_name) - - with open_ext_db("usermanager") as db: - db.execute( - """ - INSERT INTO users (id, name, admin) - VALUES (?, ?, ?) - """, - (user.id, user_name, admin_id), - ) - - db.execute( - """ - INSERT INTO wallets (id, admin, name, user, adminkey, inkey) - VALUES (?, ?, ?, ?, ?, ?) - """, - (wallet.id, admin_id, wallet_name, user.id, wallet.adminkey, wallet.inkey), - ) - - return get_usermanager_user(user.id) +### Users -def get_usermanager_user(user_id: str) -> Users: - with open_ext_db("usermanager") as db: +async def create_usermanager_user(user_name: str, wallet_name: str, admin_id: str) -> Users: + account = await create_account() + user = await get_user(account.id) + assert user, "Newly created user couldn't be retrieved" - row = db.fetchone("SELECT * FROM users WHERE id = ?", (user_id,)) + wallet = await create_wallet(user_id=user.id, wallet_name=wallet_name) + await db.execute( + """ + INSERT INTO users (id, name, admin) + VALUES (?, ?, ?) + """, + (user.id, user_name, admin_id), + ) + + await db.execute( + """ + INSERT INTO wallets (id, admin, name, user, adminkey, inkey) + VALUES (?, ?, ?, ?, ?, ?) + """, + (wallet.id, admin_id, wallet_name, user.id, wallet.adminkey, wallet.inkey), + ) + + user_created = await get_usermanager_user(user.id) + assert user_created, "Newly created user couldn't be retrieved" + return user_created + + +async def get_usermanager_user(user_id: str) -> Optional[Users]: + row = await db.fetchone("SELECT * FROM users WHERE id = ?", (user_id,)) return Users(**row) if row else None -def get_usermanager_users(user_id: str) -> Users: - - with open_ext_db("usermanager") as db: - rows = db.fetchall("SELECT * FROM users WHERE admin = ?", (user_id,)) - +async def get_usermanager_users(user_id: str) -> List[Users]: + rows = await db.fetchall("SELECT * FROM users WHERE admin = ?", (user_id,)) return [Users(**row) for row in rows] -def delete_usermanager_user(user_id: str) -> None: - row = get_usermanager_wallets(user_id) - print("test") - with open_ext_db("usermanager") as db: - db.execute("DELETE FROM users WHERE id = ?", (user_id,)) - row - for r in row: - delete_wallet(user_id=user_id, wallet_id=r.id) - with open_ext_db("usermanager") as dbb: - dbb.execute("DELETE FROM wallets WHERE user = ?", (user_id,)) +async def delete_usermanager_user(user_id: str) -> None: + wallets = await get_usermanager_wallets(user_id) + for wallet in wallets: + await delete_wallet(user_id=user_id, wallet_id=wallet.id) + + await db.execute("DELETE FROM users WHERE id = ?", (user_id,)) + await db.execute("DELETE FROM wallets WHERE user = ?", (user_id,)) -###Wallets +### Wallets -def create_usermanager_wallet(user_id: str, wallet_name: str, admin_id: str) -> Wallets: - wallet = create_wallet(user_id=user_id, wallet_name=wallet_name) - with open_ext_db("usermanager") as db: - - db.execute( - """ - INSERT INTO wallets (id, admin, name, user, adminkey, inkey) - VALUES (?, ?, ?, ?, ?, ?) - """, - (wallet.id, admin_id, wallet_name, user_id, wallet.adminkey, wallet.inkey), - ) - - return get_usermanager_wallet(wallet.id) +async def create_usermanager_wallet(user_id: str, wallet_name: str, admin_id: str) -> Wallets: + wallet = await create_wallet(user_id=user_id, wallet_name=wallet_name) + await db.execute( + """ + INSERT INTO wallets (id, admin, name, user, adminkey, inkey) + VALUES (?, ?, ?, ?, ?, ?) + """, + (wallet.id, admin_id, wallet_name, user_id, wallet.adminkey, wallet.inkey), + ) + wallet_created = await get_usermanager_wallet(wallet.id) + assert wallet_created, "Newly created wallet couldn't be retrieved" + return wallet_created -def get_usermanager_wallet(wallet_id: str) -> Optional[Wallets]: - with open_ext_db("usermanager") as db: - row = db.fetchone("SELECT * FROM wallets WHERE id = ?", (wallet_id,)) - +async def get_usermanager_wallet(wallet_id: str) -> Optional[Wallets]: + row = await db.fetchone("SELECT * FROM wallets WHERE id = ?", (wallet_id,)) return Wallets(**row) if row else None -def get_usermanager_wallets(user_id: str) -> Wallets: - - with open_ext_db("usermanager") as db: - rows = db.fetchall("SELECT * FROM wallets WHERE admin = ?", (user_id,)) - +async def get_usermanager_wallets(user_id: str) -> List[Wallets]: + rows = await db.fetchall("SELECT * FROM wallets WHERE admin = ?", (user_id,)) return [Wallets(**row) for row in rows] -def get_usermanager_wallet_transactions(wallet_id: str) -> Users: - return get_wallet_payments(wallet_id=wallet_id, complete=True, pending=False, outgoing=True, incoming=True) +async def get_usermanager_wallet_transactions(wallet_id: str) -> List[Payment]: + return await get_wallet_payments(wallet_id=wallet_id, complete=True, pending=False, outgoing=True, incoming=True) -def get_usermanager_wallet_balances(user_id: str) -> Users: - user = get_user(user_id) - return user.wallets - - -def delete_usermanager_wallet(wallet_id: str, user_id: str) -> None: - delete_wallet(user_id=user_id, wallet_id=wallet_id) - with open_ext_db("usermanager") as db: - db.execute("DELETE FROM wallets WHERE id = ?", (wallet_id,)) +async def delete_usermanager_wallet(wallet_id: str, user_id: str) -> None: + await delete_wallet(user_id=user_id, wallet_id=wallet_id) + await db.execute("DELETE FROM wallets WHERE id = ?", (wallet_id,)) diff --git a/lnbits/extensions/usermanager/migrations.py b/lnbits/extensions/usermanager/migrations.py index faff3b88..9b60fa66 100644 --- a/lnbits/extensions/usermanager/migrations.py +++ b/lnbits/extensions/usermanager/migrations.py @@ -1,8 +1,8 @@ -def m001_initial(db): +async def m001_initial(db): """ Initial users table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, @@ -17,7 +17,7 @@ def m001_initial(db): """ Initial wallets table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS wallets ( id TEXT PRIMARY KEY, diff --git a/lnbits/extensions/usermanager/views.py b/lnbits/extensions/usermanager/views.py index 713fb32f..df6949c6 100644 --- a/lnbits/extensions/usermanager/views.py +++ b/lnbits/extensions/usermanager/views.py @@ -1,6 +1,8 @@ from quart import g, render_template + from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.usermanager import usermanager_ext + +from . import usermanager_ext @usermanager_ext.route("/") diff --git a/lnbits/extensions/usermanager/views_api.py b/lnbits/extensions/usermanager/views_api.py index 288fc04d..557aa8b9 100644 --- a/lnbits/extensions/usermanager/views_api.py +++ b/lnbits/extensions/usermanager/views_api.py @@ -4,13 +4,12 @@ from http import HTTPStatus from lnbits.core.crud import get_user from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.extensions.usermanager import usermanager_ext +from . import usermanager_ext from .crud import ( create_usermanager_user, get_usermanager_user, get_usermanager_users, get_usermanager_wallet_transactions, - get_usermanager_wallet_balances, delete_usermanager_user, create_usermanager_wallet, get_usermanager_wallet, @@ -27,7 +26,7 @@ from lnbits.core import update_user_extension @api_check_wallet_key(key_type="invoice") async def api_usermanager_users(): user_id = g.wallet.user - return jsonify([user._asdict() for user in get_usermanager_users(user_id)]), HTTPStatus.OK + return jsonify([user._asdict() for user in await get_usermanager_users(user_id)]), HTTPStatus.OK @usermanager_ext.route("/api/v1/users", methods=["POST"]) @@ -40,17 +39,17 @@ async def api_usermanager_users(): } ) async def api_usermanager_users_create(): - user = create_usermanager_user(g.data["user_name"], g.data["wallet_name"], g.data["admin_id"]) + user = await create_usermanager_user(g.data["user_name"], g.data["wallet_name"], g.data["admin_id"]) return jsonify(user._asdict()), HTTPStatus.CREATED @usermanager_ext.route("/api/v1/users/", methods=["DELETE"]) @api_check_wallet_key(key_type="invoice") async def api_usermanager_users_delete(user_id): - user = get_usermanager_user(user_id) + user = await get_usermanager_user(user_id) if not user: return jsonify({"message": "User does not exist."}), HTTPStatus.NOT_FOUND - delete_usermanager_user(user_id) + await delete_usermanager_user(user_id) return "", HTTPStatus.NO_CONTENT @@ -67,7 +66,7 @@ async def api_usermanager_users_delete(user_id): } ) async def api_usermanager_activate_extension(): - user = get_user(g.data["userid"]) + user = await get_user(g.data["userid"]) if not user: return jsonify({"message": "no such user"}), HTTPStatus.NOT_FOUND update_user_extension(user_id=g.data["userid"], extension=g.data["extension"], active=g.data["active"]) @@ -81,7 +80,7 @@ async def api_usermanager_activate_extension(): @api_check_wallet_key(key_type="invoice") async def api_usermanager_wallets(): user_id = g.wallet.user - return jsonify([wallet._asdict() for wallet in get_usermanager_wallets(user_id)]), HTTPStatus.OK + return jsonify([wallet._asdict() for wallet in await get_usermanager_wallets(user_id)]), HTTPStatus.OK @usermanager_ext.route("/api/v1/wallets", methods=["POST"]) @@ -94,31 +93,28 @@ async def api_usermanager_wallets(): } ) async def api_usermanager_wallets_create(): - user = create_usermanager_wallet(g.data["user_id"], g.data["wallet_name"], g.data["admin_id"]) + user = await create_usermanager_wallet(g.data["user_id"], g.data["wallet_name"], g.data["admin_id"]) return jsonify(user._asdict()), HTTPStatus.CREATED @usermanager_ext.route("/api/v1/wallets", methods=["GET"]) @api_check_wallet_key(key_type="invoice") async def api_usermanager_wallet_transactions(wallet_id): - - return jsonify(get_usermanager_wallet_transactions(wallet_id)), HTTPStatus.OK + return jsonify(await get_usermanager_wallet_transactions(wallet_id)), HTTPStatus.OK @usermanager_ext.route("/api/v1/wallets/", methods=["GET"]) @api_check_wallet_key(key_type="invoice") -async def api_usermanager_wallet_balances(user_id): - return jsonify(get_usermanager_wallet_balances(user_id)), HTTPStatus.OK +async def api_usermanager_wallet(user_id): + return jsonify(await get_usermanager_wallets(user_id)), HTTPStatus.OK @usermanager_ext.route("/api/v1/wallets/", methods=["DELETE"]) @api_check_wallet_key(key_type="invoice") async def api_usermanager_wallets_delete(wallet_id): - wallet = get_usermanager_wallet(wallet_id) - print(wallet.id) + wallet = await get_usermanager_wallet(wallet_id) if not wallet: return jsonify({"message": "Wallet does not exist."}), HTTPStatus.NOT_FOUND - delete_usermanager_wallet(wallet_id, wallet.user) - + await delete_usermanager_wallet(wallet_id, wallet.user) return "", HTTPStatus.NO_CONTENT diff --git a/lnbits/extensions/withdraw/__init__.py b/lnbits/extensions/withdraw/__init__.py index ce5970ee..38e8f1b3 100644 --- a/lnbits/extensions/withdraw/__init__.py +++ b/lnbits/extensions/withdraw/__init__.py @@ -1,4 +1,7 @@ from quart import Blueprint +from lnbits.db import Database + +db = Database("ext_withdraw") withdraw_ext: Blueprint = Blueprint("withdraw", __name__, static_folder="static", template_folder="templates") diff --git a/lnbits/extensions/withdraw/crud.py b/lnbits/extensions/withdraw/crud.py index 85c884e8..78fd7f56 100644 --- a/lnbits/extensions/withdraw/crud.py +++ b/lnbits/extensions/withdraw/crud.py @@ -1,12 +1,12 @@ from datetime import datetime from typing import List, Optional, Union -from lnbits.db import open_ext_db from lnbits.helpers import urlsafe_short_hash +from . import db from .models import WithdrawLink -def create_withdraw_link( +async def create_withdraw_link( *, wallet_id: str, title: str, @@ -17,49 +17,47 @@ def create_withdraw_link( is_unique: bool, usescsv: str, ) -> WithdrawLink: - - with open_ext_db("withdraw") as db: - - link_id = urlsafe_short_hash() - db.execute( - """ - INSERT INTO withdraw_link ( - id, - wallet, - title, - min_withdrawable, - max_withdrawable, - uses, - wait_time, - is_unique, - unique_hash, - k1, - open_time, - usescsv - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - link_id, - wallet_id, - title, - min_withdrawable, - max_withdrawable, - uses, - wait_time, - int(is_unique), - urlsafe_short_hash(), - urlsafe_short_hash(), - int(datetime.now().timestamp()) + wait_time, - usescsv, - ), + link_id = urlsafe_short_hash() + await db.execute( + """ + INSERT INTO withdraw_link ( + id, + wallet, + title, + min_withdrawable, + max_withdrawable, + uses, + wait_time, + is_unique, + unique_hash, + k1, + open_time, + usescsv ) - return get_withdraw_link(link_id, 0) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + link_id, + wallet_id, + title, + min_withdrawable, + max_withdrawable, + uses, + wait_time, + int(is_unique), + urlsafe_short_hash(), + urlsafe_short_hash(), + int(datetime.now().timestamp()) + wait_time, + usescsv, + ), + ) + link = await get_withdraw_link(link_id, 0) + assert link, "Newly created link couldn't be retrieved" + return link -def get_withdraw_link(link_id: str, num=0) -> Optional[WithdrawLink]: - with open_ext_db("withdraw") as db: - row = db.fetchone("SELECT * FROM withdraw_link WHERE id = ?", (link_id,)) +async def get_withdraw_link(link_id: str, num=0) -> Optional[WithdrawLink]: + row = await db.fetchone("SELECT * FROM withdraw_link WHERE id = ?", (link_id,)) link = [] for item in row: link.append(item) @@ -67,39 +65,34 @@ def get_withdraw_link(link_id: str, num=0) -> Optional[WithdrawLink]: return WithdrawLink._make(link) -def get_withdraw_link_by_hash(unique_hash: str, num=0) -> Optional[WithdrawLink]: - with open_ext_db("withdraw") as db: - row = db.fetchone("SELECT * FROM withdraw_link WHERE unique_hash = ?", (unique_hash,)) - link = [] - for item in row: - link.append(item) +async def get_withdraw_link_by_hash(unique_hash: str, num=0) -> Optional[WithdrawLink]: + row = await db.fetchone("SELECT * FROM withdraw_link WHERE unique_hash = ?", (unique_hash,)) + link = [] + for item in row: + link.append(item) link.append(num) return WithdrawLink._make(link) -def get_withdraw_links(wallet_ids: Union[str, List[str]]) -> List[WithdrawLink]: +async def get_withdraw_links(wallet_ids: Union[str, List[str]]) -> List[WithdrawLink]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("withdraw") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM withdraw_link WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await db.fetchall(f"SELECT * FROM withdraw_link WHERE wallet IN ({q})", (*wallet_ids,)) return [WithdrawLink.from_row(row) for row in rows] -def update_withdraw_link(link_id: str, **kwargs) -> Optional[WithdrawLink]: +async def update_withdraw_link(link_id: str, **kwargs) -> Optional[WithdrawLink]: q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()]) - with open_ext_db("withdraw") as db: - db.execute(f"UPDATE withdraw_link SET {q} WHERE id = ?", (*kwargs.values(), link_id)) - row = db.fetchone("SELECT * FROM withdraw_link WHERE id = ?", (link_id,)) - + await db.execute(f"UPDATE withdraw_link SET {q} WHERE id = ?", (*kwargs.values(), link_id)) + row = await db.fetchone("SELECT * FROM withdraw_link WHERE id = ?", (link_id,)) return WithdrawLink.from_row(row) if row else None -def delete_withdraw_link(link_id: str) -> None: - with open_ext_db("withdraw") as db: - db.execute("DELETE FROM withdraw_link WHERE id = ?", (link_id,)) +async def delete_withdraw_link(link_id: str) -> None: + await db.execute("DELETE FROM withdraw_link WHERE id = ?", (link_id,)) def chunks(lst, n): diff --git a/lnbits/extensions/withdraw/migrations.py b/lnbits/extensions/withdraw/migrations.py index 84924ce2..4af24f8f 100644 --- a/lnbits/extensions/withdraw/migrations.py +++ b/lnbits/extensions/withdraw/migrations.py @@ -1,8 +1,8 @@ -def m001_initial(db): +async def m001_initial(db): """ Creates an improved withdraw table and migrates the existing data. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS withdraw_links ( id TEXT PRIMARY KEY, @@ -23,11 +23,11 @@ def m001_initial(db): ) -def m002_change_withdraw_table(db): +async def m002_change_withdraw_table(db): """ Creates an improved withdraw table and migrates the existing data. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS withdraw_link ( id TEXT PRIMARY KEY, @@ -46,10 +46,10 @@ def m002_change_withdraw_table(db): ); """ ) - db.execute("CREATE INDEX IF NOT EXISTS wallet_idx ON withdraw_link (wallet)") - db.execute("CREATE UNIQUE INDEX IF NOT EXISTS unique_hash_idx ON withdraw_link (unique_hash)") + await db.execute("CREATE INDEX IF NOT EXISTS wallet_idx ON withdraw_link (wallet)") + await db.execute("CREATE UNIQUE INDEX IF NOT EXISTS unique_hash_idx ON withdraw_link (unique_hash)") - for row in [list(row) for row in db.fetchall("SELECT * FROM withdraw_links")]: + for row in [list(row) for row in await db.fetchall("SELECT * FROM withdraw_links")]: usescsv = "" for i in range(row[5]): @@ -58,7 +58,7 @@ def m002_change_withdraw_table(db): else: usescsv += "," + str(1) usescsv = usescsv[1:] - db.execute( + await db.execute( """ INSERT INTO withdraw_link ( id, @@ -93,4 +93,4 @@ def m002_change_withdraw_table(db): usescsv, ), ) - db.execute("DROP TABLE withdraw_links") + await db.execute("DROP TABLE withdraw_links") diff --git a/lnbits/extensions/withdraw/models.py b/lnbits/extensions/withdraw/models.py index 3e55fc36..7e80a789 100644 --- a/lnbits/extensions/withdraw/models.py +++ b/lnbits/extensions/withdraw/models.py @@ -1,5 +1,5 @@ from quart import url_for -from lnurl import Lnurl, LnurlWithdrawResponse, encode as lnurl_encode +from lnurl import Lnurl, LnurlWithdrawResponse, encode as lnurl_encode # type: ignore from sqlite3 import Row from typing import NamedTuple import shortuuid # type: ignore diff --git a/lnbits/extensions/withdraw/views.py b/lnbits/extensions/withdraw/views.py index f3c48992..574cb3ad 100644 --- a/lnbits/extensions/withdraw/views.py +++ b/lnbits/extensions/withdraw/views.py @@ -3,7 +3,7 @@ from http import HTTPStatus from lnbits.decorators import check_user_exists, validate_uuids -from lnbits.extensions.withdraw import withdraw_ext +from . import withdraw_ext from .crud import get_withdraw_link, chunks @@ -16,19 +16,19 @@ async def index(): @withdraw_ext.route("/") async def display(link_id): - link = get_withdraw_link(link_id, 0) or abort(HTTPStatus.NOT_FOUND, "Withdraw link does not exist.") + link = await get_withdraw_link(link_id, 0) or abort(HTTPStatus.NOT_FOUND, "Withdraw link does not exist.") return await render_template("withdraw/display.html", link=link, unique=True) @withdraw_ext.route("/print/") async def print_qr(link_id): - link = get_withdraw_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Withdraw link does not exist.") + link = await get_withdraw_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Withdraw link does not exist.") if link.uses == 0: return await render_template("withdraw/print_qr.html", link=link, unique=False) links = [] count = 0 for x in link.usescsv.split(","): - linkk = get_withdraw_link(link_id, count) or abort(HTTPStatus.NOT_FOUND, "Withdraw link does not exist.") + linkk = await get_withdraw_link(link_id, count) or abort(HTTPStatus.NOT_FOUND, "Withdraw link does not exist.") links.append(str(linkk.lnurl)) count = count + 1 page_link = list(chunks(links, 2)) diff --git a/lnbits/extensions/withdraw/views_api.py b/lnbits/extensions/withdraw/views_api.py index c37a318a..8ed9bcac 100644 --- a/lnbits/extensions/withdraw/views_api.py +++ b/lnbits/extensions/withdraw/views_api.py @@ -1,14 +1,14 @@ from datetime import datetime from quart import g, jsonify, request from http import HTTPStatus -from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl +from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl # type: ignore import shortuuid # type: ignore from lnbits.core.crud import get_user from lnbits.core.services import pay_invoice from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.extensions.withdraw import withdraw_ext +from . import withdraw_ext from .crud import ( create_withdraw_link, get_withdraw_link, @@ -25,10 +25,18 @@ async def api_links(): wallet_ids = [g.wallet.id] if "all_wallets" in request.args: - wallet_ids = get_user(g.wallet.user).wallet_ids + wallet_ids = (await get_user(g.wallet.user)).wallet_ids try: return ( - jsonify([{**link._asdict(), **{"lnurl": link.lnurl}} for link in get_withdraw_links(wallet_ids)]), + jsonify( + [ + { + **link._asdict(), + **{"lnurl": link.lnurl}, + } + for link in await get_withdraw_links(wallet_ids) + ] + ), HTTPStatus.OK, ) except LnurlInvalidUrl: @@ -41,7 +49,7 @@ async def api_links(): @withdraw_ext.route("/api/v1/links/", methods=["GET"]) @api_check_wallet_key("invoice") async def api_link_retrieve(link_id): - link = get_withdraw_link(link_id, 0) + link = await get_withdraw_link(link_id, 0) if not link: return jsonify({"message": "Withdraw link does not exist."}), HTTPStatus.NOT_FOUND @@ -82,21 +90,22 @@ async def api_link_create_or_update(link_id=None): usescsv = usescsv[1:] if link_id: - link = get_withdraw_link(link_id, 0) + link = await get_withdraw_link(link_id, 0) if not link: return jsonify({"message": "Withdraw link does not exist."}), HTTPStatus.NOT_FOUND if link.wallet != g.wallet.id: return jsonify({"message": "Not your withdraw link."}), HTTPStatus.FORBIDDEN - link = update_withdraw_link(link_id, **g.data, usescsv=usescsv, used=0) + link = await update_withdraw_link(link_id, **g.data, usescsv=usescsv, used=0) else: - link = create_withdraw_link(wallet_id=g.wallet.id, **g.data, usescsv=usescsv) + link = await create_withdraw_link(wallet_id=g.wallet.id, **g.data, usescsv=usescsv) + return jsonify({**link._asdict(), **{"lnurl": link.lnurl}}), HTTPStatus.OK if link_id else HTTPStatus.CREATED @withdraw_ext.route("/api/v1/links/", methods=["DELETE"]) @api_check_wallet_key("admin") async def api_link_delete(link_id): - link = get_withdraw_link(link_id) + link = await get_withdraw_link(link_id) if not link: return jsonify({"message": "Withdraw link does not exist."}), HTTPStatus.NOT_FOUND @@ -104,7 +113,7 @@ async def api_link_delete(link_id): if link.wallet != g.wallet.id: return jsonify({"message": "Not your withdraw link."}), HTTPStatus.FORBIDDEN - delete_withdraw_link(link_id) + await delete_withdraw_link(link_id) return "", HTTPStatus.NO_CONTENT @@ -114,7 +123,7 @@ async def api_link_delete(link_id): @withdraw_ext.route("/api/v1/lnurl/", methods=["GET"]) async def api_lnurl_response(unique_hash): - link = get_withdraw_link_by_hash(unique_hash) + link = await get_withdraw_link_by_hash(unique_hash) if not link: return jsonify({"status": "ERROR", "reason": "LNURL-withdraw not found."}), HTTPStatus.OK @@ -125,7 +134,7 @@ async def api_lnurl_response(unique_hash): for x in range(1, link.uses - link.used): usescsv += "," + str(1) usescsv = usescsv[1:] - link = update_withdraw_link(link.id, used=link.used + 1, usescsv=usescsv) + link = await update_withdraw_link(link.id, used=link.used + 1, usescsv=usescsv) return jsonify(link.lnurl_response.dict()), HTTPStatus.OK @@ -135,7 +144,7 @@ async def api_lnurl_response(unique_hash): @withdraw_ext.route("/api/v1/lnurl//", methods=["GET"]) async def api_lnurl_multi_response(unique_hash, id_unique_hash): - link = get_withdraw_link_by_hash(unique_hash) + link = await get_withdraw_link_by_hash(unique_hash) if not link: return jsonify({"status": "ERROR", "reason": "LNURL-withdraw not found."}), HTTPStatus.OK @@ -156,13 +165,13 @@ async def api_lnurl_multi_response(unique_hash, id_unique_hash): return jsonify({"status": "ERROR", "reason": "LNURL-withdraw not found."}), HTTPStatus.OK usescsv = usescsv[1:] - link = update_withdraw_link(link.id, usescsv=usescsv) + link = await update_withdraw_link(link.id, usescsv=usescsv) return jsonify(link.lnurl_response.dict()), HTTPStatus.OK @withdraw_ext.route("/api/v1/lnurl/cb/", methods=["GET"]) async def api_lnurl_callback(unique_hash): - link = get_withdraw_link_by_hash(unique_hash) + link = await get_withdraw_link_by_hash(unique_hash) k1 = request.args.get("k1", type=str) payment_request = request.args.get("pr", type=str) now = int(datetime.now().timestamp()) @@ -180,7 +189,7 @@ async def api_lnurl_callback(unique_hash): return jsonify({"status": "ERROR", "reason": f"Wait {link.open_time - now} seconds."}), HTTPStatus.OK try: - pay_invoice( + await pay_invoice( wallet_id=link.wallet, payment_request=payment_request, max_sat=link.max_withdrawable, @@ -189,12 +198,10 @@ async def api_lnurl_callback(unique_hash): changes = {"open_time": link.wait_time + now, "used": link.used + 1} - update_withdraw_link(link.id, **changes) + await update_withdraw_link(link.id, **changes) except ValueError as e: return jsonify({"status": "ERROR", "reason": str(e)}), HTTPStatus.OK except PermissionError: return jsonify({"status": "ERROR", "reason": "Withdraw link is empty."}), HTTPStatus.OK - except Exception as e: - return jsonify({"status": "ERROR", "reason": str(e)}), HTTPStatus.OK return jsonify({"status": "OK"}), HTTPStatus.OK diff --git a/lnbits/tasks.py b/lnbits/tasks.py index 285b7126..3acb2a07 100644 --- a/lnbits/tasks.py +++ b/lnbits/tasks.py @@ -1,13 +1,9 @@ import trio # type: ignore from http import HTTPStatus from typing import Optional, List, Callable -from quart import Request, g from quart_trio import QuartTrio -from werkzeug.datastructures import Headers -from lnbits.db import open_db from lnbits.settings import WALLET - from lnbits.core.crud import get_standalone_payment main_app: Optional[QuartTrio] = None @@ -37,28 +33,6 @@ async def send_push_promise(a, b) -> None: pass -async def run_on_pseudo_request(func: Callable, *args): - fk = Request( - "GET", - "http", - "/background/pseudo", - b"", - Headers([("host", "lnbits.background")]), - "", - "1.1", - send_push_promise=send_push_promise, - ) - assert main_app - - async def run(): - async with main_app.request_context(fk): - with open_db() as g.db: # type: ignore - await func(*args) - - async with trio.open_nursery() as nursery: - nursery.start_soon(run) - - invoice_listeners: List[trio.MemorySendChannel] = [] @@ -81,18 +55,20 @@ internal_invoice_paid, internal_invoice_received = trio.open_memory_channel(0) async def internal_invoice_listener(): - async for checking_id in internal_invoice_received: - await run_on_pseudo_request(invoice_callback_dispatcher, checking_id) + async with trio.open_nursery() as nursery: + async for checking_id in internal_invoice_received: + nursery.start_soon(invoice_callback_dispatcher, checking_id) async def invoice_listener(): - async for checking_id in WALLET.paid_invoices_stream(): - await run_on_pseudo_request(invoice_callback_dispatcher, checking_id) + async with trio.open_nursery() as nursery: + async for checking_id in WALLET.paid_invoices_stream(): + nursery.start_soon(invoice_callback_dispatcher, checking_id) async def invoice_callback_dispatcher(checking_id: str): - payment = get_standalone_payment(checking_id) + payment = await get_standalone_payment(checking_id) if payment and payment.is_in: - payment.set_pending(False) + await payment.set_pending(False) for send_chan in invoice_listeners: await send_chan.send(payment)