Merge pull request #1468 from lnbits/pyright3
introduce pyright + fix issues (supersedes #1444)
This commit is contained in:
commit
47df94178e
2
Makefile
2
Makefile
|
@ -1,6 +1,6 @@
|
|||
.PHONY: test
|
||||
|
||||
all: format check requirements.txt
|
||||
all: format check
|
||||
|
||||
format: prettier isort black
|
||||
|
||||
|
|
|
@ -66,11 +66,12 @@ def decode(pr: str) -> Invoice:
|
|||
invoice.amount_msat = _unshorten_amount(amountstr)
|
||||
|
||||
# pull out date
|
||||
invoice.date = data.read(35).uint
|
||||
date_bin = data.read(35)
|
||||
invoice.date = date_bin.uint # type: ignore
|
||||
|
||||
while data.pos != data.len:
|
||||
tag, tagdata, data = _pull_tagged(data)
|
||||
data_length = len(tagdata) / 5
|
||||
data_length = len(tagdata or []) / 5
|
||||
|
||||
if tag == "d":
|
||||
invoice.description = _trim_to_bytes(tagdata).decode()
|
||||
|
@ -79,7 +80,7 @@ def decode(pr: str) -> Invoice:
|
|||
elif tag == "p" and data_length == 52:
|
||||
invoice.payment_hash = _trim_to_bytes(tagdata).hex()
|
||||
elif tag == "x":
|
||||
invoice.expiry = tagdata.uint
|
||||
invoice.expiry = tagdata.uint # type: ignore
|
||||
elif tag == "n":
|
||||
invoice.payee = _trim_to_bytes(tagdata).hex()
|
||||
# this won't work in most cases, we must extract the payee
|
||||
|
@ -90,11 +91,11 @@ def decode(pr: str) -> Invoice:
|
|||
s = bitstring.ConstBitStream(tagdata)
|
||||
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
||||
route = Route(
|
||||
pubkey=s.read(264).tobytes().hex(),
|
||||
short_channel_id=_readable_scid(s.read(64).intbe),
|
||||
base_fee_msat=s.read(32).intbe,
|
||||
ppm_fee=s.read(32).intbe,
|
||||
cltv=s.read(16).intbe,
|
||||
pubkey=s.read(264).tobytes().hex(), # type: ignore
|
||||
short_channel_id=_readable_scid(s.read(64).intbe), # type: ignore
|
||||
base_fee_msat=s.read(32).intbe, # type: ignore
|
||||
ppm_fee=s.read(32).intbe, # type: ignore
|
||||
cltv=s.read(16).intbe, # type: ignore
|
||||
)
|
||||
invoice.route_hints.append(route)
|
||||
|
||||
|
@ -202,7 +203,8 @@ def lnencode(addr, privkey):
|
|||
)
|
||||
data += tagged("r", route)
|
||||
elif k == "f":
|
||||
data += encode_fallback(v, addr.currency)
|
||||
# NOTE: there was an error fallback here that's now removed
|
||||
continue
|
||||
elif k == "d":
|
||||
data += tagged_bytes("d", v.encode())
|
||||
elif k == "x":
|
||||
|
@ -244,7 +246,13 @@ def lnencode(addr, privkey):
|
|||
|
||||
class LnAddr:
|
||||
def __init__(
|
||||
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None
|
||||
self,
|
||||
paymenthash=None,
|
||||
amount=None,
|
||||
currency="bc",
|
||||
tags=None,
|
||||
date=None,
|
||||
fallback=None,
|
||||
):
|
||||
self.date = int(time.time()) if not date else int(date)
|
||||
self.tags = [] if not tags else tags
|
||||
|
@ -252,11 +260,13 @@ class LnAddr:
|
|||
self.paymenthash = paymenthash
|
||||
self.signature = None
|
||||
self.pubkey = None
|
||||
self.fallback = fallback
|
||||
self.currency = currency
|
||||
self.amount = amount
|
||||
|
||||
def __str__(self):
|
||||
pubkey = bytes.hex(self.pubkey.serialize()).decode()
|
||||
assert self.pubkey, "LnAddr, pubkey must be set"
|
||||
pubkey = bytes.hex(self.pubkey.serialize())
|
||||
tags = ", ".join([f"{k}={v}" for k, v in self.tags])
|
||||
return f"LnAddr[{pubkey}, amount={self.amount}{self.currency} tags=[{tags}]]"
|
||||
|
||||
|
@ -266,6 +276,7 @@ def shorten_amount(amount):
|
|||
# Convert to pico initially
|
||||
amount = int(amount * 10**12)
|
||||
units = ["p", "n", "u", "m", ""]
|
||||
unit = ""
|
||||
for unit in units:
|
||||
if amount % 1000 == 0:
|
||||
amount //= 1000
|
||||
|
@ -304,14 +315,6 @@ def _pull_tagged(stream):
|
|||
return (CHARSET[tag], stream.read(length * 5), stream)
|
||||
|
||||
|
||||
def is_p2pkh(currency, prefix):
|
||||
return prefix == base58_prefix_map[currency][0]
|
||||
|
||||
|
||||
def is_p2sh(currency, prefix):
|
||||
return prefix == base58_prefix_map[currency][1]
|
||||
|
||||
|
||||
# Tagged field containing BitArray
|
||||
def tagged(char, l):
|
||||
# Tagged fields need to be zero-padded to 5 bits.
|
||||
|
@ -359,5 +362,5 @@ def bitarray_to_u5(barr):
|
|||
ret = []
|
||||
s = bitstring.ConstBitStream(barr)
|
||||
while s.pos != s.len:
|
||||
ret.append(s.read(5).uint)
|
||||
ret.append(s.read(5).uint) # type: ignore
|
||||
return ret
|
||||
|
|
|
@ -41,6 +41,7 @@ async def migrate_databases():
|
|||
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
||||
|
||||
async with core_db.connect() as conn:
|
||||
exists = False
|
||||
if conn.type == SQLITE:
|
||||
exists = await conn.fetchone(
|
||||
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
||||
|
|
|
@ -206,7 +206,7 @@ async def create_wallet(
|
|||
async def update_wallet(
|
||||
wallet_id: str, new_name: str, conn: Optional[Connection] = None
|
||||
) -> Optional[Wallet]:
|
||||
return await (conn or db).execute(
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
UPDATE wallets SET
|
||||
name = ?
|
||||
|
@ -214,6 +214,9 @@ async def update_wallet(
|
|||
""",
|
||||
(new_name, wallet_id),
|
||||
)
|
||||
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
assert wallet, "updated created wallet couldn't be retrieved"
|
||||
return wallet
|
||||
|
||||
|
||||
async def delete_wallet(
|
||||
|
@ -393,7 +396,7 @@ async def get_payments(
|
|||
clause.append("checking_id NOT LIKE 'internal_%'")
|
||||
|
||||
if not filters:
|
||||
filters = Filters()
|
||||
filters = Filters(limit=None, offset=None)
|
||||
|
||||
rows = await (conn or db).fetchall(
|
||||
f"""
|
||||
|
@ -712,15 +715,19 @@ async def update_admin_settings(data: EditableSettings):
|
|||
await db.execute("UPDATE settings SET editable_settings = ?", (json.dumps(data),))
|
||||
|
||||
|
||||
async def update_super_user(super_user: str):
|
||||
async def update_super_user(super_user: str) -> SuperSettings:
|
||||
await db.execute("UPDATE settings SET super_user = ?", (super_user,))
|
||||
return await get_super_settings()
|
||||
settings = await get_super_settings()
|
||||
assert settings, "updated super_user settings could not be retrieved"
|
||||
return settings
|
||||
|
||||
|
||||
async def create_admin_settings(super_user: str, new_settings: dict):
|
||||
sql = "INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)"
|
||||
await db.execute(sql, (super_user, json.dumps(new_settings)))
|
||||
return await get_super_settings()
|
||||
settings = await get_super_settings()
|
||||
assert settings, "created admin settings could not be retrieved"
|
||||
return settings
|
||||
|
||||
|
||||
# db versions
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
|
@ -17,6 +17,7 @@ from lnbits.helpers import url_for
|
|||
from lnbits.settings import (
|
||||
FAKE_WALLET,
|
||||
EditableSettings,
|
||||
SuperSettings,
|
||||
get_wallet_class,
|
||||
readonly_variables,
|
||||
send_admin_user_to_saas,
|
||||
|
@ -43,11 +44,6 @@ from .crud import (
|
|||
)
|
||||
from .models import Payment
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError: # pragma: nocover
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class PaymentFailure(Exception):
|
||||
pass
|
||||
|
@ -188,7 +184,7 @@ async def pay_invoice(
|
|||
|
||||
# do the balance check
|
||||
wallet = await get_wallet(wallet_id, conn=conn)
|
||||
assert wallet
|
||||
assert wallet, "Wallet for balancecheck could not be fetched"
|
||||
if wallet.balance_msat < 0:
|
||||
logger.debug("balance is too low, deleting temporary payment")
|
||||
if not internal_checking_id and wallet.balance_msat > -fee_reserve_msat:
|
||||
|
@ -336,19 +332,19 @@ async def perform_lnurlauth(
|
|||
|
||||
return b
|
||||
|
||||
def encode_strict_der(r_int, s_int, order):
|
||||
def encode_strict_der(r: int, s: int, order: int):
|
||||
# if s > order/2 verification will fail sometimes
|
||||
# so we must fix it here (see https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147)
|
||||
if s_int > order // 2:
|
||||
s_int = order - s_int
|
||||
if s > order // 2:
|
||||
s = order - s
|
||||
|
||||
# now we do the strict DER encoding copied from
|
||||
# https://github.com/KiriKiri/bip66 (without any checks)
|
||||
r = int_to_bytes_suitable_der(r_int)
|
||||
s = int_to_bytes_suitable_der(s_int)
|
||||
r_temp = int_to_bytes_suitable_der(r)
|
||||
s_temp = int_to_bytes_suitable_der(s)
|
||||
|
||||
r_len = len(r)
|
||||
s_len = len(s)
|
||||
r_len = len(r_temp)
|
||||
s_len = len(s_temp)
|
||||
sign_len = 6 + r_len + s_len
|
||||
|
||||
signature = BytesIO()
|
||||
|
@ -356,16 +352,17 @@ async def perform_lnurlauth(
|
|||
signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
|
||||
signature.write(0x02.to_bytes(1, "big", signed=False))
|
||||
signature.write(r_len.to_bytes(1, "big", signed=False))
|
||||
signature.write(r)
|
||||
signature.write(r_temp)
|
||||
signature.write(0x02.to_bytes(1, "big", signed=False))
|
||||
signature.write(s_len.to_bytes(1, "big", signed=False))
|
||||
signature.write(s)
|
||||
signature.write(s_temp)
|
||||
|
||||
return signature.getvalue()
|
||||
|
||||
sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
assert key.verifying_key, "LNURLauth verifying_key does not exist"
|
||||
r = await client.get(
|
||||
callback,
|
||||
params={
|
||||
|
@ -469,7 +466,7 @@ def update_cached_settings(sets_dict: dict):
|
|||
setattr(settings, "super_user", sets_dict["super_user"])
|
||||
|
||||
|
||||
async def init_admin_settings(super_user: str = None):
|
||||
async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings:
|
||||
account = None
|
||||
if super_user:
|
||||
account = await get_account(super_user)
|
||||
|
|
|
@ -411,8 +411,7 @@ async def subscribe_wallet_invoices(request: Request, wallet: Wallet):
|
|||
typ, data = await send_queue.get()
|
||||
if data:
|
||||
jdata = json.dumps(dict(data.dict(), pending=False))
|
||||
|
||||
yield dict(data=jdata, event=typ)
|
||||
yield dict(data=jdata, event=typ)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"removing listener for wallet {uid}")
|
||||
api_invoice_listeners.pop(uid)
|
||||
|
@ -431,11 +430,12 @@ async def api_payments_sse(
|
|||
)
|
||||
|
||||
|
||||
# TODO: refactor this route into a public and admin one
|
||||
@core_app.get("/api/v1/payments/{payment_hash}")
|
||||
async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)):
|
||||
# We use X_Api_Key here because we want this call to work with and without keys
|
||||
# If a valid key is given, we also return the field "details", otherwise not
|
||||
wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None
|
||||
wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None # type: ignore
|
||||
|
||||
# we have to specify the wallet id here, because postgres and sqlite return internal payments in different order
|
||||
# and get_standalone_payment otherwise just fetches the first one, causing unpredictable results
|
||||
|
@ -505,6 +505,7 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type
|
|||
params.update(callback=url) # with k1 already in it
|
||||
|
||||
lnurlauth_key = wallet.wallet.lnurlauth_key(domain)
|
||||
assert lnurlauth_key.verifying_key
|
||||
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
@ -693,7 +694,7 @@ async def api_auditor():
|
|||
if not error_message:
|
||||
delta = node_balance - total_balance
|
||||
else:
|
||||
node_balance, delta = None, None
|
||||
node_balance, delta = 0, 0
|
||||
|
||||
return {
|
||||
"node_balance_msats": int(node_balance),
|
||||
|
@ -745,6 +746,7 @@ async def api_install_extension(
|
|||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND, detail="Release not found"
|
||||
)
|
||||
|
||||
ext_info = InstallableExtension(
|
||||
id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon
|
||||
)
|
||||
|
@ -824,8 +826,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)
|
|||
)
|
||||
|
||||
|
||||
@core_app.get("/api/v1/extension/{ext_id}/releases")
|
||||
async def get_extension_releases(ext_id: str, user: User = Depends(check_admin)):
|
||||
@core_app.get(
|
||||
"/api/v1/extension/{ext_id}/releases", dependencies=[Depends(check_admin)]
|
||||
)
|
||||
async def get_extension_releases(ext_id: str):
|
||||
try:
|
||||
extension_releases: List[
|
||||
ExtensionRelease
|
||||
|
|
|
@ -40,19 +40,18 @@ async def api_public_payment_longpolling(payment_hash):
|
|||
|
||||
response = None
|
||||
|
||||
async def payment_info_receiver(cancel_scope):
|
||||
async for payment in payment_queue.get():
|
||||
async def payment_info_receiver():
|
||||
for payment in await payment_queue.get():
|
||||
if payment.payment_hash == payment_hash:
|
||||
nonlocal response
|
||||
response = {"status": "paid"}
|
||||
cancel_scope.cancel()
|
||||
|
||||
async def timeouter(cancel_scope):
|
||||
await asyncio.sleep(45)
|
||||
cancel_scope.cancel()
|
||||
|
||||
asyncio.create_task(payment_info_receiver())
|
||||
asyncio.create_task(timeouter())
|
||||
cancel_scope = asyncio.create_task(payment_info_receiver())
|
||||
asyncio.create_task(timeouter(cancel_scope))
|
||||
|
||||
if response:
|
||||
return response
|
||||
|
|
20
lnbits/db.py
20
lnbits/db.py
|
@ -131,7 +131,7 @@ class Database(Compat):
|
|||
else:
|
||||
self.type = POSTGRES
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import DECIMAL, new_type, register_type
|
||||
|
||||
def _parse_timestamp(value, _):
|
||||
if value is None:
|
||||
|
@ -141,15 +141,15 @@ class Database(Compat):
|
|||
f = "%Y-%m-%d %H:%M:%S"
|
||||
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
psycopg2.extensions.DECIMAL.values,
|
||||
register_type(
|
||||
new_type(
|
||||
DECIMAL.values,
|
||||
"DEC2FLOAT",
|
||||
lambda value, curs: float(value) if value is not None else None,
|
||||
)
|
||||
)
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
register_type(
|
||||
new_type(
|
||||
(1082, 1083, 1266),
|
||||
"DATE2INT",
|
||||
lambda value, curs: time.mktime(value.timetuple())
|
||||
|
@ -158,11 +158,7 @@ class Database(Compat):
|
|||
)
|
||||
)
|
||||
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
(1184, 1114), "TIMESTAMP2INT", _parse_timestamp
|
||||
)
|
||||
)
|
||||
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
|
||||
else:
|
||||
if os.path.isdir(settings.lnbits_data_folder):
|
||||
self.path = os.path.join(
|
||||
|
@ -189,7 +185,7 @@ class Database(Compat):
|
|||
async def connect(self):
|
||||
await self.lock.acquire()
|
||||
try:
|
||||
async with self.engine.connect() as conn:
|
||||
async with self.engine.connect() as conn: # type: ignore
|
||||
async with conn.begin() as txn:
|
||||
wconn = Connection(conn, txn, self.type, self.name, self.schema)
|
||||
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Optional, Type
|
||||
|
||||
from fastapi import Security, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import HTTPException, Request, Security, status
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
from fastapi.security.base import SecurityBase
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import UUID4
|
||||
from starlette.requests import Request
|
||||
|
||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||
from lnbits.core.models import User, Wallet
|
||||
|
@ -17,9 +15,13 @@ from lnbits.requestvars import g
|
|||
from lnbits.settings import settings
|
||||
|
||||
|
||||
# TODO: fix type ignores
|
||||
class KeyChecker(SecurityBase):
|
||||
def __init__(
|
||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
||||
self,
|
||||
scheme_name: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
@ -27,13 +29,13 @@ class KeyChecker(SecurityBase):
|
|||
self._api_key = api_key
|
||||
if api_key:
|
||||
key = APIKey(
|
||||
**{"in": APIKeyIn.query},
|
||||
**{"in": APIKeyIn.query}, # type: ignore
|
||||
name="X-API-KEY",
|
||||
description="Wallet API Key - QUERY",
|
||||
)
|
||||
else:
|
||||
key = APIKey(
|
||||
**{"in": APIKeyIn.header},
|
||||
**{"in": APIKeyIn.header}, # type: ignore
|
||||
name="X-API-KEY",
|
||||
description="Wallet API Key - HEADER",
|
||||
)
|
||||
|
@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
||||
self,
|
||||
scheme_name: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(scheme_name, auto_error, api_key)
|
||||
self._key_type = "invoice"
|
||||
|
@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
||||
self,
|
||||
scheme_name: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(scheme_name, auto_error, api_key)
|
||||
self._key_type = "admin"
|
||||
|
|
|
@ -3,20 +3,145 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
from urllib import request
|
||||
|
||||
import httpx
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
||||
class ExplicitRelease(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
version: str
|
||||
archive: str
|
||||
hash: str
|
||||
dependencies: List[str] = []
|
||||
icon: Optional[str]
|
||||
short_description: Optional[str]
|
||||
html_url: Optional[str]
|
||||
details: Optional[str]
|
||||
info_notification: Optional[str]
|
||||
critical_notification: Optional[str]
|
||||
|
||||
|
||||
class GitHubRelease(BaseModel):
|
||||
id: str
|
||||
organisation: str
|
||||
repository: str
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
featured: List[str] = []
|
||||
extensions: List["ExplicitRelease"] = []
|
||||
repos: List["GitHubRelease"] = []
|
||||
|
||||
|
||||
class GitHubRepoRelease(BaseModel):
|
||||
name: str
|
||||
tag_name: str
|
||||
zipball_url: str
|
||||
html_url: str
|
||||
|
||||
|
||||
class GitHubRepo(BaseModel):
|
||||
stargazers_count: str
|
||||
html_url: str
|
||||
default_branch: str
|
||||
|
||||
|
||||
class ExtensionConfig(BaseModel):
|
||||
name: str
|
||||
short_description: str
|
||||
tile: str = ""
|
||||
|
||||
|
||||
def download_url(url, save_path):
|
||||
with request.urlopen(url) as dl_file:
|
||||
with open(save_path, "wb") as out_file:
|
||||
out_file.write(dl_file.read())
|
||||
|
||||
|
||||
def file_hash(filename):
|
||||
h = hashlib.sha256()
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
with open(filename, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
async def fetch_github_repo_info(
|
||||
org: str, repository: str
|
||||
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
||||
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
||||
error_msg = "Cannot fetch extension repo"
|
||||
repo = await gihub_api_get(repo_url, error_msg)
|
||||
github_repo = GitHubRepo.parse_obj(repo)
|
||||
|
||||
lates_release_url = (
|
||||
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
||||
)
|
||||
error_msg = "Cannot fetch extension releases"
|
||||
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
|
||||
|
||||
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
||||
error_msg = "Cannot fetch config for extension"
|
||||
config = await gihub_api_get(config_url, error_msg)
|
||||
|
||||
return (
|
||||
github_repo,
|
||||
GitHubRepoRelease.parse_obj(latest_release),
|
||||
ExtensionConfig.parse_obj(config),
|
||||
)
|
||||
|
||||
|
||||
async def fetch_manifest(url) -> Manifest:
|
||||
error_msg = "Cannot fetch extensions manifest"
|
||||
manifest = await gihub_api_get(url, error_msg)
|
||||
return Manifest.parse_obj(manifest)
|
||||
|
||||
|
||||
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
|
||||
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
||||
error_msg = "Cannot fetch extension releases"
|
||||
releases = await gihub_api_get(releases_url, error_msg)
|
||||
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
||||
|
||||
|
||||
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = (
|
||||
{"Authorization": "Bearer " + settings.lnbits_ext_github_token}
|
||||
if settings.lnbits_ext_github_token
|
||||
else None
|
||||
)
|
||||
resp = await client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
||||
if not path:
|
||||
return ""
|
||||
_, _, *rest = path.split("/")
|
||||
tail = "/".join(rest)
|
||||
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
||||
|
||||
|
||||
class Extension(NamedTuple):
|
||||
code: str
|
||||
is_valid: bool
|
||||
|
@ -97,12 +222,12 @@ class ExtensionRelease(BaseModel):
|
|||
version: str
|
||||
archive: str
|
||||
source_repo: str
|
||||
is_github_release = False
|
||||
hash: Optional[str]
|
||||
html_url: Optional[str]
|
||||
description: Optional[str]
|
||||
is_github_release: bool = False
|
||||
hash: Optional[str] = None
|
||||
html_url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
details_html: Optional[str] = None
|
||||
icon: Optional[str]
|
||||
icon: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_github_release(
|
||||
|
@ -132,52 +257,6 @@ class ExtensionRelease(BaseModel):
|
|||
return []
|
||||
|
||||
|
||||
class ExplicitRelease(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
version: str
|
||||
archive: str
|
||||
hash: str
|
||||
dependencies: List[str] = []
|
||||
icon: Optional[str]
|
||||
short_description: Optional[str]
|
||||
html_url: Optional[str]
|
||||
details: Optional[str]
|
||||
info_notification: Optional[str]
|
||||
critical_notification: Optional[str]
|
||||
|
||||
|
||||
class GitHubRelease(BaseModel):
|
||||
id: str
|
||||
organisation: str
|
||||
repository: str
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
featured: List[str] = []
|
||||
extensions: List["ExplicitRelease"] = []
|
||||
repos: List["GitHubRelease"] = []
|
||||
|
||||
|
||||
class GitHubRepoRelease(BaseModel):
|
||||
name: str
|
||||
tag_name: str
|
||||
zipball_url: str
|
||||
html_url: str
|
||||
|
||||
|
||||
class GitHubRepo(BaseModel):
|
||||
stargazers_count: str
|
||||
html_url: str
|
||||
default_branch: str
|
||||
|
||||
|
||||
class ExtensionConfig(BaseModel):
|
||||
name: str
|
||||
short_description: str
|
||||
tile: str = ""
|
||||
|
||||
|
||||
class InstallableExtension(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
@ -187,8 +266,9 @@ class InstallableExtension(BaseModel):
|
|||
is_admin_only: bool = False
|
||||
stars: int = 0
|
||||
featured = False
|
||||
latest_release: Optional[ExtensionRelease]
|
||||
installed_release: Optional[ExtensionRelease]
|
||||
latest_release: Optional[ExtensionRelease] = None
|
||||
installed_release: Optional[ExtensionRelease] = None
|
||||
archive: Optional[str] = None
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
|
@ -234,6 +314,7 @@ class InstallableExtension(BaseModel):
|
|||
if ext_zip_file.is_file():
|
||||
os.remove(ext_zip_file)
|
||||
try:
|
||||
assert self.installed_release, "installed_release is none."
|
||||
download_url(self.installed_release.archive, ext_zip_file)
|
||||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
|
@ -334,8 +415,7 @@ class InstallableExtension(BaseModel):
|
|||
id=github_release.id,
|
||||
name=config.name,
|
||||
short_description=config.short_description,
|
||||
version="0",
|
||||
stars=repo.stargazers_count,
|
||||
stars=int(repo.stargazers_count),
|
||||
icon=icon_to_github_url(
|
||||
f"{github_release.organisation}/{github_release.repository}",
|
||||
config.tile,
|
||||
|
@ -354,7 +434,6 @@ class InstallableExtension(BaseModel):
|
|||
id=e.id,
|
||||
name=e.name,
|
||||
archive=e.archive,
|
||||
hash=e.hash,
|
||||
short_description=e.short_description,
|
||||
icon=e.icon,
|
||||
dependencies=e.dependencies,
|
||||
|
@ -453,82 +532,3 @@ def get_valid_extensions() -> List[Extension]:
|
|||
return [
|
||||
extension for extension in ExtensionManager().extensions if extension.is_valid
|
||||
]
|
||||
|
||||
|
||||
def download_url(url, save_path):
|
||||
with urllib.request.urlopen(url) as dl_file:
|
||||
with open(save_path, "wb") as out_file:
|
||||
out_file.write(dl_file.read())
|
||||
|
||||
|
||||
def file_hash(filename):
|
||||
h = hashlib.sha256()
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
with open(filename, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
||||
if not path:
|
||||
return ""
|
||||
_, _, *rest = path.split("/")
|
||||
tail = "/".join(rest)
|
||||
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
||||
|
||||
|
||||
async def fetch_github_repo_info(
|
||||
org: str, repository: str
|
||||
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
||||
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
||||
error_msg = "Cannot fetch extension repo"
|
||||
repo = await gihub_api_get(repo_url, error_msg)
|
||||
github_repo = GitHubRepo.parse_obj(repo)
|
||||
|
||||
lates_release_url = (
|
||||
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
||||
)
|
||||
error_msg = "Cannot fetch extension releases"
|
||||
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
|
||||
|
||||
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
||||
error_msg = "Cannot fetch config for extension"
|
||||
config = await gihub_api_get(config_url, error_msg)
|
||||
|
||||
return (
|
||||
github_repo,
|
||||
GitHubRepoRelease.parse_obj(latest_release),
|
||||
ExtensionConfig.parse_obj(config),
|
||||
)
|
||||
|
||||
|
||||
async def fetch_manifest(url) -> Manifest:
|
||||
error_msg = "Cannot fetch extensions manifest"
|
||||
manifest = await gihub_api_get(url, error_msg)
|
||||
return Manifest.parse_obj(manifest)
|
||||
|
||||
|
||||
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
|
||||
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
||||
error_msg = "Cannot fetch extension releases"
|
||||
releases = await gihub_api_get(releases_url, error_msg)
|
||||
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
||||
|
||||
|
||||
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = (
|
||||
{"Authorization": f"Bearer {settings.lnbits_ext_github_token}"}
|
||||
if settings.lnbits_ext_github_token
|
||||
else None
|
||||
)
|
||||
resp = await client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
|
|
@ -1,25 +1,18 @@
|
|||
# Borrowed from the excellent accent-starlette
|
||||
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
|
||||
|
||||
import typing
|
||||
|
||||
from starlette import templating
|
||||
from jinja2 import BaseLoader, Environment, pass_context
|
||||
from starlette.datastructures import QueryParams
|
||||
from starlette.requests import Request
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError: # pragma: nocover
|
||||
jinja2 = None # type: ignore
|
||||
from starlette.templating import Jinja2Templates as SuperJinja2Templates
|
||||
|
||||
|
||||
class Jinja2Templates(templating.Jinja2Templates):
|
||||
def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231
|
||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
||||
class Jinja2Templates(SuperJinja2Templates):
|
||||
def __init__(self, loader: BaseLoader) -> None:
|
||||
super().__init__("")
|
||||
self.env = self.get_environment(loader)
|
||||
|
||||
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
|
||||
@jinja2.pass_context
|
||||
def get_environment(self, loader: BaseLoader) -> Environment:
|
||||
@pass_context
|
||||
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
|
||||
request: Request = context["request"]
|
||||
return request.app.url_path_for(name, **path_params)
|
||||
|
@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates):
|
|||
values.update(new)
|
||||
return QueryParams(**values)
|
||||
|
||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
||||
env = Environment(loader=loader, autoescape=True)
|
||||
env.globals["url_for"] = url_for
|
||||
env.globals["url_params_update"] = url_params_update
|
||||
return env
|
||||
|
|
|
@ -26,6 +26,7 @@ class InstalledExtensionMiddleware:
|
|||
else:
|
||||
_, path_name = path_elements
|
||||
path_type = None
|
||||
rest = []
|
||||
|
||||
# block path for all users if the extension is disabled
|
||||
if path_name in settings.lnbits_deactivated_extensions:
|
||||
|
@ -88,7 +89,7 @@ class ExtensionsRedirectMiddleware:
|
|||
if "from_path" not in redirect:
|
||||
return False
|
||||
header_filters = (
|
||||
redirect["header_filters"] if "header_filters" in redirect else []
|
||||
redirect["header_filters"] if "header_filters" in redirect else {}
|
||||
)
|
||||
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
|
||||
header_filters, req_headers
|
||||
|
|
|
@ -24,6 +24,7 @@ def list_parse_fallback(v):
|
|||
|
||||
|
||||
class LNbitsSettings(BaseSettings):
|
||||
@classmethod
|
||||
def validate(cls, val):
|
||||
if type(val) == str:
|
||||
val = val.split(",") if val else []
|
||||
|
@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings):
|
|||
class LNbitsFundingSource(LNbitsSettings):
|
||||
lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
|
||||
lnbits_key: Optional[str] = Field(default=None)
|
||||
lnbits_admin_key: Optional[str] = Field(default=None)
|
||||
lnbits_invoice_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class ClicheFundingSource(LNbitsSettings):
|
||||
|
@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings):
|
|||
lnpay_api_endpoint: Optional[str] = Field(default=None)
|
||||
lnpay_api_key: Optional[str] = Field(default=None)
|
||||
lnpay_wallet_key: Optional[str] = Field(default=None)
|
||||
lnpay_admin_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class OpenNodeFundingSource(LNbitsSettings):
|
||||
opennode_api_endpoint: Optional[str] = Field(default=None)
|
||||
opennode_key: Optional[str] = Field(default=None)
|
||||
opennode_admin_key: Optional[str] = Field(default=None)
|
||||
opennode_invoice_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class SparkFundingSource(LNbitsSettings):
|
||||
|
@ -208,8 +214,9 @@ class EditableSettings(
|
|||
"lnbits_admin_extensions",
|
||||
pre=True,
|
||||
)
|
||||
@classmethod
|
||||
def validate_editable_settings(cls, val):
|
||||
return super().validate(cls, val)
|
||||
return super().validate(val)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
|
@ -281,8 +288,9 @@ class ReadOnlySettings(
|
|||
"lnbits_allowed_funding_sources",
|
||||
pre=True,
|
||||
)
|
||||
@classmethod
|
||||
def validate_readonly_settings(cls, val):
|
||||
return super().validate(cls, val)
|
||||
return super().validate(val)
|
||||
|
||||
@classmethod
|
||||
def readonly_fields(cls):
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
import traceback
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi.exceptions import HTTPException
|
||||
from loguru import logger
|
||||
|
@ -42,7 +42,7 @@ class SseListenersDict(dict):
|
|||
A dict of sse listeners.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = None):
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
|
@ -65,7 +65,7 @@ class SseListenersDict(dict):
|
|||
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
|
||||
|
||||
|
||||
def register_invoice_listener(send_chan: asyncio.Queue, name: str = None):
|
||||
def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None):
|
||||
"""
|
||||
A method intended for extensions (and core/tasks.py) to call when they want to be notified about
|
||||
new invoice payments incoming. Will emit all incoming payments.
|
||||
|
@ -164,7 +164,7 @@ async def check_pending_payments():
|
|||
async def perform_balance_checks():
|
||||
while True:
|
||||
for bc in await get_balance_checks():
|
||||
redeem_lnurl_withdraw(bc.wallet, bc.url)
|
||||
await redeem_lnurl_withdraw(bc.wallet, bc.url)
|
||||
|
||||
await asyncio.sleep(60 * 60 * 6) # every 6 hours
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
# flake8: noqa: F401
|
||||
|
||||
|
||||
from .cliche import ClicheWallet
|
||||
from .cln import CoreLightningWallet # legacy .env support
|
||||
from .cln import CoreLightningWallet
|
||||
from .cln import CoreLightningWallet as CLightningWallet
|
||||
from .eclair import EclairWallet
|
||||
from .fake import FakeWallet
|
||||
|
|
|
@ -22,6 +22,8 @@ class ClicheWallet(Wallet):
|
|||
|
||||
def __init__(self):
|
||||
self.endpoint = settings.cliche_endpoint
|
||||
if not self.endpoint:
|
||||
raise Exception("cannot initialize cliche")
|
||||
|
||||
async def status(self) -> StatusResponse:
|
||||
try:
|
||||
|
@ -36,7 +38,7 @@ class ClicheWallet(Wallet):
|
|||
data = json.loads(r)
|
||||
except:
|
||||
return StatusResponse(
|
||||
f"Failed to connect to {self.endpoint}, got: '{r.text[:200]}...'", 0
|
||||
f"Failed to connect to {self.endpoint}, got: '{r[:200]}...'", 0
|
||||
)
|
||||
|
||||
return StatusResponse(None, data["result"]["wallets"][0]["balance"])
|
||||
|
@ -89,6 +91,13 @@ class ClicheWallet(Wallet):
|
|||
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
|
||||
ws = create_connection(self.endpoint)
|
||||
ws.send(f"pay-invoice --invoice {bolt11}")
|
||||
checking_id, fee_msat, preimage, error_message, payment_ok = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
for _ in range(2):
|
||||
r = ws.recv()
|
||||
data = json.loads(r)
|
||||
|
@ -151,9 +160,9 @@ class ClicheWallet(Wallet):
|
|||
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
|
||||
while True:
|
||||
try:
|
||||
ws = await create_connection(self.endpoint)
|
||||
ws = create_connection(self.endpoint)
|
||||
while True:
|
||||
r = await ws.recv()
|
||||
r = ws.recv()
|
||||
data = json.loads(r)
|
||||
print(data)
|
||||
try:
|
||||
|
|
|
@ -7,10 +7,7 @@ from typing import AsyncGenerator, Dict, Optional
|
|||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
# TODO: https://github.com/lnbits/lnbits/issues/764
|
||||
# mypy https://github.com/aaugustin/websockets/issues/940
|
||||
from websockets import connect # type: ignore
|
||||
from websockets.client import connect
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
@ -34,11 +31,13 @@ class UnknownError(Exception):
|
|||
class EclairWallet(Wallet):
|
||||
def __init__(self):
|
||||
url = settings.eclair_url
|
||||
self.url = url[:-1] if url.endswith("/") else url
|
||||
passw = settings.eclair_pass
|
||||
if not url or not passw:
|
||||
raise Exception("cannot initialize eclair")
|
||||
|
||||
self.url = url[:-1] if url.endswith("/") else url
|
||||
self.ws_url = f"ws://{urllib.parse.urlsplit(self.url).netloc}/ws"
|
||||
|
||||
passw = settings.eclair_pass
|
||||
encodedAuth = base64.b64encode(f":{passw}".encode())
|
||||
auth = str(encodedAuth, "utf-8")
|
||||
self.auth = {"Authorization": f"Basic {auth}"}
|
||||
|
@ -71,7 +70,11 @@ class EclairWallet(Wallet):
|
|||
**kwargs,
|
||||
) -> InvoiceResponse:
|
||||
|
||||
data: Dict = {"amountMsat": amount * 1000}
|
||||
data: Dict = {
|
||||
"amountMsat": amount * 1000,
|
||||
"description_hash": b"",
|
||||
"description": memo,
|
||||
}
|
||||
if kwargs.get("expiry"):
|
||||
data["expireIn"] = kwargs["expiry"]
|
||||
|
||||
|
@ -79,8 +82,6 @@ class EclairWallet(Wallet):
|
|||
data["descriptionHash"] = description_hash.hex()
|
||||
elif unhashed_description:
|
||||
data["descriptionHash"] = hashlib.sha256(unhashed_description).hexdigest()
|
||||
else:
|
||||
data["description"] = memo or ""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
|
@ -149,6 +150,7 @@ class EclairWallet(Wallet):
|
|||
}
|
||||
|
||||
data = r.json()[-1]
|
||||
fee_msat = 0
|
||||
if data["status"]["type"] == "sent":
|
||||
fee_msat = -data["status"]["feesPaid"]
|
||||
preimage = data["status"]["paymentPreimage"]
|
||||
|
@ -223,10 +225,10 @@ class EclairWallet(Wallet):
|
|||
) as ws:
|
||||
while True:
|
||||
message = await ws.recv()
|
||||
message = json.loads(message)
|
||||
message_json = json.loads(message)
|
||||
|
||||
if message and message["type"] == "payment-received":
|
||||
yield message["paymentHash"]
|
||||
if message_json and message_json["type"] == "payment-received":
|
||||
yield message_json["paymentHash"]
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
|
|
|
@ -48,16 +48,15 @@ class FakeWallet(Wallet):
|
|||
"amount": amount,
|
||||
"currency": "bc",
|
||||
"privkey": self.privkey,
|
||||
"memo": None,
|
||||
"description_hash": None,
|
||||
"memo": memo,
|
||||
"description_hash": b"",
|
||||
"description": "",
|
||||
"fallback": None,
|
||||
"expires": None,
|
||||
"expires": kwargs.get("expiry"),
|
||||
"timestamp": datetime.now().timestamp(),
|
||||
"route": None,
|
||||
"tags_set": [],
|
||||
}
|
||||
data["expires"] = kwargs.get("expiry")
|
||||
data["amount"] = amount * 1000
|
||||
data["timestamp"] = datetime.now().timestamp()
|
||||
if description_hash:
|
||||
data["tags_set"] = ["h"]
|
||||
data["description_hash"] = description_hash
|
||||
|
@ -69,7 +68,7 @@ class FakeWallet(Wallet):
|
|||
data["memo"] = memo
|
||||
data["description"] = memo
|
||||
randomHash = (
|
||||
data["privkey"][:6]
|
||||
self.privkey[:6]
|
||||
+ hashlib.sha256(str(random.getrandbits(256)).encode()).hexdigest()[6:]
|
||||
)
|
||||
data["paymenthash"] = randomHash
|
||||
|
@ -78,12 +77,10 @@ class FakeWallet(Wallet):
|
|||
|
||||
return InvoiceResponse(True, checking_id, payment_request)
|
||||
|
||||
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
|
||||
async def pay_invoice(self, bolt11: str, _: int) -> PaymentResponse:
|
||||
invoice = decode(bolt11)
|
||||
if (
|
||||
hasattr(invoice, "checking_id")
|
||||
and invoice.checking_id[:6] == self.privkey[:6] # type: ignore
|
||||
):
|
||||
|
||||
if invoice.payment_hash[:6] == self.privkey[:6]:
|
||||
await self.queue.put(invoice)
|
||||
return PaymentResponse(True, invoice.payment_hash, 0)
|
||||
else:
|
||||
|
@ -91,10 +88,10 @@ class FakeWallet(Wallet):
|
|||
ok=False, error_message="Only internal invoices can be used!"
|
||||
)
|
||||
|
||||
async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
|
||||
async def get_invoice_status(self, _: str) -> PaymentStatus:
|
||||
return PaymentStatus(None)
|
||||
|
||||
async def get_payment_status(self, checking_id: str) -> PaymentStatus:
|
||||
async def get_payment_status(self, _: str) -> PaymentStatus:
|
||||
return PaymentStatus(None)
|
||||
|
||||
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
|
||||
|
|
|
@ -21,12 +21,13 @@ class LNbitsWallet(Wallet):
|
|||
|
||||
def __init__(self):
|
||||
self.endpoint = settings.lnbits_endpoint
|
||||
|
||||
key = (
|
||||
settings.lnbits_key
|
||||
or settings.lnbits_admin_key
|
||||
or settings.lnbits_invoice_key
|
||||
)
|
||||
if not self.endpoint or not key:
|
||||
raise Exception("cannot initialize lnbits wallet")
|
||||
self.key = {"X-Api-Key": key}
|
||||
|
||||
async def status(self) -> StatusResponse:
|
||||
|
@ -60,7 +61,7 @@ class LNbitsWallet(Wallet):
|
|||
unhashed_description: Optional[bytes] = None,
|
||||
**kwargs,
|
||||
) -> InvoiceResponse:
|
||||
data: Dict = {"out": False, "amount": amount}
|
||||
data: Dict = {"out": False, "amount": amount, "memo": memo or ""}
|
||||
if kwargs.get("expiry"):
|
||||
data["expiry"] = kwargs["expiry"]
|
||||
if description_hash:
|
||||
|
@ -68,8 +69,6 @@ class LNbitsWallet(Wallet):
|
|||
if unhashed_description:
|
||||
data["unhashed_description"] = unhashed_description.hex()
|
||||
|
||||
data["memo"] = memo or ""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
url=f"{self.endpoint}/api/v1/payments", headers=self.key, json=data
|
||||
|
|
|
@ -105,9 +105,6 @@ class LndWallet(Wallet):
|
|||
)
|
||||
|
||||
endpoint = settings.lnd_grpc_endpoint
|
||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
self.port = int(settings.lnd_grpc_port)
|
||||
self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert
|
||||
|
||||
macaroon = (
|
||||
settings.lnd_grpc_macaroon
|
||||
|
@ -122,8 +119,17 @@ class LndWallet(Wallet):
|
|||
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
||||
encrypted_macaroon
|
||||
)
|
||||
self.macaroon = load_macaroon(macaroon)
|
||||
|
||||
cert_path = settings.lnd_grpc_cert or settings.lnd_cert
|
||||
if not endpoint or not macaroon or not cert_path or not settings.lnd_grpc_port:
|
||||
raise Exception("cannot initialize lndrest")
|
||||
|
||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
self.port = int(settings.lnd_grpc_port)
|
||||
self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert
|
||||
|
||||
self.macaroon = load_macaroon(macaroon)
|
||||
self.cert_path = cert_path
|
||||
cert = open(self.cert_path, "rb").read()
|
||||
creds = grpc.ssl_channel_credentials(cert)
|
||||
auth_creds = grpc.metadata_call_credentials(self.metadata_callback)
|
||||
|
@ -140,8 +146,6 @@ class LndWallet(Wallet):
|
|||
async def status(self) -> StatusResponse:
|
||||
try:
|
||||
resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest())
|
||||
except RpcError as exc:
|
||||
return StatusResponse(str(exc._details), 0)
|
||||
except Exception as exc:
|
||||
return StatusResponse(str(exc), 0)
|
||||
|
||||
|
@ -155,20 +159,23 @@ class LndWallet(Wallet):
|
|||
unhashed_description: Optional[bytes] = None,
|
||||
**kwargs,
|
||||
) -> InvoiceResponse:
|
||||
params: Dict = {"value": amount, "private": True}
|
||||
data: Dict = {
|
||||
"description_hash": b"",
|
||||
"value": amount,
|
||||
"private": True,
|
||||
"memo": memo or "",
|
||||
}
|
||||
if kwargs.get("expiry"):
|
||||
params["expiry"] = kwargs["expiry"]
|
||||
data["expiry"] = kwargs["expiry"]
|
||||
if description_hash:
|
||||
params["description_hash"] = description_hash
|
||||
data["description_hash"] = description_hash
|
||||
elif unhashed_description:
|
||||
params["description_hash"] = hashlib.sha256(
|
||||
data["description_hash"] = hashlib.sha256(
|
||||
unhashed_description
|
||||
).digest() # as bytes directly
|
||||
else:
|
||||
params["memo"] = memo or ""
|
||||
|
||||
try:
|
||||
req = ln.Invoice(**params)
|
||||
req = ln.Invoice(**data)
|
||||
resp = await self.rpc.AddInvoice(req)
|
||||
except Exception as exc:
|
||||
error_message = str(exc)
|
||||
|
@ -188,8 +195,6 @@ class LndWallet(Wallet):
|
|||
)
|
||||
try:
|
||||
resp = await self.routerpc.SendPaymentV2(req).read()
|
||||
except RpcError as exc:
|
||||
return PaymentResponse(False, None, None, None, exc._details)
|
||||
except Exception as exc:
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
|
||||
|
|
|
@ -24,11 +24,6 @@ class LndRestWallet(Wallet):
|
|||
|
||||
def __init__(self):
|
||||
endpoint = settings.lnd_rest_endpoint
|
||||
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
endpoint = (
|
||||
f"https://{endpoint}" if not endpoint.startswith("http") else endpoint
|
||||
)
|
||||
self.endpoint = endpoint
|
||||
|
||||
macaroon = (
|
||||
settings.lnd_rest_macaroon
|
||||
|
@ -43,6 +38,15 @@ class LndRestWallet(Wallet):
|
|||
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
||||
encrypted_macaroon
|
||||
)
|
||||
|
||||
if not endpoint or not macaroon or not settings.lnd_rest_cert:
|
||||
raise Exception("cannot initialize lndrest")
|
||||
|
||||
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
endpoint = (
|
||||
f"https://{endpoint}" if not endpoint.startswith("http") else endpoint
|
||||
)
|
||||
self.endpoint = endpoint
|
||||
self.macaroon = load_macaroon(macaroon)
|
||||
|
||||
self.auth = {"Grpc-Metadata-macaroon": self.macaroon}
|
||||
|
@ -74,7 +78,7 @@ class LndRestWallet(Wallet):
|
|||
unhashed_description: Optional[bytes] = None,
|
||||
**kwargs,
|
||||
) -> InvoiceResponse:
|
||||
data: Dict = {"value": amount, "private": True}
|
||||
data: Dict = {"value": amount, "private": True, "memo": memo or ""}
|
||||
if kwargs.get("expiry"):
|
||||
data["expiry"] = kwargs["expiry"]
|
||||
if description_hash:
|
||||
|
@ -85,8 +89,6 @@ class LndRestWallet(Wallet):
|
|||
data["description_hash"] = base64.b64encode(
|
||||
hashlib.sha256(unhashed_description).digest()
|
||||
).decode("ascii")
|
||||
else:
|
||||
data["memo"] = memo or ""
|
||||
|
||||
async with httpx.AsyncClient(verify=self.cert) as client:
|
||||
r = await client.post(
|
||||
|
|
|
@ -5,7 +5,7 @@ from http import HTTPStatus
|
|||
from typing import AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
@ -24,8 +24,13 @@ class LNPayWallet(Wallet):
|
|||
|
||||
def __init__(self):
|
||||
endpoint = settings.lnpay_api_endpoint
|
||||
wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key
|
||||
|
||||
if not endpoint or not wallet_key or not settings.lnpay_api_key:
|
||||
raise Exception("cannot initialize lnpay")
|
||||
|
||||
self.wallet_key = wallet_key
|
||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
self.wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key
|
||||
self.auth = {"X-Api-Key": settings.lnpay_api_key}
|
||||
|
||||
async def status(self) -> StatusResponse:
|
||||
|
@ -134,7 +139,9 @@ class LNPayWallet(Wallet):
|
|||
yield value
|
||||
|
||||
async def webhook_listener(self):
|
||||
text: str = await request.get_data()
|
||||
# TODO: request.get_data is undefined, was it something with Flask or quart?
|
||||
# probably issue introduced when refactoring?
|
||||
text: str = await request.get_data() # type: ignore
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.decoder.JSONDecodeError:
|
||||
|
|
|
@ -21,13 +21,14 @@ from .base import (
|
|||
class LnTipsWallet(Wallet):
|
||||
def __init__(self):
|
||||
endpoint = settings.lntips_api_endpoint
|
||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
|
||||
key = (
|
||||
settings.lntips_api_key
|
||||
or settings.lntips_admin_key
|
||||
or settings.lntips_invoice_key
|
||||
)
|
||||
if not endpoint or not key:
|
||||
raise Exception("cannot initialize lntxbod")
|
||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
self.auth = {"Authorization": f"Basic {key}"}
|
||||
|
||||
async def status(self) -> StatusResponse:
|
||||
|
@ -55,13 +56,11 @@ class LnTipsWallet(Wallet):
|
|||
unhashed_description: Optional[bytes] = None,
|
||||
**kwargs,
|
||||
) -> InvoiceResponse:
|
||||
data: Dict = {"amount": amount}
|
||||
data: Dict = {"amount": amount, "description_hash": "", "memo": memo or ""}
|
||||
if description_hash:
|
||||
data["description_hash"] = description_hash.hex()
|
||||
elif unhashed_description:
|
||||
data["description_hash"] = hashlib.sha256(unhashed_description).hexdigest()
|
||||
else:
|
||||
data["memo"] = memo or ""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
|
|
|
@ -4,7 +4,7 @@ from http import HTTPStatus
|
|||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
@ -24,13 +24,15 @@ class OpenNodeWallet(Wallet):
|
|||
|
||||
def __init__(self):
|
||||
endpoint = settings.opennode_api_endpoint
|
||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
|
||||
key = (
|
||||
settings.opennode_key
|
||||
or settings.opennode_admin_key
|
||||
or settings.opennode_invoice_key
|
||||
)
|
||||
if not endpoint or not key:
|
||||
raise Exception("cannot initialize opennode")
|
||||
|
||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||
self.auth = {"Authorization": key}
|
||||
|
||||
async def status(self) -> StatusResponse:
|
||||
|
@ -140,7 +142,9 @@ class OpenNodeWallet(Wallet):
|
|||
yield value
|
||||
|
||||
async def webhook_listener(self):
|
||||
data = await request.form
|
||||
# TODO: request.form is undefined, was it something with Flask or quart?
|
||||
# probably issue introduced when refactoring?
|
||||
data = await request.form # type: ignore
|
||||
if "status" not in data or data["status"] != "paid":
|
||||
raise HTTPException(status_code=HTTPStatus.NO_CONTENT)
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ class UnknownError(Exception):
|
|||
|
||||
class SparkWallet(Wallet):
|
||||
def __init__(self):
|
||||
assert settings.spark_url, "spark url does not exist"
|
||||
self.url = settings.spark_url.replace("/rpc", "")
|
||||
self.token = settings.spark_token
|
||||
|
||||
|
@ -46,6 +47,7 @@ class SparkWallet(Wallet):
|
|||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
assert self.token, "spark wallet token does not exist"
|
||||
r = await client.post(
|
||||
self.url + "/rpc",
|
||||
headers={"X-Access": self.token},
|
||||
|
@ -133,38 +135,49 @@ class SparkWallet(Wallet):
|
|||
bolt11=bolt11,
|
||||
maxfee=fee_limit_msat,
|
||||
)
|
||||
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
|
||||
preimage = r["payment_preimage"]
|
||||
return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None)
|
||||
|
||||
except (SparkError, UnknownError) as exc:
|
||||
listpays = await self.listpays(bolt11)
|
||||
if listpays:
|
||||
pays = listpays["pays"]
|
||||
if not listpays:
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
|
||||
if len(pays) == 0:
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
pays = listpays["pays"]
|
||||
|
||||
pay = pays[0]
|
||||
payment_hash = pay["payment_hash"]
|
||||
if len(pays) == 0:
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
|
||||
if len(pays) > 1:
|
||||
raise SparkError(
|
||||
f"listpays({payment_hash}) returned an unexpected response: {listpays}"
|
||||
)
|
||||
pay = pays[0]
|
||||
payment_hash = pay["payment_hash"]
|
||||
|
||||
if pay["status"] == "failed":
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
elif pay["status"] == "pending":
|
||||
return PaymentResponse(None, payment_hash, None, None, None)
|
||||
elif pay["status"] == "complete":
|
||||
r = pay
|
||||
r["payment_preimage"] = pay["preimage"]
|
||||
r["msatoshi"] = int(pay["amount_msat"][0:-4])
|
||||
r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4])
|
||||
# this may result in an error if it was paid previously
|
||||
# our database won't allow the same payment_hash to be added twice
|
||||
# this is good
|
||||
if len(pays) > 1:
|
||||
raise SparkError(
|
||||
f"listpays({payment_hash}) returned an unexpected response: {listpays}"
|
||||
)
|
||||
|
||||
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
|
||||
preimage = r["payment_preimage"]
|
||||
return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None)
|
||||
if pay["status"] == "failed":
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
|
||||
if pay["status"] == "pending":
|
||||
return PaymentResponse(None, payment_hash, None, None, None)
|
||||
|
||||
if pay["status"] == "complete":
|
||||
r = pay
|
||||
r["payment_preimage"] = pay["preimage"]
|
||||
r["msatoshi"] = int(pay["amount_msat"][0:-4])
|
||||
r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4])
|
||||
# this may result in an error if it was paid previously
|
||||
# our database won't allow the same payment_hash to be added twice
|
||||
# this is good
|
||||
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
|
||||
preimage = r["payment_preimage"]
|
||||
return PaymentResponse(
|
||||
True, r["payment_hash"], fee_msat, preimage, None
|
||||
)
|
||||
else:
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
|
||||
async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
|
||||
try:
|
||||
|
@ -205,7 +218,7 @@ class SparkWallet(Wallet):
|
|||
- int(r["pays"][0]["amount_msat"][0:-4])
|
||||
)
|
||||
return PaymentStatus(True, fee_msat, r["pays"][0]["preimage"])
|
||||
elif status == "failed":
|
||||
if status == "failed":
|
||||
return PaymentStatus(False)
|
||||
return PaymentStatus(None)
|
||||
raise KeyError("supplied an invalid checking_id")
|
||||
|
|
|
@ -69,9 +69,6 @@ include = [
|
|||
]
|
||||
exclude = [
|
||||
"lnbits/wallets/lnd_grpc_files",
|
||||
"lnbits/wallets",
|
||||
"lnbits/core",
|
||||
"lnbits/*.py",
|
||||
"lnbits/extensions",
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user