diff --git a/Makefile b/Makefile index 0f4f83bb..25add4fc 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: test -all: format check requirements.txt +all: format check format: prettier isort black diff --git a/lnbits/bolt11.py b/lnbits/bolt11.py index a918cfba..84bf61e7 100644 --- a/lnbits/bolt11.py +++ b/lnbits/bolt11.py @@ -66,11 +66,12 @@ def decode(pr: str) -> Invoice: invoice.amount_msat = _unshorten_amount(amountstr) # pull out date - invoice.date = data.read(35).uint + date_bin = data.read(35) + invoice.date = date_bin.uint # type: ignore while data.pos != data.len: tag, tagdata, data = _pull_tagged(data) - data_length = len(tagdata) / 5 + data_length = len(tagdata or []) / 5 if tag == "d": invoice.description = _trim_to_bytes(tagdata).decode() @@ -79,7 +80,7 @@ def decode(pr: str) -> Invoice: elif tag == "p" and data_length == 52: invoice.payment_hash = _trim_to_bytes(tagdata).hex() elif tag == "x": - invoice.expiry = tagdata.uint + invoice.expiry = tagdata.uint # type: ignore elif tag == "n": invoice.payee = _trim_to_bytes(tagdata).hex() # this won't work in most cases, we must extract the payee @@ -90,11 +91,11 @@ def decode(pr: str) -> Invoice: s = bitstring.ConstBitStream(tagdata) while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: route = Route( - pubkey=s.read(264).tobytes().hex(), - short_channel_id=_readable_scid(s.read(64).intbe), - base_fee_msat=s.read(32).intbe, - ppm_fee=s.read(32).intbe, - cltv=s.read(16).intbe, + pubkey=s.read(264).tobytes().hex(), # type: ignore + short_channel_id=_readable_scid(s.read(64).intbe), # type: ignore + base_fee_msat=s.read(32).intbe, # type: ignore + ppm_fee=s.read(32).intbe, # type: ignore + cltv=s.read(16).intbe, # type: ignore ) invoice.route_hints.append(route) @@ -202,7 +203,8 @@ def lnencode(addr, privkey): ) data += tagged("r", route) elif k == "f": - data += encode_fallback(v, addr.currency) + # NOTE: there was an error fallback here that's now removed + continue elif k == "d": data += tagged_bytes("d", v.encode()) elif k == "x": @@ -244,7 +246,13 @@ def lnencode(addr, privkey): class LnAddr: def __init__( - self, paymenthash=None, amount=None, currency="bc", tags=None, date=None + self, + paymenthash=None, + amount=None, + currency="bc", + tags=None, + date=None, + fallback=None, ): self.date = int(time.time()) if not date else int(date) self.tags = [] if not tags else tags @@ -252,11 +260,13 @@ class LnAddr: self.paymenthash = paymenthash self.signature = None self.pubkey = None + self.fallback = fallback self.currency = currency self.amount = amount def __str__(self): - pubkey = bytes.hex(self.pubkey.serialize()).decode() + assert self.pubkey, "LnAddr, pubkey must be set" + pubkey = bytes.hex(self.pubkey.serialize()) tags = ", ".join([f"{k}={v}" for k, v in self.tags]) return f"LnAddr[{pubkey}, amount={self.amount}{self.currency} tags=[{tags}]]" @@ -266,6 +276,7 @@ def shorten_amount(amount): # Convert to pico initially amount = int(amount * 10**12) units = ["p", "n", "u", "m", ""] + unit = "" for unit in units: if amount % 1000 == 0: amount //= 1000 @@ -304,14 +315,6 @@ def _pull_tagged(stream): return (CHARSET[tag], stream.read(length * 5), stream) -def is_p2pkh(currency, prefix): - return prefix == base58_prefix_map[currency][0] - - -def is_p2sh(currency, prefix): - return prefix == base58_prefix_map[currency][1] - - # Tagged field containing BitArray def tagged(char, l): # Tagged fields need to be zero-padded to 5 bits. @@ -359,5 +362,5 @@ def bitarray_to_u5(barr): ret = [] s = bitstring.ConstBitStream(barr) while s.pos != s.len: - ret.append(s.read(5).uint) + ret.append(s.read(5).uint) # type: ignore return ret diff --git a/lnbits/commands.py b/lnbits/commands.py index f2252fee..e0f58605 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -41,6 +41,7 @@ async def migrate_databases(): """Creates the necessary databases if they don't exist already; or migrates them.""" async with core_db.connect() as conn: + exists = False if conn.type == SQLITE: exists = await conn.fetchone( "SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'" diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 1b807b78..b234b5d4 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -206,7 +206,7 @@ async def create_wallet( async def update_wallet( wallet_id: str, new_name: str, conn: Optional[Connection] = None ) -> Optional[Wallet]: - return await (conn or db).execute( + await (conn or db).execute( """ UPDATE wallets SET name = ? @@ -214,6 +214,9 @@ async def update_wallet( """, (new_name, wallet_id), ) + wallet = await get_wallet(wallet_id=wallet_id, conn=conn) + assert wallet, "updated created wallet couldn't be retrieved" + return wallet async def delete_wallet( @@ -393,7 +396,7 @@ async def get_payments( clause.append("checking_id NOT LIKE 'internal_%'") if not filters: - filters = Filters() + filters = Filters(limit=None, offset=None) rows = await (conn or db).fetchall( f""" @@ -712,15 +715,19 @@ async def update_admin_settings(data: EditableSettings): await db.execute("UPDATE settings SET editable_settings = ?", (json.dumps(data),)) -async def update_super_user(super_user: str): +async def update_super_user(super_user: str) -> SuperSettings: await db.execute("UPDATE settings SET super_user = ?", (super_user,)) - return await get_super_settings() + settings = await get_super_settings() + assert settings, "updated super_user settings could not be retrieved" + return settings async def create_admin_settings(super_user: str, new_settings: dict): sql = "INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)" await db.execute(sql, (super_user, json.dumps(new_settings))) - return await get_super_settings() + settings = await get_super_settings() + assert settings, "created admin settings could not be retrieved" + return settings # db versions diff --git a/lnbits/core/services.py b/lnbits/core/services.py index 5322de77..9b699571 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -1,7 +1,7 @@ import asyncio import json from io import BytesIO -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, TypedDict from urllib.parse import parse_qs, urlparse import httpx @@ -17,6 +17,7 @@ from lnbits.helpers import url_for from lnbits.settings import ( FAKE_WALLET, EditableSettings, + SuperSettings, get_wallet_class, readonly_variables, send_admin_user_to_saas, @@ -43,11 +44,6 @@ from .crud import ( ) from .models import Payment -try: - from typing import TypedDict -except ImportError: # pragma: nocover - from typing_extensions import TypedDict - class PaymentFailure(Exception): pass @@ -188,7 +184,7 @@ async def pay_invoice( # do the balance check wallet = await get_wallet(wallet_id, conn=conn) - assert wallet + assert wallet, "Wallet for balancecheck could not be fetched" if wallet.balance_msat < 0: logger.debug("balance is too low, deleting temporary payment") if not internal_checking_id and wallet.balance_msat > -fee_reserve_msat: @@ -336,19 +332,19 @@ async def perform_lnurlauth( return b - def encode_strict_der(r_int, s_int, order): + def encode_strict_der(r: int, s: int, order: int): # if s > order/2 verification will fail sometimes # so we must fix it here (see https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147) - if s_int > order // 2: - s_int = order - s_int + if s > order // 2: + s = order - s # now we do the strict DER encoding copied from # https://github.com/KiriKiri/bip66 (without any checks) - r = int_to_bytes_suitable_der(r_int) - s = int_to_bytes_suitable_der(s_int) + r_temp = int_to_bytes_suitable_der(r) + s_temp = int_to_bytes_suitable_der(s) - r_len = len(r) - s_len = len(s) + r_len = len(r_temp) + s_len = len(s_temp) sign_len = 6 + r_len + s_len signature = BytesIO() @@ -356,16 +352,17 @@ async def perform_lnurlauth( signature.write((sign_len - 2).to_bytes(1, "big", signed=False)) signature.write(0x02.to_bytes(1, "big", signed=False)) signature.write(r_len.to_bytes(1, "big", signed=False)) - signature.write(r) + signature.write(r_temp) signature.write(0x02.to_bytes(1, "big", signed=False)) signature.write(s_len.to_bytes(1, "big", signed=False)) - signature.write(s) + signature.write(s_temp) return signature.getvalue() sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der) async with httpx.AsyncClient() as client: + assert key.verifying_key, "LNURLauth verifying_key does not exist" r = await client.get( callback, params={ @@ -469,7 +466,7 @@ def update_cached_settings(sets_dict: dict): setattr(settings, "super_user", sets_dict["super_user"]) -async def init_admin_settings(super_user: str = None): +async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings: account = None if super_user: account = await get_account(super_user) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 408bef59..e0555b37 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -411,8 +411,7 @@ async def subscribe_wallet_invoices(request: Request, wallet: Wallet): typ, data = await send_queue.get() if data: jdata = json.dumps(dict(data.dict(), pending=False)) - - yield dict(data=jdata, event=typ) + yield dict(data=jdata, event=typ) except asyncio.CancelledError: logger.debug(f"removing listener for wallet {uid}") api_invoice_listeners.pop(uid) @@ -431,11 +430,12 @@ async def api_payments_sse( ) +# TODO: refactor this route into a public and admin one @core_app.get("/api/v1/payments/{payment_hash}") async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)): # We use X_Api_Key here because we want this call to work with and without keys # If a valid key is given, we also return the field "details", otherwise not - wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None + wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None # type: ignore # we have to specify the wallet id here, because postgres and sqlite return internal payments in different order # and get_standalone_payment otherwise just fetches the first one, causing unpredictable results @@ -505,6 +505,7 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type params.update(callback=url) # with k1 already in it lnurlauth_key = wallet.wallet.lnurlauth_key(domain) + assert lnurlauth_key.verifying_key params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex()) else: async with httpx.AsyncClient() as client: @@ -693,7 +694,7 @@ async def api_auditor(): if not error_message: delta = node_balance - total_balance else: - node_balance, delta = None, None + node_balance, delta = 0, 0 return { "node_balance_msats": int(node_balance), @@ -745,6 +746,7 @@ async def api_install_extension( raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="Release not found" ) + ext_info = InstallableExtension( id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon ) @@ -824,8 +826,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin) ) -@core_app.get("/api/v1/extension/{ext_id}/releases") -async def get_extension_releases(ext_id: str, user: User = Depends(check_admin)): +@core_app.get( + "/api/v1/extension/{ext_id}/releases", dependencies=[Depends(check_admin)] +) +async def get_extension_releases(ext_id: str): try: extension_releases: List[ ExtensionRelease diff --git a/lnbits/core/views/public_api.py b/lnbits/core/views/public_api.py index 303929fe..934fc617 100644 --- a/lnbits/core/views/public_api.py +++ b/lnbits/core/views/public_api.py @@ -40,19 +40,18 @@ async def api_public_payment_longpolling(payment_hash): response = None - async def payment_info_receiver(cancel_scope): - async for payment in payment_queue.get(): + async def payment_info_receiver(): + for payment in await payment_queue.get(): if payment.payment_hash == payment_hash: nonlocal response response = {"status": "paid"} - cancel_scope.cancel() async def timeouter(cancel_scope): await asyncio.sleep(45) cancel_scope.cancel() - asyncio.create_task(payment_info_receiver()) - asyncio.create_task(timeouter()) + cancel_scope = asyncio.create_task(payment_info_receiver()) + asyncio.create_task(timeouter(cancel_scope)) if response: return response diff --git a/lnbits/db.py b/lnbits/db.py index 3af11e36..4a6673a7 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -131,7 +131,7 @@ class Database(Compat): else: self.type = POSTGRES - import psycopg2 + from psycopg2.extensions import DECIMAL, new_type, register_type def _parse_timestamp(value, _): if value is None: @@ -141,15 +141,15 @@ class Database(Compat): f = "%Y-%m-%d %H:%M:%S" return time.mktime(datetime.datetime.strptime(value, f).timetuple()) - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - psycopg2.extensions.DECIMAL.values, + register_type( + new_type( + DECIMAL.values, "DEC2FLOAT", lambda value, curs: float(value) if value is not None else None, ) ) - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( + register_type( + new_type( (1082, 1083, 1266), "DATE2INT", lambda value, curs: time.mktime(value.timetuple()) @@ -158,11 +158,7 @@ class Database(Compat): ) ) - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - (1184, 1114), "TIMESTAMP2INT", _parse_timestamp - ) - ) + register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp)) else: if os.path.isdir(settings.lnbits_data_folder): self.path = os.path.join( @@ -189,7 +185,7 @@ class Database(Compat): async def connect(self): await self.lock.acquire() try: - async with self.engine.connect() as conn: + async with self.engine.connect() as conn: # type: ignore async with conn.begin() as txn: wconn = Connection(conn, txn, self.type, self.name, self.schema) diff --git a/lnbits/decorators.py b/lnbits/decorators.py index bd1c0520..3ced881a 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -1,14 +1,12 @@ from http import HTTPStatus from typing import Optional, Type -from fastapi import Security, status -from fastapi.exceptions import HTTPException +from fastapi import HTTPException, Request, Security, status from fastapi.openapi.models import APIKey, APIKeyIn -from fastapi.security.api_key import APIKeyHeader, APIKeyQuery +from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.security.base import SecurityBase from pydantic import BaseModel from pydantic.types import UUID4 -from starlette.requests import Request from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.models import User, Wallet @@ -17,9 +15,13 @@ from lnbits.requestvars import g from lnbits.settings import settings +# TODO: fix type ignores class KeyChecker(SecurityBase): def __init__( - self, scheme_name: str = None, auto_error: bool = True, api_key: str = None + self, + scheme_name: Optional[str] = None, + auto_error: bool = True, + api_key: Optional[str] = None, ): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error @@ -27,13 +29,13 @@ class KeyChecker(SecurityBase): self._api_key = api_key if api_key: key = APIKey( - **{"in": APIKeyIn.query}, + **{"in": APIKeyIn.query}, # type: ignore name="X-API-KEY", description="Wallet API Key - QUERY", ) else: key = APIKey( - **{"in": APIKeyIn.header}, + **{"in": APIKeyIn.header}, # type: ignore name="X-API-KEY", description="Wallet API Key - HEADER", ) @@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker): """ def __init__( - self, scheme_name: str = None, auto_error: bool = True, api_key: str = None + self, + scheme_name: Optional[str] = None, + auto_error: bool = True, + api_key: Optional[str] = None, ): super().__init__(scheme_name, auto_error, api_key) self._key_type = "invoice" @@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker): """ def __init__( - self, scheme_name: str = None, auto_error: bool = True, api_key: str = None + self, + scheme_name: Optional[str] = None, + auto_error: bool = True, + api_key: Optional[str] = None, ): super().__init__(scheme_name, auto_error, api_key) self._key_type = "admin" diff --git a/lnbits/extension_manager.py b/lnbits/extension_manager.py index 81faff10..092f2f97 100644 --- a/lnbits/extension_manager.py +++ b/lnbits/extension_manager.py @@ -3,20 +3,145 @@ import json import os import shutil import sys -import urllib.request import zipfile from http import HTTPStatus from pathlib import Path from typing import Any, List, NamedTuple, Optional, Tuple +from urllib import request import httpx -from fastapi.exceptions import HTTPException +from fastapi import HTTPException from loguru import logger from pydantic import BaseModel from lnbits.settings import settings +class ExplicitRelease(BaseModel): + id: str + name: str + version: str + archive: str + hash: str + dependencies: List[str] = [] + icon: Optional[str] + short_description: Optional[str] + html_url: Optional[str] + details: Optional[str] + info_notification: Optional[str] + critical_notification: Optional[str] + + +class GitHubRelease(BaseModel): + id: str + organisation: str + repository: str + + +class Manifest(BaseModel): + featured: List[str] = [] + extensions: List["ExplicitRelease"] = [] + repos: List["GitHubRelease"] = [] + + +class GitHubRepoRelease(BaseModel): + name: str + tag_name: str + zipball_url: str + html_url: str + + +class GitHubRepo(BaseModel): + stargazers_count: str + html_url: str + default_branch: str + + +class ExtensionConfig(BaseModel): + name: str + short_description: str + tile: str = "" + + +def download_url(url, save_path): + with request.urlopen(url) as dl_file: + with open(save_path, "wb") as out_file: + out_file.write(dl_file.read()) + + +def file_hash(filename): + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() + + +async def fetch_github_repo_info( + org: str, repository: str +) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]: + repo_url = f"https://api.github.com/repos/{org}/{repository}" + error_msg = "Cannot fetch extension repo" + repo = await gihub_api_get(repo_url, error_msg) + github_repo = GitHubRepo.parse_obj(repo) + + lates_release_url = ( + f"https://api.github.com/repos/{org}/{repository}/releases/latest" + ) + error_msg = "Cannot fetch extension releases" + latest_release: Any = await gihub_api_get(lates_release_url, error_msg) + + config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json" + error_msg = "Cannot fetch config for extension" + config = await gihub_api_get(config_url, error_msg) + + return ( + github_repo, + GitHubRepoRelease.parse_obj(latest_release), + ExtensionConfig.parse_obj(config), + ) + + +async def fetch_manifest(url) -> Manifest: + error_msg = "Cannot fetch extensions manifest" + manifest = await gihub_api_get(url, error_msg) + return Manifest.parse_obj(manifest) + + +async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]: + releases_url = f"https://api.github.com/repos/{org}/{repo}/releases" + error_msg = "Cannot fetch extension releases" + releases = await gihub_api_get(releases_url, error_msg) + return [GitHubRepoRelease.parse_obj(r) for r in releases] + + +async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any: + async with httpx.AsyncClient() as client: + headers = ( + {"Authorization": "Bearer " + settings.lnbits_ext_github_token} + if settings.lnbits_ext_github_token + else None + ) + resp = await client.get( + url, + headers=headers, + ) + if resp.status_code != 200: + logger.warning(f"{error_msg} ({url}): {resp.text}") + resp.raise_for_status() + return resp.json() + + +def icon_to_github_url(source_repo: str, path: Optional[str]) -> str: + if not path: + return "" + _, _, *rest = path.split("/") + tail = "/".join(rest) + return f"https://github.com/{source_repo}/raw/main/{tail}" + + class Extension(NamedTuple): code: str is_valid: bool @@ -97,12 +222,12 @@ class ExtensionRelease(BaseModel): version: str archive: str source_repo: str - is_github_release = False - hash: Optional[str] - html_url: Optional[str] - description: Optional[str] + is_github_release: bool = False + hash: Optional[str] = None + html_url: Optional[str] = None + description: Optional[str] = None details_html: Optional[str] = None - icon: Optional[str] + icon: Optional[str] = None @classmethod def from_github_release( @@ -132,52 +257,6 @@ class ExtensionRelease(BaseModel): return [] -class ExplicitRelease(BaseModel): - id: str - name: str - version: str - archive: str - hash: str - dependencies: List[str] = [] - icon: Optional[str] - short_description: Optional[str] - html_url: Optional[str] - details: Optional[str] - info_notification: Optional[str] - critical_notification: Optional[str] - - -class GitHubRelease(BaseModel): - id: str - organisation: str - repository: str - - -class Manifest(BaseModel): - featured: List[str] = [] - extensions: List["ExplicitRelease"] = [] - repos: List["GitHubRelease"] = [] - - -class GitHubRepoRelease(BaseModel): - name: str - tag_name: str - zipball_url: str - html_url: str - - -class GitHubRepo(BaseModel): - stargazers_count: str - html_url: str - default_branch: str - - -class ExtensionConfig(BaseModel): - name: str - short_description: str - tile: str = "" - - class InstallableExtension(BaseModel): id: str name: str @@ -187,8 +266,9 @@ class InstallableExtension(BaseModel): is_admin_only: bool = False stars: int = 0 featured = False - latest_release: Optional[ExtensionRelease] - installed_release: Optional[ExtensionRelease] + latest_release: Optional[ExtensionRelease] = None + installed_release: Optional[ExtensionRelease] = None + archive: Optional[str] = None @property def hash(self) -> str: @@ -234,6 +314,7 @@ class InstallableExtension(BaseModel): if ext_zip_file.is_file(): os.remove(ext_zip_file) try: + assert self.installed_release, "installed_release is none." download_url(self.installed_release.archive, ext_zip_file) except Exception as ex: logger.warning(ex) @@ -334,8 +415,7 @@ class InstallableExtension(BaseModel): id=github_release.id, name=config.name, short_description=config.short_description, - version="0", - stars=repo.stargazers_count, + stars=int(repo.stargazers_count), icon=icon_to_github_url( f"{github_release.organisation}/{github_release.repository}", config.tile, @@ -354,7 +434,6 @@ class InstallableExtension(BaseModel): id=e.id, name=e.name, archive=e.archive, - hash=e.hash, short_description=e.short_description, icon=e.icon, dependencies=e.dependencies, @@ -453,82 +532,3 @@ def get_valid_extensions() -> List[Extension]: return [ extension for extension in ExtensionManager().extensions if extension.is_valid ] - - -def download_url(url, save_path): - with urllib.request.urlopen(url) as dl_file: - with open(save_path, "wb") as out_file: - out_file.write(dl_file.read()) - - -def file_hash(filename): - h = hashlib.sha256() - b = bytearray(128 * 1024) - mv = memoryview(b) - with open(filename, "rb", buffering=0) as f: - while n := f.readinto(mv): - h.update(mv[:n]) - return h.hexdigest() - - -def icon_to_github_url(source_repo: str, path: Optional[str]) -> str: - if not path: - return "" - _, _, *rest = path.split("/") - tail = "/".join(rest) - return f"https://github.com/{source_repo}/raw/main/{tail}" - - -async def fetch_github_repo_info( - org: str, repository: str -) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]: - repo_url = f"https://api.github.com/repos/{org}/{repository}" - error_msg = "Cannot fetch extension repo" - repo = await gihub_api_get(repo_url, error_msg) - github_repo = GitHubRepo.parse_obj(repo) - - lates_release_url = ( - f"https://api.github.com/repos/{org}/{repository}/releases/latest" - ) - error_msg = "Cannot fetch extension releases" - latest_release: Any = await gihub_api_get(lates_release_url, error_msg) - - config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json" - error_msg = "Cannot fetch config for extension" - config = await gihub_api_get(config_url, error_msg) - - return ( - github_repo, - GitHubRepoRelease.parse_obj(latest_release), - ExtensionConfig.parse_obj(config), - ) - - -async def fetch_manifest(url) -> Manifest: - error_msg = "Cannot fetch extensions manifest" - manifest = await gihub_api_get(url, error_msg) - return Manifest.parse_obj(manifest) - - -async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]: - releases_url = f"https://api.github.com/repos/{org}/{repo}/releases" - error_msg = "Cannot fetch extension releases" - releases = await gihub_api_get(releases_url, error_msg) - return [GitHubRepoRelease.parse_obj(r) for r in releases] - - -async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any: - async with httpx.AsyncClient() as client: - headers = ( - {"Authorization": f"Bearer {settings.lnbits_ext_github_token}"} - if settings.lnbits_ext_github_token - else None - ) - resp = await client.get( - url, - headers=headers, - ) - if resp.status_code != 200: - logger.warning(f"{error_msg} ({url}): {resp.text}") - resp.raise_for_status() - return resp.json() diff --git a/lnbits/jinja2_templating.py b/lnbits/jinja2_templating.py index 5dfe36c3..f4539442 100644 --- a/lnbits/jinja2_templating.py +++ b/lnbits/jinja2_templating.py @@ -1,25 +1,18 @@ -# Borrowed from the excellent accent-starlette -# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py - import typing -from starlette import templating +from jinja2 import BaseLoader, Environment, pass_context from starlette.datastructures import QueryParams from starlette.requests import Request - -try: - import jinja2 -except ImportError: # pragma: nocover - jinja2 = None # type: ignore +from starlette.templating import Jinja2Templates as SuperJinja2Templates -class Jinja2Templates(templating.Jinja2Templates): - def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231 - assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" +class Jinja2Templates(SuperJinja2Templates): + def __init__(self, loader: BaseLoader) -> None: + super().__init__("") self.env = self.get_environment(loader) - def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment": - @jinja2.pass_context + def get_environment(self, loader: BaseLoader) -> Environment: + @pass_context def url_for(context: dict, name: str, **path_params: typing.Any) -> str: request: Request = context["request"] return request.app.url_path_for(name, **path_params) @@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates): values.update(new) return QueryParams(**values) - env = jinja2.Environment(loader=loader, autoescape=True) + env = Environment(loader=loader, autoescape=True) env.globals["url_for"] = url_for env.globals["url_params_update"] = url_params_update return env diff --git a/lnbits/middleware.py b/lnbits/middleware.py index daac03bf..93a5671c 100644 --- a/lnbits/middleware.py +++ b/lnbits/middleware.py @@ -26,6 +26,7 @@ class InstalledExtensionMiddleware: else: _, path_name = path_elements path_type = None + rest = [] # block path for all users if the extension is disabled if path_name in settings.lnbits_deactivated_extensions: @@ -88,7 +89,7 @@ class ExtensionsRedirectMiddleware: if "from_path" not in redirect: return False header_filters = ( - redirect["header_filters"] if "header_filters" in redirect else [] + redirect["header_filters"] if "header_filters" in redirect else {} ) return self._has_common_path(redirect["from_path"], path) and self._has_headers( header_filters, req_headers diff --git a/lnbits/settings.py b/lnbits/settings.py index 62751b4f..c5bdc43d 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -24,6 +24,7 @@ def list_parse_fallback(v): class LNbitsSettings(BaseSettings): + @classmethod def validate(cls, val): if type(val) == str: val = val.split(",") if val else [] @@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings): class LNbitsFundingSource(LNbitsSettings): lnbits_endpoint: str = Field(default="https://legend.lnbits.com") lnbits_key: Optional[str] = Field(default=None) + lnbits_admin_key: Optional[str] = Field(default=None) + lnbits_invoice_key: Optional[str] = Field(default=None) class ClicheFundingSource(LNbitsSettings): @@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings): lnpay_api_endpoint: Optional[str] = Field(default=None) lnpay_api_key: Optional[str] = Field(default=None) lnpay_wallet_key: Optional[str] = Field(default=None) + lnpay_admin_key: Optional[str] = Field(default=None) class OpenNodeFundingSource(LNbitsSettings): opennode_api_endpoint: Optional[str] = Field(default=None) opennode_key: Optional[str] = Field(default=None) + opennode_admin_key: Optional[str] = Field(default=None) + opennode_invoice_key: Optional[str] = Field(default=None) class SparkFundingSource(LNbitsSettings): @@ -208,8 +214,9 @@ class EditableSettings( "lnbits_admin_extensions", pre=True, ) + @classmethod def validate_editable_settings(cls, val): - return super().validate(cls, val) + return super().validate(val) @classmethod def from_dict(cls, d: dict): @@ -281,8 +288,9 @@ class ReadOnlySettings( "lnbits_allowed_funding_sources", pre=True, ) + @classmethod def validate_readonly_settings(cls, val): - return super().validate(cls, val) + return super().validate(val) @classmethod def readonly_fields(cls): diff --git a/lnbits/tasks.py b/lnbits/tasks.py index f4d3bf7b..6c482256 100644 --- a/lnbits/tasks.py +++ b/lnbits/tasks.py @@ -3,7 +3,7 @@ import time import traceback import uuid from http import HTTPStatus -from typing import Dict +from typing import Dict, Optional from fastapi.exceptions import HTTPException from loguru import logger @@ -42,7 +42,7 @@ class SseListenersDict(dict): A dict of sse listeners. """ - def __init__(self, name: str = None): + def __init__(self, name: Optional[str] = None): self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}" def __setitem__(self, key, value): @@ -65,7 +65,7 @@ class SseListenersDict(dict): invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners") -def register_invoice_listener(send_chan: asyncio.Queue, name: str = None): +def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None): """ A method intended for extensions (and core/tasks.py) to call when they want to be notified about new invoice payments incoming. Will emit all incoming payments. @@ -164,7 +164,7 @@ async def check_pending_payments(): async def perform_balance_checks(): while True: for bc in await get_balance_checks(): - redeem_lnurl_withdraw(bc.wallet, bc.url) + await redeem_lnurl_withdraw(bc.wallet, bc.url) await asyncio.sleep(60 * 60 * 6) # every 6 hours diff --git a/lnbits/wallets/__init__.py b/lnbits/wallets/__init__.py index c16fb42e..b1b8c25c 100644 --- a/lnbits/wallets/__init__.py +++ b/lnbits/wallets/__init__.py @@ -1,8 +1,6 @@ # flake8: noqa: F401 - - from .cliche import ClicheWallet -from .cln import CoreLightningWallet # legacy .env support +from .cln import CoreLightningWallet from .cln import CoreLightningWallet as CLightningWallet from .eclair import EclairWallet from .fake import FakeWallet diff --git a/lnbits/wallets/cliche.py b/lnbits/wallets/cliche.py index 211ba4f3..cb11f520 100644 --- a/lnbits/wallets/cliche.py +++ b/lnbits/wallets/cliche.py @@ -22,6 +22,8 @@ class ClicheWallet(Wallet): def __init__(self): self.endpoint = settings.cliche_endpoint + if not self.endpoint: + raise Exception("cannot initialize cliche") async def status(self) -> StatusResponse: try: @@ -36,7 +38,7 @@ class ClicheWallet(Wallet): data = json.loads(r) except: return StatusResponse( - f"Failed to connect to {self.endpoint}, got: '{r.text[:200]}...'", 0 + f"Failed to connect to {self.endpoint}, got: '{r[:200]}...'", 0 ) return StatusResponse(None, data["result"]["wallets"][0]["balance"]) @@ -89,6 +91,13 @@ class ClicheWallet(Wallet): async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: ws = create_connection(self.endpoint) ws.send(f"pay-invoice --invoice {bolt11}") + checking_id, fee_msat, preimage, error_message, payment_ok = ( + None, + None, + None, + None, + None, + ) for _ in range(2): r = ws.recv() data = json.loads(r) @@ -151,9 +160,9 @@ class ClicheWallet(Wallet): async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: while True: try: - ws = await create_connection(self.endpoint) + ws = create_connection(self.endpoint) while True: - r = await ws.recv() + r = ws.recv() data = json.loads(r) print(data) try: diff --git a/lnbits/wallets/eclair.py b/lnbits/wallets/eclair.py index a45123b1..073f486e 100644 --- a/lnbits/wallets/eclair.py +++ b/lnbits/wallets/eclair.py @@ -7,10 +7,7 @@ from typing import AsyncGenerator, Dict, Optional import httpx from loguru import logger - -# TODO: https://github.com/lnbits/lnbits/issues/764 -# mypy https://github.com/aaugustin/websockets/issues/940 -from websockets import connect # type: ignore +from websockets.client import connect from lnbits.settings import settings @@ -34,11 +31,13 @@ class UnknownError(Exception): class EclairWallet(Wallet): def __init__(self): url = settings.eclair_url - self.url = url[:-1] if url.endswith("/") else url + passw = settings.eclair_pass + if not url or not passw: + raise Exception("cannot initialize eclair") + self.url = url[:-1] if url.endswith("/") else url self.ws_url = f"ws://{urllib.parse.urlsplit(self.url).netloc}/ws" - passw = settings.eclair_pass encodedAuth = base64.b64encode(f":{passw}".encode()) auth = str(encodedAuth, "utf-8") self.auth = {"Authorization": f"Basic {auth}"} @@ -71,7 +70,11 @@ class EclairWallet(Wallet): **kwargs, ) -> InvoiceResponse: - data: Dict = {"amountMsat": amount * 1000} + data: Dict = { + "amountMsat": amount * 1000, + "description_hash": b"", + "description": memo, + } if kwargs.get("expiry"): data["expireIn"] = kwargs["expiry"] @@ -79,8 +82,6 @@ class EclairWallet(Wallet): data["descriptionHash"] = description_hash.hex() elif unhashed_description: data["descriptionHash"] = hashlib.sha256(unhashed_description).hexdigest() - else: - data["description"] = memo or "" async with httpx.AsyncClient() as client: r = await client.post( @@ -149,6 +150,7 @@ class EclairWallet(Wallet): } data = r.json()[-1] + fee_msat = 0 if data["status"]["type"] == "sent": fee_msat = -data["status"]["feesPaid"] preimage = data["status"]["paymentPreimage"] @@ -223,10 +225,10 @@ class EclairWallet(Wallet): ) as ws: while True: message = await ws.recv() - message = json.loads(message) + message_json = json.loads(message) - if message and message["type"] == "payment-received": - yield message["paymentHash"] + if message_json and message_json["type"] == "payment-received": + yield message_json["paymentHash"] except Exception as exc: logger.error( diff --git a/lnbits/wallets/fake.py b/lnbits/wallets/fake.py index 93e7d3f0..0927cd8b 100644 --- a/lnbits/wallets/fake.py +++ b/lnbits/wallets/fake.py @@ -48,16 +48,15 @@ class FakeWallet(Wallet): "amount": amount, "currency": "bc", "privkey": self.privkey, - "memo": None, - "description_hash": None, + "memo": memo, + "description_hash": b"", "description": "", "fallback": None, - "expires": None, + "expires": kwargs.get("expiry"), + "timestamp": datetime.now().timestamp(), "route": None, + "tags_set": [], } - data["expires"] = kwargs.get("expiry") - data["amount"] = amount * 1000 - data["timestamp"] = datetime.now().timestamp() if description_hash: data["tags_set"] = ["h"] data["description_hash"] = description_hash @@ -69,7 +68,7 @@ class FakeWallet(Wallet): data["memo"] = memo data["description"] = memo randomHash = ( - data["privkey"][:6] + self.privkey[:6] + hashlib.sha256(str(random.getrandbits(256)).encode()).hexdigest()[6:] ) data["paymenthash"] = randomHash @@ -78,12 +77,10 @@ class FakeWallet(Wallet): return InvoiceResponse(True, checking_id, payment_request) - async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: + async def pay_invoice(self, bolt11: str, _: int) -> PaymentResponse: invoice = decode(bolt11) - if ( - hasattr(invoice, "checking_id") - and invoice.checking_id[:6] == self.privkey[:6] # type: ignore - ): + + if invoice.payment_hash[:6] == self.privkey[:6]: await self.queue.put(invoice) return PaymentResponse(True, invoice.payment_hash, 0) else: @@ -91,10 +88,10 @@ class FakeWallet(Wallet): ok=False, error_message="Only internal invoices can be used!" ) - async def get_invoice_status(self, checking_id: str) -> PaymentStatus: + async def get_invoice_status(self, _: str) -> PaymentStatus: return PaymentStatus(None) - async def get_payment_status(self, checking_id: str) -> PaymentStatus: + async def get_payment_status(self, _: str) -> PaymentStatus: return PaymentStatus(None) async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: diff --git a/lnbits/wallets/lnbits.py b/lnbits/wallets/lnbits.py index 74c9efcc..902711d6 100644 --- a/lnbits/wallets/lnbits.py +++ b/lnbits/wallets/lnbits.py @@ -21,12 +21,13 @@ class LNbitsWallet(Wallet): def __init__(self): self.endpoint = settings.lnbits_endpoint - key = ( settings.lnbits_key or settings.lnbits_admin_key or settings.lnbits_invoice_key ) + if not self.endpoint or not key: + raise Exception("cannot initialize lnbits wallet") self.key = {"X-Api-Key": key} async def status(self) -> StatusResponse: @@ -60,7 +61,7 @@ class LNbitsWallet(Wallet): unhashed_description: Optional[bytes] = None, **kwargs, ) -> InvoiceResponse: - data: Dict = {"out": False, "amount": amount} + data: Dict = {"out": False, "amount": amount, "memo": memo or ""} if kwargs.get("expiry"): data["expiry"] = kwargs["expiry"] if description_hash: @@ -68,8 +69,6 @@ class LNbitsWallet(Wallet): if unhashed_description: data["unhashed_description"] = unhashed_description.hex() - data["memo"] = memo or "" - async with httpx.AsyncClient() as client: r = await client.post( url=f"{self.endpoint}/api/v1/payments", headers=self.key, json=data diff --git a/lnbits/wallets/lndgrpc.py b/lnbits/wallets/lndgrpc.py index 4173e79e..5cb90a7c 100644 --- a/lnbits/wallets/lndgrpc.py +++ b/lnbits/wallets/lndgrpc.py @@ -105,9 +105,6 @@ class LndWallet(Wallet): ) endpoint = settings.lnd_grpc_endpoint - self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint - self.port = int(settings.lnd_grpc_port) - self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert macaroon = ( settings.lnd_grpc_macaroon @@ -122,8 +119,17 @@ class LndWallet(Wallet): macaroon = AESCipher(description="macaroon decryption").decrypt( encrypted_macaroon ) - self.macaroon = load_macaroon(macaroon) + cert_path = settings.lnd_grpc_cert or settings.lnd_cert + if not endpoint or not macaroon or not cert_path or not settings.lnd_grpc_port: + raise Exception("cannot initialize lndrest") + + self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint + self.port = int(settings.lnd_grpc_port) + self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert + + self.macaroon = load_macaroon(macaroon) + self.cert_path = cert_path cert = open(self.cert_path, "rb").read() creds = grpc.ssl_channel_credentials(cert) auth_creds = grpc.metadata_call_credentials(self.metadata_callback) @@ -140,8 +146,6 @@ class LndWallet(Wallet): async def status(self) -> StatusResponse: try: resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest()) - except RpcError as exc: - return StatusResponse(str(exc._details), 0) except Exception as exc: return StatusResponse(str(exc), 0) @@ -155,20 +159,23 @@ class LndWallet(Wallet): unhashed_description: Optional[bytes] = None, **kwargs, ) -> InvoiceResponse: - params: Dict = {"value": amount, "private": True} + data: Dict = { + "description_hash": b"", + "value": amount, + "private": True, + "memo": memo or "", + } if kwargs.get("expiry"): - params["expiry"] = kwargs["expiry"] + data["expiry"] = kwargs["expiry"] if description_hash: - params["description_hash"] = description_hash + data["description_hash"] = description_hash elif unhashed_description: - params["description_hash"] = hashlib.sha256( + data["description_hash"] = hashlib.sha256( unhashed_description ).digest() # as bytes directly - else: - params["memo"] = memo or "" try: - req = ln.Invoice(**params) + req = ln.Invoice(**data) resp = await self.rpc.AddInvoice(req) except Exception as exc: error_message = str(exc) @@ -188,8 +195,6 @@ class LndWallet(Wallet): ) try: resp = await self.routerpc.SendPaymentV2(req).read() - except RpcError as exc: - return PaymentResponse(False, None, None, None, exc._details) except Exception as exc: return PaymentResponse(False, None, None, None, str(exc)) diff --git a/lnbits/wallets/lndrest.py b/lnbits/wallets/lndrest.py index c0e344ae..303b4879 100644 --- a/lnbits/wallets/lndrest.py +++ b/lnbits/wallets/lndrest.py @@ -24,11 +24,6 @@ class LndRestWallet(Wallet): def __init__(self): endpoint = settings.lnd_rest_endpoint - endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint - endpoint = ( - f"https://{endpoint}" if not endpoint.startswith("http") else endpoint - ) - self.endpoint = endpoint macaroon = ( settings.lnd_rest_macaroon @@ -43,6 +38,15 @@ class LndRestWallet(Wallet): macaroon = AESCipher(description="macaroon decryption").decrypt( encrypted_macaroon ) + + if not endpoint or not macaroon or not settings.lnd_rest_cert: + raise Exception("cannot initialize lndrest") + + endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint + endpoint = ( + f"https://{endpoint}" if not endpoint.startswith("http") else endpoint + ) + self.endpoint = endpoint self.macaroon = load_macaroon(macaroon) self.auth = {"Grpc-Metadata-macaroon": self.macaroon} @@ -74,7 +78,7 @@ class LndRestWallet(Wallet): unhashed_description: Optional[bytes] = None, **kwargs, ) -> InvoiceResponse: - data: Dict = {"value": amount, "private": True} + data: Dict = {"value": amount, "private": True, "memo": memo or ""} if kwargs.get("expiry"): data["expiry"] = kwargs["expiry"] if description_hash: @@ -85,8 +89,6 @@ class LndRestWallet(Wallet): data["description_hash"] = base64.b64encode( hashlib.sha256(unhashed_description).digest() ).decode("ascii") - else: - data["memo"] = memo or "" async with httpx.AsyncClient(verify=self.cert) as client: r = await client.post( diff --git a/lnbits/wallets/lnpay.py b/lnbits/wallets/lnpay.py index ccc5254c..f05e4432 100644 --- a/lnbits/wallets/lnpay.py +++ b/lnbits/wallets/lnpay.py @@ -5,7 +5,7 @@ from http import HTTPStatus from typing import AsyncGenerator, Dict, Optional import httpx -from fastapi.exceptions import HTTPException +from fastapi import HTTPException from loguru import logger from lnbits.settings import settings @@ -24,8 +24,13 @@ class LNPayWallet(Wallet): def __init__(self): endpoint = settings.lnpay_api_endpoint + wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key + + if not endpoint or not wallet_key or not settings.lnpay_api_key: + raise Exception("cannot initialize lnpay") + + self.wallet_key = wallet_key self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint - self.wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key self.auth = {"X-Api-Key": settings.lnpay_api_key} async def status(self) -> StatusResponse: @@ -134,7 +139,9 @@ class LNPayWallet(Wallet): yield value async def webhook_listener(self): - text: str = await request.get_data() + # TODO: request.get_data is undefined, was it something with Flask or quart? + # probably issue introduced when refactoring? + text: str = await request.get_data() # type: ignore try: data = json.loads(text) except json.decoder.JSONDecodeError: diff --git a/lnbits/wallets/lntips.py b/lnbits/wallets/lntips.py index 4551a207..be8159b9 100644 --- a/lnbits/wallets/lntips.py +++ b/lnbits/wallets/lntips.py @@ -21,13 +21,14 @@ from .base import ( class LnTipsWallet(Wallet): def __init__(self): endpoint = settings.lntips_api_endpoint - self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint - key = ( settings.lntips_api_key or settings.lntips_admin_key or settings.lntips_invoice_key ) + if not endpoint or not key: + raise Exception("cannot initialize lntxbod") + self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.auth = {"Authorization": f"Basic {key}"} async def status(self) -> StatusResponse: @@ -55,13 +56,11 @@ class LnTipsWallet(Wallet): unhashed_description: Optional[bytes] = None, **kwargs, ) -> InvoiceResponse: - data: Dict = {"amount": amount} + data: Dict = {"amount": amount, "description_hash": "", "memo": memo or ""} if description_hash: data["description_hash"] = description_hash.hex() elif unhashed_description: data["description_hash"] = hashlib.sha256(unhashed_description).hexdigest() - else: - data["memo"] = memo or "" async with httpx.AsyncClient() as client: r = await client.post( diff --git a/lnbits/wallets/opennode.py b/lnbits/wallets/opennode.py index 89c7f1d5..08b234ee 100644 --- a/lnbits/wallets/opennode.py +++ b/lnbits/wallets/opennode.py @@ -4,7 +4,7 @@ from http import HTTPStatus from typing import AsyncGenerator, Optional import httpx -from fastapi.exceptions import HTTPException +from fastapi import HTTPException from loguru import logger from lnbits.settings import settings @@ -24,13 +24,15 @@ class OpenNodeWallet(Wallet): def __init__(self): endpoint = settings.opennode_api_endpoint - self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint - key = ( settings.opennode_key or settings.opennode_admin_key or settings.opennode_invoice_key ) + if not endpoint or not key: + raise Exception("cannot initialize opennode") + + self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.auth = {"Authorization": key} async def status(self) -> StatusResponse: @@ -140,7 +142,9 @@ class OpenNodeWallet(Wallet): yield value async def webhook_listener(self): - data = await request.form + # TODO: request.form is undefined, was it something with Flask or quart? + # probably issue introduced when refactoring? + data = await request.form # type: ignore if "status" not in data or data["status"] != "paid": raise HTTPException(status_code=HTTPStatus.NO_CONTENT) diff --git a/lnbits/wallets/spark.py b/lnbits/wallets/spark.py index 66cfba36..8f41e372 100644 --- a/lnbits/wallets/spark.py +++ b/lnbits/wallets/spark.py @@ -28,6 +28,7 @@ class UnknownError(Exception): class SparkWallet(Wallet): def __init__(self): + assert settings.spark_url, "spark url does not exist" self.url = settings.spark_url.replace("/rpc", "") self.token = settings.spark_token @@ -46,6 +47,7 @@ class SparkWallet(Wallet): try: async with httpx.AsyncClient() as client: + assert self.token, "spark wallet token does not exist" r = await client.post( self.url + "/rpc", headers={"X-Access": self.token}, @@ -133,38 +135,49 @@ class SparkWallet(Wallet): bolt11=bolt11, maxfee=fee_limit_msat, ) + fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"]) + preimage = r["payment_preimage"] + return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None) + except (SparkError, UnknownError) as exc: listpays = await self.listpays(bolt11) - if listpays: - pays = listpays["pays"] + if not listpays: + return PaymentResponse(False, None, None, None, str(exc)) - if len(pays) == 0: - return PaymentResponse(False, None, None, None, str(exc)) + pays = listpays["pays"] - pay = pays[0] - payment_hash = pay["payment_hash"] + if len(pays) == 0: + return PaymentResponse(False, None, None, None, str(exc)) - if len(pays) > 1: - raise SparkError( - f"listpays({payment_hash}) returned an unexpected response: {listpays}" - ) + pay = pays[0] + payment_hash = pay["payment_hash"] - if pay["status"] == "failed": - return PaymentResponse(False, None, None, None, str(exc)) - elif pay["status"] == "pending": - return PaymentResponse(None, payment_hash, None, None, None) - elif pay["status"] == "complete": - r = pay - r["payment_preimage"] = pay["preimage"] - r["msatoshi"] = int(pay["amount_msat"][0:-4]) - r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4]) - # this may result in an error if it was paid previously - # our database won't allow the same payment_hash to be added twice - # this is good + if len(pays) > 1: + raise SparkError( + f"listpays({payment_hash}) returned an unexpected response: {listpays}" + ) - fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"]) - preimage = r["payment_preimage"] - return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None) + if pay["status"] == "failed": + return PaymentResponse(False, None, None, None, str(exc)) + + if pay["status"] == "pending": + return PaymentResponse(None, payment_hash, None, None, None) + + if pay["status"] == "complete": + r = pay + r["payment_preimage"] = pay["preimage"] + r["msatoshi"] = int(pay["amount_msat"][0:-4]) + r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4]) + # this may result in an error if it was paid previously + # our database won't allow the same payment_hash to be added twice + # this is good + fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"]) + preimage = r["payment_preimage"] + return PaymentResponse( + True, r["payment_hash"], fee_msat, preimage, None + ) + else: + return PaymentResponse(False, None, None, None, str(exc)) async def get_invoice_status(self, checking_id: str) -> PaymentStatus: try: @@ -205,7 +218,7 @@ class SparkWallet(Wallet): - int(r["pays"][0]["amount_msat"][0:-4]) ) return PaymentStatus(True, fee_msat, r["pays"][0]["preimage"]) - elif status == "failed": + if status == "failed": return PaymentStatus(False) return PaymentStatus(None) raise KeyError("supplied an invalid checking_id") diff --git a/pyproject.toml b/pyproject.toml index 492e0c45..37692c6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,9 +69,6 @@ include = [ ] exclude = [ "lnbits/wallets/lnd_grpc_files", - "lnbits/wallets", - "lnbits/core", - "lnbits/*.py", "lnbits/extensions", ]