Merge pull request #1468 from lnbits/pyright3

introduce pyright + fix issues (supersedes #1444)
This commit is contained in:
Arc 2023-04-05 12:42:23 +01:00 committed by GitHub
commit 47df94178e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 383 additions and 334 deletions

View File

@ -1,6 +1,6 @@
.PHONY: test .PHONY: test
all: format check requirements.txt all: format check
format: prettier isort black format: prettier isort black

View File

@ -66,11 +66,12 @@ def decode(pr: str) -> Invoice:
invoice.amount_msat = _unshorten_amount(amountstr) invoice.amount_msat = _unshorten_amount(amountstr)
# pull out date # pull out date
invoice.date = data.read(35).uint date_bin = data.read(35)
invoice.date = date_bin.uint # type: ignore
while data.pos != data.len: while data.pos != data.len:
tag, tagdata, data = _pull_tagged(data) tag, tagdata, data = _pull_tagged(data)
data_length = len(tagdata) / 5 data_length = len(tagdata or []) / 5
if tag == "d": if tag == "d":
invoice.description = _trim_to_bytes(tagdata).decode() invoice.description = _trim_to_bytes(tagdata).decode()
@ -79,7 +80,7 @@ def decode(pr: str) -> Invoice:
elif tag == "p" and data_length == 52: elif tag == "p" and data_length == 52:
invoice.payment_hash = _trim_to_bytes(tagdata).hex() invoice.payment_hash = _trim_to_bytes(tagdata).hex()
elif tag == "x": elif tag == "x":
invoice.expiry = tagdata.uint invoice.expiry = tagdata.uint # type: ignore
elif tag == "n": elif tag == "n":
invoice.payee = _trim_to_bytes(tagdata).hex() invoice.payee = _trim_to_bytes(tagdata).hex()
# this won't work in most cases, we must extract the payee # this won't work in most cases, we must extract the payee
@ -90,11 +91,11 @@ def decode(pr: str) -> Invoice:
s = bitstring.ConstBitStream(tagdata) s = bitstring.ConstBitStream(tagdata)
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
route = Route( route = Route(
pubkey=s.read(264).tobytes().hex(), pubkey=s.read(264).tobytes().hex(), # type: ignore
short_channel_id=_readable_scid(s.read(64).intbe), short_channel_id=_readable_scid(s.read(64).intbe), # type: ignore
base_fee_msat=s.read(32).intbe, base_fee_msat=s.read(32).intbe, # type: ignore
ppm_fee=s.read(32).intbe, ppm_fee=s.read(32).intbe, # type: ignore
cltv=s.read(16).intbe, cltv=s.read(16).intbe, # type: ignore
) )
invoice.route_hints.append(route) invoice.route_hints.append(route)
@ -202,7 +203,8 @@ def lnencode(addr, privkey):
) )
data += tagged("r", route) data += tagged("r", route)
elif k == "f": elif k == "f":
data += encode_fallback(v, addr.currency) # NOTE: there was an error fallback here that's now removed
continue
elif k == "d": elif k == "d":
data += tagged_bytes("d", v.encode()) data += tagged_bytes("d", v.encode())
elif k == "x": elif k == "x":
@ -244,7 +246,13 @@ def lnencode(addr, privkey):
class LnAddr: class LnAddr:
def __init__( def __init__(
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None self,
paymenthash=None,
amount=None,
currency="bc",
tags=None,
date=None,
fallback=None,
): ):
self.date = int(time.time()) if not date else int(date) self.date = int(time.time()) if not date else int(date)
self.tags = [] if not tags else tags self.tags = [] if not tags else tags
@ -252,11 +260,13 @@ class LnAddr:
self.paymenthash = paymenthash self.paymenthash = paymenthash
self.signature = None self.signature = None
self.pubkey = None self.pubkey = None
self.fallback = fallback
self.currency = currency self.currency = currency
self.amount = amount self.amount = amount
def __str__(self): def __str__(self):
pubkey = bytes.hex(self.pubkey.serialize()).decode() assert self.pubkey, "LnAddr, pubkey must be set"
pubkey = bytes.hex(self.pubkey.serialize())
tags = ", ".join([f"{k}={v}" for k, v in self.tags]) tags = ", ".join([f"{k}={v}" for k, v in self.tags])
return f"LnAddr[{pubkey}, amount={self.amount}{self.currency} tags=[{tags}]]" return f"LnAddr[{pubkey}, amount={self.amount}{self.currency} tags=[{tags}]]"
@ -266,6 +276,7 @@ def shorten_amount(amount):
# Convert to pico initially # Convert to pico initially
amount = int(amount * 10**12) amount = int(amount * 10**12)
units = ["p", "n", "u", "m", ""] units = ["p", "n", "u", "m", ""]
unit = ""
for unit in units: for unit in units:
if amount % 1000 == 0: if amount % 1000 == 0:
amount //= 1000 amount //= 1000
@ -304,14 +315,6 @@ def _pull_tagged(stream):
return (CHARSET[tag], stream.read(length * 5), stream) return (CHARSET[tag], stream.read(length * 5), stream)
def is_p2pkh(currency, prefix):
return prefix == base58_prefix_map[currency][0]
def is_p2sh(currency, prefix):
return prefix == base58_prefix_map[currency][1]
# Tagged field containing BitArray # Tagged field containing BitArray
def tagged(char, l): def tagged(char, l):
# Tagged fields need to be zero-padded to 5 bits. # Tagged fields need to be zero-padded to 5 bits.
@ -359,5 +362,5 @@ def bitarray_to_u5(barr):
ret = [] ret = []
s = bitstring.ConstBitStream(barr) s = bitstring.ConstBitStream(barr)
while s.pos != s.len: while s.pos != s.len:
ret.append(s.read(5).uint) ret.append(s.read(5).uint) # type: ignore
return ret return ret

View File

@ -41,6 +41,7 @@ async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them.""" """Creates the necessary databases if they don't exist already; or migrates them."""
async with core_db.connect() as conn: async with core_db.connect() as conn:
exists = False
if conn.type == SQLITE: if conn.type == SQLITE:
exists = await conn.fetchone( exists = await conn.fetchone(
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'" "SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"

View File

@ -206,7 +206,7 @@ async def create_wallet(
async def update_wallet( async def update_wallet(
wallet_id: str, new_name: str, conn: Optional[Connection] = None wallet_id: str, new_name: str, conn: Optional[Connection] = None
) -> Optional[Wallet]: ) -> Optional[Wallet]:
return await (conn or db).execute( await (conn or db).execute(
""" """
UPDATE wallets SET UPDATE wallets SET
name = ? name = ?
@ -214,6 +214,9 @@ async def update_wallet(
""", """,
(new_name, wallet_id), (new_name, wallet_id),
) )
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert wallet, "updated created wallet couldn't be retrieved"
return wallet
async def delete_wallet( async def delete_wallet(
@ -393,7 +396,7 @@ async def get_payments(
clause.append("checking_id NOT LIKE 'internal_%'") clause.append("checking_id NOT LIKE 'internal_%'")
if not filters: if not filters:
filters = Filters() filters = Filters(limit=None, offset=None)
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
f""" f"""
@ -712,15 +715,19 @@ async def update_admin_settings(data: EditableSettings):
await db.execute("UPDATE settings SET editable_settings = ?", (json.dumps(data),)) await db.execute("UPDATE settings SET editable_settings = ?", (json.dumps(data),))
async def update_super_user(super_user: str): async def update_super_user(super_user: str) -> SuperSettings:
await db.execute("UPDATE settings SET super_user = ?", (super_user,)) await db.execute("UPDATE settings SET super_user = ?", (super_user,))
return await get_super_settings() settings = await get_super_settings()
assert settings, "updated super_user settings could not be retrieved"
return settings
async def create_admin_settings(super_user: str, new_settings: dict): async def create_admin_settings(super_user: str, new_settings: dict):
sql = "INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)" sql = "INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)"
await db.execute(sql, (super_user, json.dumps(new_settings))) await db.execute(sql, (super_user, json.dumps(new_settings)))
return await get_super_settings() settings = await get_super_settings()
assert settings, "created admin settings could not be retrieved"
return settings
# db versions # db versions

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
from io import BytesIO from io import BytesIO
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, TypedDict
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import httpx import httpx
@ -17,6 +17,7 @@ from lnbits.helpers import url_for
from lnbits.settings import ( from lnbits.settings import (
FAKE_WALLET, FAKE_WALLET,
EditableSettings, EditableSettings,
SuperSettings,
get_wallet_class, get_wallet_class,
readonly_variables, readonly_variables,
send_admin_user_to_saas, send_admin_user_to_saas,
@ -43,11 +44,6 @@ from .crud import (
) )
from .models import Payment from .models import Payment
try:
from typing import TypedDict
except ImportError: # pragma: nocover
from typing_extensions import TypedDict
class PaymentFailure(Exception): class PaymentFailure(Exception):
pass pass
@ -188,7 +184,7 @@ async def pay_invoice(
# do the balance check # do the balance check
wallet = await get_wallet(wallet_id, conn=conn) wallet = await get_wallet(wallet_id, conn=conn)
assert wallet assert wallet, "Wallet for balancecheck could not be fetched"
if wallet.balance_msat < 0: if wallet.balance_msat < 0:
logger.debug("balance is too low, deleting temporary payment") logger.debug("balance is too low, deleting temporary payment")
if not internal_checking_id and wallet.balance_msat > -fee_reserve_msat: if not internal_checking_id and wallet.balance_msat > -fee_reserve_msat:
@ -336,19 +332,19 @@ async def perform_lnurlauth(
return b return b
def encode_strict_der(r_int, s_int, order): def encode_strict_der(r: int, s: int, order: int):
# if s > order/2 verification will fail sometimes # if s > order/2 verification will fail sometimes
# so we must fix it here (see https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147) # so we must fix it here (see https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147)
if s_int > order // 2: if s > order // 2:
s_int = order - s_int s = order - s
# now we do the strict DER encoding copied from # now we do the strict DER encoding copied from
# https://github.com/KiriKiri/bip66 (without any checks) # https://github.com/KiriKiri/bip66 (without any checks)
r = int_to_bytes_suitable_der(r_int) r_temp = int_to_bytes_suitable_der(r)
s = int_to_bytes_suitable_der(s_int) s_temp = int_to_bytes_suitable_der(s)
r_len = len(r) r_len = len(r_temp)
s_len = len(s) s_len = len(s_temp)
sign_len = 6 + r_len + s_len sign_len = 6 + r_len + s_len
signature = BytesIO() signature = BytesIO()
@ -356,16 +352,17 @@ async def perform_lnurlauth(
signature.write((sign_len - 2).to_bytes(1, "big", signed=False)) signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
signature.write(0x02.to_bytes(1, "big", signed=False)) signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(r_len.to_bytes(1, "big", signed=False)) signature.write(r_len.to_bytes(1, "big", signed=False))
signature.write(r) signature.write(r_temp)
signature.write(0x02.to_bytes(1, "big", signed=False)) signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(s_len.to_bytes(1, "big", signed=False)) signature.write(s_len.to_bytes(1, "big", signed=False))
signature.write(s) signature.write(s_temp)
return signature.getvalue() return signature.getvalue()
sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der) sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
assert key.verifying_key, "LNURLauth verifying_key does not exist"
r = await client.get( r = await client.get(
callback, callback,
params={ params={
@ -469,7 +466,7 @@ def update_cached_settings(sets_dict: dict):
setattr(settings, "super_user", sets_dict["super_user"]) setattr(settings, "super_user", sets_dict["super_user"])
async def init_admin_settings(super_user: str = None): async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings:
account = None account = None
if super_user: if super_user:
account = await get_account(super_user) account = await get_account(super_user)

View File

@ -411,8 +411,7 @@ async def subscribe_wallet_invoices(request: Request, wallet: Wallet):
typ, data = await send_queue.get() typ, data = await send_queue.get()
if data: if data:
jdata = json.dumps(dict(data.dict(), pending=False)) jdata = json.dumps(dict(data.dict(), pending=False))
yield dict(data=jdata, event=typ)
yield dict(data=jdata, event=typ)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug(f"removing listener for wallet {uid}") logger.debug(f"removing listener for wallet {uid}")
api_invoice_listeners.pop(uid) api_invoice_listeners.pop(uid)
@ -431,11 +430,12 @@ async def api_payments_sse(
) )
# TODO: refactor this route into a public and admin one
@core_app.get("/api/v1/payments/{payment_hash}") @core_app.get("/api/v1/payments/{payment_hash}")
async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)): async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)):
# We use X_Api_Key here because we want this call to work with and without keys # We use X_Api_Key here because we want this call to work with and without keys
# If a valid key is given, we also return the field "details", otherwise not # If a valid key is given, we also return the field "details", otherwise not
wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None # type: ignore
# we have to specify the wallet id here, because postgres and sqlite return internal payments in different order # we have to specify the wallet id here, because postgres and sqlite return internal payments in different order
# and get_standalone_payment otherwise just fetches the first one, causing unpredictable results # and get_standalone_payment otherwise just fetches the first one, causing unpredictable results
@ -505,6 +505,7 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type
params.update(callback=url) # with k1 already in it params.update(callback=url) # with k1 already in it
lnurlauth_key = wallet.wallet.lnurlauth_key(domain) lnurlauth_key = wallet.wallet.lnurlauth_key(domain)
assert lnurlauth_key.verifying_key
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex()) params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
else: else:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@ -693,7 +694,7 @@ async def api_auditor():
if not error_message: if not error_message:
delta = node_balance - total_balance delta = node_balance - total_balance
else: else:
node_balance, delta = None, None node_balance, delta = 0, 0
return { return {
"node_balance_msats": int(node_balance), "node_balance_msats": int(node_balance),
@ -745,6 +746,7 @@ async def api_install_extension(
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Release not found" status_code=HTTPStatus.NOT_FOUND, detail="Release not found"
) )
ext_info = InstallableExtension( ext_info = InstallableExtension(
id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon
) )
@ -824,8 +826,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)
) )
@core_app.get("/api/v1/extension/{ext_id}/releases") @core_app.get(
async def get_extension_releases(ext_id: str, user: User = Depends(check_admin)): "/api/v1/extension/{ext_id}/releases", dependencies=[Depends(check_admin)]
)
async def get_extension_releases(ext_id: str):
try: try:
extension_releases: List[ extension_releases: List[
ExtensionRelease ExtensionRelease

View File

@ -40,19 +40,18 @@ async def api_public_payment_longpolling(payment_hash):
response = None response = None
async def payment_info_receiver(cancel_scope): async def payment_info_receiver():
async for payment in payment_queue.get(): for payment in await payment_queue.get():
if payment.payment_hash == payment_hash: if payment.payment_hash == payment_hash:
nonlocal response nonlocal response
response = {"status": "paid"} response = {"status": "paid"}
cancel_scope.cancel()
async def timeouter(cancel_scope): async def timeouter(cancel_scope):
await asyncio.sleep(45) await asyncio.sleep(45)
cancel_scope.cancel() cancel_scope.cancel()
asyncio.create_task(payment_info_receiver()) cancel_scope = asyncio.create_task(payment_info_receiver())
asyncio.create_task(timeouter()) asyncio.create_task(timeouter(cancel_scope))
if response: if response:
return response return response

View File

@ -131,7 +131,7 @@ class Database(Compat):
else: else:
self.type = POSTGRES self.type = POSTGRES
import psycopg2 from psycopg2.extensions import DECIMAL, new_type, register_type
def _parse_timestamp(value, _): def _parse_timestamp(value, _):
if value is None: if value is None:
@ -141,15 +141,15 @@ class Database(Compat):
f = "%Y-%m-%d %H:%M:%S" f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple()) return time.mktime(datetime.datetime.strptime(value, f).timetuple())
psycopg2.extensions.register_type( register_type(
psycopg2.extensions.new_type( new_type(
psycopg2.extensions.DECIMAL.values, DECIMAL.values,
"DEC2FLOAT", "DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None, lambda value, curs: float(value) if value is not None else None,
) )
) )
psycopg2.extensions.register_type( register_type(
psycopg2.extensions.new_type( new_type(
(1082, 1083, 1266), (1082, 1083, 1266),
"DATE2INT", "DATE2INT",
lambda value, curs: time.mktime(value.timetuple()) lambda value, curs: time.mktime(value.timetuple())
@ -158,11 +158,7 @@ class Database(Compat):
) )
) )
psycopg2.extensions.register_type( register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
psycopg2.extensions.new_type(
(1184, 1114), "TIMESTAMP2INT", _parse_timestamp
)
)
else: else:
if os.path.isdir(settings.lnbits_data_folder): if os.path.isdir(settings.lnbits_data_folder):
self.path = os.path.join( self.path = os.path.join(
@ -189,7 +185,7 @@ class Database(Compat):
async def connect(self): async def connect(self):
await self.lock.acquire() await self.lock.acquire()
try: try:
async with self.engine.connect() as conn: async with self.engine.connect() as conn: # type: ignore
async with conn.begin() as txn: async with conn.begin() as txn:
wconn = Connection(conn, txn, self.type, self.name, self.schema) wconn = Connection(conn, txn, self.type, self.name, self.schema)

View File

@ -1,14 +1,12 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Type from typing import Optional, Type
from fastapi import Security, status from fastapi import HTTPException, Request, Security, status
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.types import UUID4 from pydantic.types import UUID4
from starlette.requests import Request
from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet from lnbits.core.models import User, Wallet
@ -17,9 +15,13 @@ from lnbits.requestvars import g
from lnbits.settings import settings from lnbits.settings import settings
# TODO: fix type ignores
class KeyChecker(SecurityBase): class KeyChecker(SecurityBase):
def __init__( def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
): ):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@ -27,13 +29,13 @@ class KeyChecker(SecurityBase):
self._api_key = api_key self._api_key = api_key
if api_key: if api_key:
key = APIKey( key = APIKey(
**{"in": APIKeyIn.query}, **{"in": APIKeyIn.query}, # type: ignore
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - QUERY", description="Wallet API Key - QUERY",
) )
else: else:
key = APIKey( key = APIKey(
**{"in": APIKeyIn.header}, **{"in": APIKeyIn.header}, # type: ignore
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - HEADER", description="Wallet API Key - HEADER",
) )
@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker):
""" """
def __init__( def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
): ):
super().__init__(scheme_name, auto_error, api_key) super().__init__(scheme_name, auto_error, api_key)
self._key_type = "invoice" self._key_type = "invoice"
@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker):
""" """
def __init__( def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
): ):
super().__init__(scheme_name, auto_error, api_key) super().__init__(scheme_name, auto_error, api_key)
self._key_type = "admin" self._key_type = "admin"

View File

@ -3,20 +3,145 @@ import json
import os import os
import shutil import shutil
import sys import sys
import urllib.request
import zipfile import zipfile
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
from typing import Any, List, NamedTuple, Optional, Tuple from typing import Any, List, NamedTuple, Optional, Tuple
from urllib import request
import httpx import httpx
from fastapi.exceptions import HTTPException from fastapi import HTTPException
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from lnbits.settings import settings from lnbits.settings import settings
class ExplicitRelease(BaseModel):
id: str
name: str
version: str
archive: str
hash: str
dependencies: List[str] = []
icon: Optional[str]
short_description: Optional[str]
html_url: Optional[str]
details: Optional[str]
info_notification: Optional[str]
critical_notification: Optional[str]
class GitHubRelease(BaseModel):
id: str
organisation: str
repository: str
class Manifest(BaseModel):
featured: List[str] = []
extensions: List["ExplicitRelease"] = []
repos: List["GitHubRelease"] = []
class GitHubRepoRelease(BaseModel):
name: str
tag_name: str
zipball_url: str
html_url: str
class GitHubRepo(BaseModel):
stargazers_count: str
html_url: str
default_branch: str
class ExtensionConfig(BaseModel):
name: str
short_description: str
tile: str = ""
def download_url(url, save_path):
with request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
async def fetch_github_repo_info(
org: str, repository: str
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await gihub_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await gihub_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
async def fetch_manifest(url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await gihub_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await gihub_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
async with httpx.AsyncClient() as client:
headers = (
{"Authorization": "Bearer " + settings.lnbits_ext_github_token}
if settings.lnbits_ext_github_token
else None
)
resp = await client.get(
url,
headers=headers,
)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"
class Extension(NamedTuple): class Extension(NamedTuple):
code: str code: str
is_valid: bool is_valid: bool
@ -97,12 +222,12 @@ class ExtensionRelease(BaseModel):
version: str version: str
archive: str archive: str
source_repo: str source_repo: str
is_github_release = False is_github_release: bool = False
hash: Optional[str] hash: Optional[str] = None
html_url: Optional[str] html_url: Optional[str] = None
description: Optional[str] description: Optional[str] = None
details_html: Optional[str] = None details_html: Optional[str] = None
icon: Optional[str] icon: Optional[str] = None
@classmethod @classmethod
def from_github_release( def from_github_release(
@ -132,52 +257,6 @@ class ExtensionRelease(BaseModel):
return [] return []
class ExplicitRelease(BaseModel):
id: str
name: str
version: str
archive: str
hash: str
dependencies: List[str] = []
icon: Optional[str]
short_description: Optional[str]
html_url: Optional[str]
details: Optional[str]
info_notification: Optional[str]
critical_notification: Optional[str]
class GitHubRelease(BaseModel):
id: str
organisation: str
repository: str
class Manifest(BaseModel):
featured: List[str] = []
extensions: List["ExplicitRelease"] = []
repos: List["GitHubRelease"] = []
class GitHubRepoRelease(BaseModel):
name: str
tag_name: str
zipball_url: str
html_url: str
class GitHubRepo(BaseModel):
stargazers_count: str
html_url: str
default_branch: str
class ExtensionConfig(BaseModel):
name: str
short_description: str
tile: str = ""
class InstallableExtension(BaseModel): class InstallableExtension(BaseModel):
id: str id: str
name: str name: str
@ -187,8 +266,9 @@ class InstallableExtension(BaseModel):
is_admin_only: bool = False is_admin_only: bool = False
stars: int = 0 stars: int = 0
featured = False featured = False
latest_release: Optional[ExtensionRelease] latest_release: Optional[ExtensionRelease] = None
installed_release: Optional[ExtensionRelease] installed_release: Optional[ExtensionRelease] = None
archive: Optional[str] = None
@property @property
def hash(self) -> str: def hash(self) -> str:
@ -234,6 +314,7 @@ class InstallableExtension(BaseModel):
if ext_zip_file.is_file(): if ext_zip_file.is_file():
os.remove(ext_zip_file) os.remove(ext_zip_file)
try: try:
assert self.installed_release, "installed_release is none."
download_url(self.installed_release.archive, ext_zip_file) download_url(self.installed_release.archive, ext_zip_file)
except Exception as ex: except Exception as ex:
logger.warning(ex) logger.warning(ex)
@ -334,8 +415,7 @@ class InstallableExtension(BaseModel):
id=github_release.id, id=github_release.id,
name=config.name, name=config.name,
short_description=config.short_description, short_description=config.short_description,
version="0", stars=int(repo.stargazers_count),
stars=repo.stargazers_count,
icon=icon_to_github_url( icon=icon_to_github_url(
f"{github_release.organisation}/{github_release.repository}", f"{github_release.organisation}/{github_release.repository}",
config.tile, config.tile,
@ -354,7 +434,6 @@ class InstallableExtension(BaseModel):
id=e.id, id=e.id,
name=e.name, name=e.name,
archive=e.archive, archive=e.archive,
hash=e.hash,
short_description=e.short_description, short_description=e.short_description,
icon=e.icon, icon=e.icon,
dependencies=e.dependencies, dependencies=e.dependencies,
@ -453,82 +532,3 @@ def get_valid_extensions() -> List[Extension]:
return [ return [
extension for extension in ExtensionManager().extensions if extension.is_valid extension for extension in ExtensionManager().extensions if extension.is_valid
] ]
def download_url(url, save_path):
with urllib.request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"
async def fetch_github_repo_info(
org: str, repository: str
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await gihub_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await gihub_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
async def fetch_manifest(url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await gihub_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await gihub_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
async with httpx.AsyncClient() as client:
headers = (
{"Authorization": f"Bearer {settings.lnbits_ext_github_token}"}
if settings.lnbits_ext_github_token
else None
)
resp = await client.get(
url,
headers=headers,
)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()

View File

@ -1,25 +1,18 @@
# Borrowed from the excellent accent-starlette
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
import typing import typing
from starlette import templating from jinja2 import BaseLoader, Environment, pass_context
from starlette.datastructures import QueryParams from starlette.datastructures import QueryParams
from starlette.requests import Request from starlette.requests import Request
from starlette.templating import Jinja2Templates as SuperJinja2Templates
try:
import jinja2
except ImportError: # pragma: nocover
jinja2 = None # type: ignore
class Jinja2Templates(templating.Jinja2Templates): class Jinja2Templates(SuperJinja2Templates):
def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231 def __init__(self, loader: BaseLoader) -> None:
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" super().__init__("")
self.env = self.get_environment(loader) self.env = self.get_environment(loader)
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment": def get_environment(self, loader: BaseLoader) -> Environment:
@jinja2.pass_context @pass_context
def url_for(context: dict, name: str, **path_params: typing.Any) -> str: def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request: Request = context["request"] request: Request = context["request"]
return request.app.url_path_for(name, **path_params) return request.app.url_path_for(name, **path_params)
@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates):
values.update(new) values.update(new)
return QueryParams(**values) return QueryParams(**values)
env = jinja2.Environment(loader=loader, autoescape=True) env = Environment(loader=loader, autoescape=True)
env.globals["url_for"] = url_for env.globals["url_for"] = url_for
env.globals["url_params_update"] = url_params_update env.globals["url_params_update"] = url_params_update
return env return env

View File

@ -26,6 +26,7 @@ class InstalledExtensionMiddleware:
else: else:
_, path_name = path_elements _, path_name = path_elements
path_type = None path_type = None
rest = []
# block path for all users if the extension is disabled # block path for all users if the extension is disabled
if path_name in settings.lnbits_deactivated_extensions: if path_name in settings.lnbits_deactivated_extensions:
@ -88,7 +89,7 @@ class ExtensionsRedirectMiddleware:
if "from_path" not in redirect: if "from_path" not in redirect:
return False return False
header_filters = ( header_filters = (
redirect["header_filters"] if "header_filters" in redirect else [] redirect["header_filters"] if "header_filters" in redirect else {}
) )
return self._has_common_path(redirect["from_path"], path) and self._has_headers( return self._has_common_path(redirect["from_path"], path) and self._has_headers(
header_filters, req_headers header_filters, req_headers

View File

@ -24,6 +24,7 @@ def list_parse_fallback(v):
class LNbitsSettings(BaseSettings): class LNbitsSettings(BaseSettings):
@classmethod
def validate(cls, val): def validate(cls, val):
if type(val) == str: if type(val) == str:
val = val.split(",") if val else [] val = val.split(",") if val else []
@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings):
class LNbitsFundingSource(LNbitsSettings): class LNbitsFundingSource(LNbitsSettings):
lnbits_endpoint: str = Field(default="https://legend.lnbits.com") lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
lnbits_key: Optional[str] = Field(default=None) lnbits_key: Optional[str] = Field(default=None)
lnbits_admin_key: Optional[str] = Field(default=None)
lnbits_invoice_key: Optional[str] = Field(default=None)
class ClicheFundingSource(LNbitsSettings): class ClicheFundingSource(LNbitsSettings):
@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings):
lnpay_api_endpoint: Optional[str] = Field(default=None) lnpay_api_endpoint: Optional[str] = Field(default=None)
lnpay_api_key: Optional[str] = Field(default=None) lnpay_api_key: Optional[str] = Field(default=None)
lnpay_wallet_key: Optional[str] = Field(default=None) lnpay_wallet_key: Optional[str] = Field(default=None)
lnpay_admin_key: Optional[str] = Field(default=None)
class OpenNodeFundingSource(LNbitsSettings): class OpenNodeFundingSource(LNbitsSettings):
opennode_api_endpoint: Optional[str] = Field(default=None) opennode_api_endpoint: Optional[str] = Field(default=None)
opennode_key: Optional[str] = Field(default=None) opennode_key: Optional[str] = Field(default=None)
opennode_admin_key: Optional[str] = Field(default=None)
opennode_invoice_key: Optional[str] = Field(default=None)
class SparkFundingSource(LNbitsSettings): class SparkFundingSource(LNbitsSettings):
@ -208,8 +214,9 @@ class EditableSettings(
"lnbits_admin_extensions", "lnbits_admin_extensions",
pre=True, pre=True,
) )
@classmethod
def validate_editable_settings(cls, val): def validate_editable_settings(cls, val):
return super().validate(cls, val) return super().validate(val)
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
@ -281,8 +288,9 @@ class ReadOnlySettings(
"lnbits_allowed_funding_sources", "lnbits_allowed_funding_sources",
pre=True, pre=True,
) )
@classmethod
def validate_readonly_settings(cls, val): def validate_readonly_settings(cls, val):
return super().validate(cls, val) return super().validate(val)
@classmethod @classmethod
def readonly_fields(cls): def readonly_fields(cls):

View File

@ -3,7 +3,7 @@ import time
import traceback import traceback
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from typing import Dict from typing import Dict, Optional
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from loguru import logger from loguru import logger
@ -42,7 +42,7 @@ class SseListenersDict(dict):
A dict of sse listeners. A dict of sse listeners.
""" """
def __init__(self, name: str = None): def __init__(self, name: Optional[str] = None):
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}" self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
def __setitem__(self, key, value): def __setitem__(self, key, value):
@ -65,7 +65,7 @@ class SseListenersDict(dict):
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners") invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
def register_invoice_listener(send_chan: asyncio.Queue, name: str = None): def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None):
""" """
A method intended for extensions (and core/tasks.py) to call when they want to be notified about A method intended for extensions (and core/tasks.py) to call when they want to be notified about
new invoice payments incoming. Will emit all incoming payments. new invoice payments incoming. Will emit all incoming payments.
@ -164,7 +164,7 @@ async def check_pending_payments():
async def perform_balance_checks(): async def perform_balance_checks():
while True: while True:
for bc in await get_balance_checks(): for bc in await get_balance_checks():
redeem_lnurl_withdraw(bc.wallet, bc.url) await redeem_lnurl_withdraw(bc.wallet, bc.url)
await asyncio.sleep(60 * 60 * 6) # every 6 hours await asyncio.sleep(60 * 60 * 6) # every 6 hours

View File

@ -1,8 +1,6 @@
# flake8: noqa: F401 # flake8: noqa: F401
from .cliche import ClicheWallet from .cliche import ClicheWallet
from .cln import CoreLightningWallet # legacy .env support from .cln import CoreLightningWallet
from .cln import CoreLightningWallet as CLightningWallet from .cln import CoreLightningWallet as CLightningWallet
from .eclair import EclairWallet from .eclair import EclairWallet
from .fake import FakeWallet from .fake import FakeWallet

View File

@ -22,6 +22,8 @@ class ClicheWallet(Wallet):
def __init__(self): def __init__(self):
self.endpoint = settings.cliche_endpoint self.endpoint = settings.cliche_endpoint
if not self.endpoint:
raise Exception("cannot initialize cliche")
async def status(self) -> StatusResponse: async def status(self) -> StatusResponse:
try: try:
@ -36,7 +38,7 @@ class ClicheWallet(Wallet):
data = json.loads(r) data = json.loads(r)
except: except:
return StatusResponse( return StatusResponse(
f"Failed to connect to {self.endpoint}, got: '{r.text[:200]}...'", 0 f"Failed to connect to {self.endpoint}, got: '{r[:200]}...'", 0
) )
return StatusResponse(None, data["result"]["wallets"][0]["balance"]) return StatusResponse(None, data["result"]["wallets"][0]["balance"])
@ -89,6 +91,13 @@ class ClicheWallet(Wallet):
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
ws = create_connection(self.endpoint) ws = create_connection(self.endpoint)
ws.send(f"pay-invoice --invoice {bolt11}") ws.send(f"pay-invoice --invoice {bolt11}")
checking_id, fee_msat, preimage, error_message, payment_ok = (
None,
None,
None,
None,
None,
)
for _ in range(2): for _ in range(2):
r = ws.recv() r = ws.recv()
data = json.loads(r) data = json.loads(r)
@ -151,9 +160,9 @@ class ClicheWallet(Wallet):
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
while True: while True:
try: try:
ws = await create_connection(self.endpoint) ws = create_connection(self.endpoint)
while True: while True:
r = await ws.recv() r = ws.recv()
data = json.loads(r) data = json.loads(r)
print(data) print(data)
try: try:

View File

@ -7,10 +7,7 @@ from typing import AsyncGenerator, Dict, Optional
import httpx import httpx
from loguru import logger from loguru import logger
from websockets.client import connect
# TODO: https://github.com/lnbits/lnbits/issues/764
# mypy https://github.com/aaugustin/websockets/issues/940
from websockets import connect # type: ignore
from lnbits.settings import settings from lnbits.settings import settings
@ -34,11 +31,13 @@ class UnknownError(Exception):
class EclairWallet(Wallet): class EclairWallet(Wallet):
def __init__(self): def __init__(self):
url = settings.eclair_url url = settings.eclair_url
self.url = url[:-1] if url.endswith("/") else url passw = settings.eclair_pass
if not url or not passw:
raise Exception("cannot initialize eclair")
self.url = url[:-1] if url.endswith("/") else url
self.ws_url = f"ws://{urllib.parse.urlsplit(self.url).netloc}/ws" self.ws_url = f"ws://{urllib.parse.urlsplit(self.url).netloc}/ws"
passw = settings.eclair_pass
encodedAuth = base64.b64encode(f":{passw}".encode()) encodedAuth = base64.b64encode(f":{passw}".encode())
auth = str(encodedAuth, "utf-8") auth = str(encodedAuth, "utf-8")
self.auth = {"Authorization": f"Basic {auth}"} self.auth = {"Authorization": f"Basic {auth}"}
@ -71,7 +70,11 @@ class EclairWallet(Wallet):
**kwargs, **kwargs,
) -> InvoiceResponse: ) -> InvoiceResponse:
data: Dict = {"amountMsat": amount * 1000} data: Dict = {
"amountMsat": amount * 1000,
"description_hash": b"",
"description": memo,
}
if kwargs.get("expiry"): if kwargs.get("expiry"):
data["expireIn"] = kwargs["expiry"] data["expireIn"] = kwargs["expiry"]
@ -79,8 +82,6 @@ class EclairWallet(Wallet):
data["descriptionHash"] = description_hash.hex() data["descriptionHash"] = description_hash.hex()
elif unhashed_description: elif unhashed_description:
data["descriptionHash"] = hashlib.sha256(unhashed_description).hexdigest() data["descriptionHash"] = hashlib.sha256(unhashed_description).hexdigest()
else:
data["description"] = memo or ""
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.post( r = await client.post(
@ -149,6 +150,7 @@ class EclairWallet(Wallet):
} }
data = r.json()[-1] data = r.json()[-1]
fee_msat = 0
if data["status"]["type"] == "sent": if data["status"]["type"] == "sent":
fee_msat = -data["status"]["feesPaid"] fee_msat = -data["status"]["feesPaid"]
preimage = data["status"]["paymentPreimage"] preimage = data["status"]["paymentPreimage"]
@ -223,10 +225,10 @@ class EclairWallet(Wallet):
) as ws: ) as ws:
while True: while True:
message = await ws.recv() message = await ws.recv()
message = json.loads(message) message_json = json.loads(message)
if message and message["type"] == "payment-received": if message_json and message_json["type"] == "payment-received":
yield message["paymentHash"] yield message_json["paymentHash"]
except Exception as exc: except Exception as exc:
logger.error( logger.error(

View File

@ -48,16 +48,15 @@ class FakeWallet(Wallet):
"amount": amount, "amount": amount,
"currency": "bc", "currency": "bc",
"privkey": self.privkey, "privkey": self.privkey,
"memo": None, "memo": memo,
"description_hash": None, "description_hash": b"",
"description": "", "description": "",
"fallback": None, "fallback": None,
"expires": None, "expires": kwargs.get("expiry"),
"timestamp": datetime.now().timestamp(),
"route": None, "route": None,
"tags_set": [],
} }
data["expires"] = kwargs.get("expiry")
data["amount"] = amount * 1000
data["timestamp"] = datetime.now().timestamp()
if description_hash: if description_hash:
data["tags_set"] = ["h"] data["tags_set"] = ["h"]
data["description_hash"] = description_hash data["description_hash"] = description_hash
@ -69,7 +68,7 @@ class FakeWallet(Wallet):
data["memo"] = memo data["memo"] = memo
data["description"] = memo data["description"] = memo
randomHash = ( randomHash = (
data["privkey"][:6] self.privkey[:6]
+ hashlib.sha256(str(random.getrandbits(256)).encode()).hexdigest()[6:] + hashlib.sha256(str(random.getrandbits(256)).encode()).hexdigest()[6:]
) )
data["paymenthash"] = randomHash data["paymenthash"] = randomHash
@ -78,12 +77,10 @@ class FakeWallet(Wallet):
return InvoiceResponse(True, checking_id, payment_request) return InvoiceResponse(True, checking_id, payment_request)
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: async def pay_invoice(self, bolt11: str, _: int) -> PaymentResponse:
invoice = decode(bolt11) invoice = decode(bolt11)
if (
hasattr(invoice, "checking_id") if invoice.payment_hash[:6] == self.privkey[:6]:
and invoice.checking_id[:6] == self.privkey[:6] # type: ignore
):
await self.queue.put(invoice) await self.queue.put(invoice)
return PaymentResponse(True, invoice.payment_hash, 0) return PaymentResponse(True, invoice.payment_hash, 0)
else: else:
@ -91,10 +88,10 @@ class FakeWallet(Wallet):
ok=False, error_message="Only internal invoices can be used!" ok=False, error_message="Only internal invoices can be used!"
) )
async def get_invoice_status(self, checking_id: str) -> PaymentStatus: async def get_invoice_status(self, _: str) -> PaymentStatus:
return PaymentStatus(None) return PaymentStatus(None)
async def get_payment_status(self, checking_id: str) -> PaymentStatus: async def get_payment_status(self, _: str) -> PaymentStatus:
return PaymentStatus(None) return PaymentStatus(None)
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:

View File

@ -21,12 +21,13 @@ class LNbitsWallet(Wallet):
def __init__(self): def __init__(self):
self.endpoint = settings.lnbits_endpoint self.endpoint = settings.lnbits_endpoint
key = ( key = (
settings.lnbits_key settings.lnbits_key
or settings.lnbits_admin_key or settings.lnbits_admin_key
or settings.lnbits_invoice_key or settings.lnbits_invoice_key
) )
if not self.endpoint or not key:
raise Exception("cannot initialize lnbits wallet")
self.key = {"X-Api-Key": key} self.key = {"X-Api-Key": key}
async def status(self) -> StatusResponse: async def status(self) -> StatusResponse:
@ -60,7 +61,7 @@ class LNbitsWallet(Wallet):
unhashed_description: Optional[bytes] = None, unhashed_description: Optional[bytes] = None,
**kwargs, **kwargs,
) -> InvoiceResponse: ) -> InvoiceResponse:
data: Dict = {"out": False, "amount": amount} data: Dict = {"out": False, "amount": amount, "memo": memo or ""}
if kwargs.get("expiry"): if kwargs.get("expiry"):
data["expiry"] = kwargs["expiry"] data["expiry"] = kwargs["expiry"]
if description_hash: if description_hash:
@ -68,8 +69,6 @@ class LNbitsWallet(Wallet):
if unhashed_description: if unhashed_description:
data["unhashed_description"] = unhashed_description.hex() data["unhashed_description"] = unhashed_description.hex()
data["memo"] = memo or ""
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.post( r = await client.post(
url=f"{self.endpoint}/api/v1/payments", headers=self.key, json=data url=f"{self.endpoint}/api/v1/payments", headers=self.key, json=data

View File

@ -105,9 +105,6 @@ class LndWallet(Wallet):
) )
endpoint = settings.lnd_grpc_endpoint endpoint = settings.lnd_grpc_endpoint
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
self.port = int(settings.lnd_grpc_port)
self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert
macaroon = ( macaroon = (
settings.lnd_grpc_macaroon settings.lnd_grpc_macaroon
@ -122,8 +119,17 @@ class LndWallet(Wallet):
macaroon = AESCipher(description="macaroon decryption").decrypt( macaroon = AESCipher(description="macaroon decryption").decrypt(
encrypted_macaroon encrypted_macaroon
) )
self.macaroon = load_macaroon(macaroon)
cert_path = settings.lnd_grpc_cert or settings.lnd_cert
if not endpoint or not macaroon or not cert_path or not settings.lnd_grpc_port:
raise Exception("cannot initialize lndrest")
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
self.port = int(settings.lnd_grpc_port)
self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert
self.macaroon = load_macaroon(macaroon)
self.cert_path = cert_path
cert = open(self.cert_path, "rb").read() cert = open(self.cert_path, "rb").read()
creds = grpc.ssl_channel_credentials(cert) creds = grpc.ssl_channel_credentials(cert)
auth_creds = grpc.metadata_call_credentials(self.metadata_callback) auth_creds = grpc.metadata_call_credentials(self.metadata_callback)
@ -140,8 +146,6 @@ class LndWallet(Wallet):
async def status(self) -> StatusResponse: async def status(self) -> StatusResponse:
try: try:
resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest()) resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest())
except RpcError as exc:
return StatusResponse(str(exc._details), 0)
except Exception as exc: except Exception as exc:
return StatusResponse(str(exc), 0) return StatusResponse(str(exc), 0)
@ -155,20 +159,23 @@ class LndWallet(Wallet):
unhashed_description: Optional[bytes] = None, unhashed_description: Optional[bytes] = None,
**kwargs, **kwargs,
) -> InvoiceResponse: ) -> InvoiceResponse:
params: Dict = {"value": amount, "private": True} data: Dict = {
"description_hash": b"",
"value": amount,
"private": True,
"memo": memo or "",
}
if kwargs.get("expiry"): if kwargs.get("expiry"):
params["expiry"] = kwargs["expiry"] data["expiry"] = kwargs["expiry"]
if description_hash: if description_hash:
params["description_hash"] = description_hash data["description_hash"] = description_hash
elif unhashed_description: elif unhashed_description:
params["description_hash"] = hashlib.sha256( data["description_hash"] = hashlib.sha256(
unhashed_description unhashed_description
).digest() # as bytes directly ).digest() # as bytes directly
else:
params["memo"] = memo or ""
try: try:
req = ln.Invoice(**params) req = ln.Invoice(**data)
resp = await self.rpc.AddInvoice(req) resp = await self.rpc.AddInvoice(req)
except Exception as exc: except Exception as exc:
error_message = str(exc) error_message = str(exc)
@ -188,8 +195,6 @@ class LndWallet(Wallet):
) )
try: try:
resp = await self.routerpc.SendPaymentV2(req).read() resp = await self.routerpc.SendPaymentV2(req).read()
except RpcError as exc:
return PaymentResponse(False, None, None, None, exc._details)
except Exception as exc: except Exception as exc:
return PaymentResponse(False, None, None, None, str(exc)) return PaymentResponse(False, None, None, None, str(exc))

View File

@ -24,11 +24,6 @@ class LndRestWallet(Wallet):
def __init__(self): def __init__(self):
endpoint = settings.lnd_rest_endpoint endpoint = settings.lnd_rest_endpoint
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
endpoint = (
f"https://{endpoint}" if not endpoint.startswith("http") else endpoint
)
self.endpoint = endpoint
macaroon = ( macaroon = (
settings.lnd_rest_macaroon settings.lnd_rest_macaroon
@ -43,6 +38,15 @@ class LndRestWallet(Wallet):
macaroon = AESCipher(description="macaroon decryption").decrypt( macaroon = AESCipher(description="macaroon decryption").decrypt(
encrypted_macaroon encrypted_macaroon
) )
if not endpoint or not macaroon or not settings.lnd_rest_cert:
raise Exception("cannot initialize lndrest")
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
endpoint = (
f"https://{endpoint}" if not endpoint.startswith("http") else endpoint
)
self.endpoint = endpoint
self.macaroon = load_macaroon(macaroon) self.macaroon = load_macaroon(macaroon)
self.auth = {"Grpc-Metadata-macaroon": self.macaroon} self.auth = {"Grpc-Metadata-macaroon": self.macaroon}
@ -74,7 +78,7 @@ class LndRestWallet(Wallet):
unhashed_description: Optional[bytes] = None, unhashed_description: Optional[bytes] = None,
**kwargs, **kwargs,
) -> InvoiceResponse: ) -> InvoiceResponse:
data: Dict = {"value": amount, "private": True} data: Dict = {"value": amount, "private": True, "memo": memo or ""}
if kwargs.get("expiry"): if kwargs.get("expiry"):
data["expiry"] = kwargs["expiry"] data["expiry"] = kwargs["expiry"]
if description_hash: if description_hash:
@ -85,8 +89,6 @@ class LndRestWallet(Wallet):
data["description_hash"] = base64.b64encode( data["description_hash"] = base64.b64encode(
hashlib.sha256(unhashed_description).digest() hashlib.sha256(unhashed_description).digest()
).decode("ascii") ).decode("ascii")
else:
data["memo"] = memo or ""
async with httpx.AsyncClient(verify=self.cert) as client: async with httpx.AsyncClient(verify=self.cert) as client:
r = await client.post( r = await client.post(

View File

@ -5,7 +5,7 @@ from http import HTTPStatus
from typing import AsyncGenerator, Dict, Optional from typing import AsyncGenerator, Dict, Optional
import httpx import httpx
from fastapi.exceptions import HTTPException from fastapi import HTTPException
from loguru import logger from loguru import logger
from lnbits.settings import settings from lnbits.settings import settings
@ -24,8 +24,13 @@ class LNPayWallet(Wallet):
def __init__(self): def __init__(self):
endpoint = settings.lnpay_api_endpoint endpoint = settings.lnpay_api_endpoint
wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key
if not endpoint or not wallet_key or not settings.lnpay_api_key:
raise Exception("cannot initialize lnpay")
self.wallet_key = wallet_key
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
self.wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key
self.auth = {"X-Api-Key": settings.lnpay_api_key} self.auth = {"X-Api-Key": settings.lnpay_api_key}
async def status(self) -> StatusResponse: async def status(self) -> StatusResponse:
@ -134,7 +139,9 @@ class LNPayWallet(Wallet):
yield value yield value
async def webhook_listener(self): async def webhook_listener(self):
text: str = await request.get_data() # TODO: request.get_data is undefined, was it something with Flask or quart?
# probably issue introduced when refactoring?
text: str = await request.get_data() # type: ignore
try: try:
data = json.loads(text) data = json.loads(text)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:

View File

@ -21,13 +21,14 @@ from .base import (
class LnTipsWallet(Wallet): class LnTipsWallet(Wallet):
def __init__(self): def __init__(self):
endpoint = settings.lntips_api_endpoint endpoint = settings.lntips_api_endpoint
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
key = ( key = (
settings.lntips_api_key settings.lntips_api_key
or settings.lntips_admin_key or settings.lntips_admin_key
or settings.lntips_invoice_key or settings.lntips_invoice_key
) )
if not endpoint or not key:
raise Exception("cannot initialize lntxbod")
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
self.auth = {"Authorization": f"Basic {key}"} self.auth = {"Authorization": f"Basic {key}"}
async def status(self) -> StatusResponse: async def status(self) -> StatusResponse:
@ -55,13 +56,11 @@ class LnTipsWallet(Wallet):
unhashed_description: Optional[bytes] = None, unhashed_description: Optional[bytes] = None,
**kwargs, **kwargs,
) -> InvoiceResponse: ) -> InvoiceResponse:
data: Dict = {"amount": amount} data: Dict = {"amount": amount, "description_hash": "", "memo": memo or ""}
if description_hash: if description_hash:
data["description_hash"] = description_hash.hex() data["description_hash"] = description_hash.hex()
elif unhashed_description: elif unhashed_description:
data["description_hash"] = hashlib.sha256(unhashed_description).hexdigest() data["description_hash"] = hashlib.sha256(unhashed_description).hexdigest()
else:
data["memo"] = memo or ""
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.post( r = await client.post(

View File

@ -4,7 +4,7 @@ from http import HTTPStatus
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
import httpx import httpx
from fastapi.exceptions import HTTPException from fastapi import HTTPException
from loguru import logger from loguru import logger
from lnbits.settings import settings from lnbits.settings import settings
@ -24,13 +24,15 @@ class OpenNodeWallet(Wallet):
def __init__(self): def __init__(self):
endpoint = settings.opennode_api_endpoint endpoint = settings.opennode_api_endpoint
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
key = ( key = (
settings.opennode_key settings.opennode_key
or settings.opennode_admin_key or settings.opennode_admin_key
or settings.opennode_invoice_key or settings.opennode_invoice_key
) )
if not endpoint or not key:
raise Exception("cannot initialize opennode")
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
self.auth = {"Authorization": key} self.auth = {"Authorization": key}
async def status(self) -> StatusResponse: async def status(self) -> StatusResponse:
@ -140,7 +142,9 @@ class OpenNodeWallet(Wallet):
yield value yield value
async def webhook_listener(self): async def webhook_listener(self):
data = await request.form # TODO: request.form is undefined, was it something with Flask or quart?
# probably issue introduced when refactoring?
data = await request.form # type: ignore
if "status" not in data or data["status"] != "paid": if "status" not in data or data["status"] != "paid":
raise HTTPException(status_code=HTTPStatus.NO_CONTENT) raise HTTPException(status_code=HTTPStatus.NO_CONTENT)

View File

@ -28,6 +28,7 @@ class UnknownError(Exception):
class SparkWallet(Wallet): class SparkWallet(Wallet):
def __init__(self): def __init__(self):
assert settings.spark_url, "spark url does not exist"
self.url = settings.spark_url.replace("/rpc", "") self.url = settings.spark_url.replace("/rpc", "")
self.token = settings.spark_token self.token = settings.spark_token
@ -46,6 +47,7 @@ class SparkWallet(Wallet):
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
assert self.token, "spark wallet token does not exist"
r = await client.post( r = await client.post(
self.url + "/rpc", self.url + "/rpc",
headers={"X-Access": self.token}, headers={"X-Access": self.token},
@ -133,38 +135,49 @@ class SparkWallet(Wallet):
bolt11=bolt11, bolt11=bolt11,
maxfee=fee_limit_msat, maxfee=fee_limit_msat,
) )
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
preimage = r["payment_preimage"]
return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None)
except (SparkError, UnknownError) as exc: except (SparkError, UnknownError) as exc:
listpays = await self.listpays(bolt11) listpays = await self.listpays(bolt11)
if listpays: if not listpays:
pays = listpays["pays"] return PaymentResponse(False, None, None, None, str(exc))
if len(pays) == 0: pays = listpays["pays"]
return PaymentResponse(False, None, None, None, str(exc))
pay = pays[0] if len(pays) == 0:
payment_hash = pay["payment_hash"] return PaymentResponse(False, None, None, None, str(exc))
if len(pays) > 1: pay = pays[0]
raise SparkError( payment_hash = pay["payment_hash"]
f"listpays({payment_hash}) returned an unexpected response: {listpays}"
)
if pay["status"] == "failed": if len(pays) > 1:
return PaymentResponse(False, None, None, None, str(exc)) raise SparkError(
elif pay["status"] == "pending": f"listpays({payment_hash}) returned an unexpected response: {listpays}"
return PaymentResponse(None, payment_hash, None, None, None) )
elif pay["status"] == "complete":
r = pay
r["payment_preimage"] = pay["preimage"]
r["msatoshi"] = int(pay["amount_msat"][0:-4])
r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4])
# this may result in an error if it was paid previously
# our database won't allow the same payment_hash to be added twice
# this is good
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"]) if pay["status"] == "failed":
preimage = r["payment_preimage"] return PaymentResponse(False, None, None, None, str(exc))
return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None)
if pay["status"] == "pending":
return PaymentResponse(None, payment_hash, None, None, None)
if pay["status"] == "complete":
r = pay
r["payment_preimage"] = pay["preimage"]
r["msatoshi"] = int(pay["amount_msat"][0:-4])
r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4])
# this may result in an error if it was paid previously
# our database won't allow the same payment_hash to be added twice
# this is good
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
preimage = r["payment_preimage"]
return PaymentResponse(
True, r["payment_hash"], fee_msat, preimage, None
)
else:
return PaymentResponse(False, None, None, None, str(exc))
async def get_invoice_status(self, checking_id: str) -> PaymentStatus: async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
try: try:
@ -205,7 +218,7 @@ class SparkWallet(Wallet):
- int(r["pays"][0]["amount_msat"][0:-4]) - int(r["pays"][0]["amount_msat"][0:-4])
) )
return PaymentStatus(True, fee_msat, r["pays"][0]["preimage"]) return PaymentStatus(True, fee_msat, r["pays"][0]["preimage"])
elif status == "failed": if status == "failed":
return PaymentStatus(False) return PaymentStatus(False)
return PaymentStatus(None) return PaymentStatus(None)
raise KeyError("supplied an invalid checking_id") raise KeyError("supplied an invalid checking_id")

View File

@ -69,9 +69,6 @@ include = [
] ]
exclude = [ exclude = [
"lnbits/wallets/lnd_grpc_files", "lnbits/wallets/lnd_grpc_files",
"lnbits/wallets",
"lnbits/core",
"lnbits/*.py",
"lnbits/extensions", "lnbits/extensions",
] ]