Merge pull request #1468 from lnbits/pyright3
introduce pyright + fix issues (supersedes #1444)
This commit is contained in:
commit
47df94178e
2
Makefile
2
Makefile
|
@ -1,6 +1,6 @@
|
||||||
.PHONY: test
|
.PHONY: test
|
||||||
|
|
||||||
all: format check requirements.txt
|
all: format check
|
||||||
|
|
||||||
format: prettier isort black
|
format: prettier isort black
|
||||||
|
|
||||||
|
|
|
@ -66,11 +66,12 @@ def decode(pr: str) -> Invoice:
|
||||||
invoice.amount_msat = _unshorten_amount(amountstr)
|
invoice.amount_msat = _unshorten_amount(amountstr)
|
||||||
|
|
||||||
# pull out date
|
# pull out date
|
||||||
invoice.date = data.read(35).uint
|
date_bin = data.read(35)
|
||||||
|
invoice.date = date_bin.uint # type: ignore
|
||||||
|
|
||||||
while data.pos != data.len:
|
while data.pos != data.len:
|
||||||
tag, tagdata, data = _pull_tagged(data)
|
tag, tagdata, data = _pull_tagged(data)
|
||||||
data_length = len(tagdata) / 5
|
data_length = len(tagdata or []) / 5
|
||||||
|
|
||||||
if tag == "d":
|
if tag == "d":
|
||||||
invoice.description = _trim_to_bytes(tagdata).decode()
|
invoice.description = _trim_to_bytes(tagdata).decode()
|
||||||
|
@ -79,7 +80,7 @@ def decode(pr: str) -> Invoice:
|
||||||
elif tag == "p" and data_length == 52:
|
elif tag == "p" and data_length == 52:
|
||||||
invoice.payment_hash = _trim_to_bytes(tagdata).hex()
|
invoice.payment_hash = _trim_to_bytes(tagdata).hex()
|
||||||
elif tag == "x":
|
elif tag == "x":
|
||||||
invoice.expiry = tagdata.uint
|
invoice.expiry = tagdata.uint # type: ignore
|
||||||
elif tag == "n":
|
elif tag == "n":
|
||||||
invoice.payee = _trim_to_bytes(tagdata).hex()
|
invoice.payee = _trim_to_bytes(tagdata).hex()
|
||||||
# this won't work in most cases, we must extract the payee
|
# this won't work in most cases, we must extract the payee
|
||||||
|
@ -90,11 +91,11 @@ def decode(pr: str) -> Invoice:
|
||||||
s = bitstring.ConstBitStream(tagdata)
|
s = bitstring.ConstBitStream(tagdata)
|
||||||
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
||||||
route = Route(
|
route = Route(
|
||||||
pubkey=s.read(264).tobytes().hex(),
|
pubkey=s.read(264).tobytes().hex(), # type: ignore
|
||||||
short_channel_id=_readable_scid(s.read(64).intbe),
|
short_channel_id=_readable_scid(s.read(64).intbe), # type: ignore
|
||||||
base_fee_msat=s.read(32).intbe,
|
base_fee_msat=s.read(32).intbe, # type: ignore
|
||||||
ppm_fee=s.read(32).intbe,
|
ppm_fee=s.read(32).intbe, # type: ignore
|
||||||
cltv=s.read(16).intbe,
|
cltv=s.read(16).intbe, # type: ignore
|
||||||
)
|
)
|
||||||
invoice.route_hints.append(route)
|
invoice.route_hints.append(route)
|
||||||
|
|
||||||
|
@ -202,7 +203,8 @@ def lnencode(addr, privkey):
|
||||||
)
|
)
|
||||||
data += tagged("r", route)
|
data += tagged("r", route)
|
||||||
elif k == "f":
|
elif k == "f":
|
||||||
data += encode_fallback(v, addr.currency)
|
# NOTE: there was an error fallback here that's now removed
|
||||||
|
continue
|
||||||
elif k == "d":
|
elif k == "d":
|
||||||
data += tagged_bytes("d", v.encode())
|
data += tagged_bytes("d", v.encode())
|
||||||
elif k == "x":
|
elif k == "x":
|
||||||
|
@ -244,7 +246,13 @@ def lnencode(addr, privkey):
|
||||||
|
|
||||||
class LnAddr:
|
class LnAddr:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None
|
self,
|
||||||
|
paymenthash=None,
|
||||||
|
amount=None,
|
||||||
|
currency="bc",
|
||||||
|
tags=None,
|
||||||
|
date=None,
|
||||||
|
fallback=None,
|
||||||
):
|
):
|
||||||
self.date = int(time.time()) if not date else int(date)
|
self.date = int(time.time()) if not date else int(date)
|
||||||
self.tags = [] if not tags else tags
|
self.tags = [] if not tags else tags
|
||||||
|
@ -252,11 +260,13 @@ class LnAddr:
|
||||||
self.paymenthash = paymenthash
|
self.paymenthash = paymenthash
|
||||||
self.signature = None
|
self.signature = None
|
||||||
self.pubkey = None
|
self.pubkey = None
|
||||||
|
self.fallback = fallback
|
||||||
self.currency = currency
|
self.currency = currency
|
||||||
self.amount = amount
|
self.amount = amount
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
pubkey = bytes.hex(self.pubkey.serialize()).decode()
|
assert self.pubkey, "LnAddr, pubkey must be set"
|
||||||
|
pubkey = bytes.hex(self.pubkey.serialize())
|
||||||
tags = ", ".join([f"{k}={v}" for k, v in self.tags])
|
tags = ", ".join([f"{k}={v}" for k, v in self.tags])
|
||||||
return f"LnAddr[{pubkey}, amount={self.amount}{self.currency} tags=[{tags}]]"
|
return f"LnAddr[{pubkey}, amount={self.amount}{self.currency} tags=[{tags}]]"
|
||||||
|
|
||||||
|
@ -266,6 +276,7 @@ def shorten_amount(amount):
|
||||||
# Convert to pico initially
|
# Convert to pico initially
|
||||||
amount = int(amount * 10**12)
|
amount = int(amount * 10**12)
|
||||||
units = ["p", "n", "u", "m", ""]
|
units = ["p", "n", "u", "m", ""]
|
||||||
|
unit = ""
|
||||||
for unit in units:
|
for unit in units:
|
||||||
if amount % 1000 == 0:
|
if amount % 1000 == 0:
|
||||||
amount //= 1000
|
amount //= 1000
|
||||||
|
@ -304,14 +315,6 @@ def _pull_tagged(stream):
|
||||||
return (CHARSET[tag], stream.read(length * 5), stream)
|
return (CHARSET[tag], stream.read(length * 5), stream)
|
||||||
|
|
||||||
|
|
||||||
def is_p2pkh(currency, prefix):
|
|
||||||
return prefix == base58_prefix_map[currency][0]
|
|
||||||
|
|
||||||
|
|
||||||
def is_p2sh(currency, prefix):
|
|
||||||
return prefix == base58_prefix_map[currency][1]
|
|
||||||
|
|
||||||
|
|
||||||
# Tagged field containing BitArray
|
# Tagged field containing BitArray
|
||||||
def tagged(char, l):
|
def tagged(char, l):
|
||||||
# Tagged fields need to be zero-padded to 5 bits.
|
# Tagged fields need to be zero-padded to 5 bits.
|
||||||
|
@ -359,5 +362,5 @@ def bitarray_to_u5(barr):
|
||||||
ret = []
|
ret = []
|
||||||
s = bitstring.ConstBitStream(barr)
|
s = bitstring.ConstBitStream(barr)
|
||||||
while s.pos != s.len:
|
while s.pos != s.len:
|
||||||
ret.append(s.read(5).uint)
|
ret.append(s.read(5).uint) # type: ignore
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -41,6 +41,7 @@ async def migrate_databases():
|
||||||
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
||||||
|
|
||||||
async with core_db.connect() as conn:
|
async with core_db.connect() as conn:
|
||||||
|
exists = False
|
||||||
if conn.type == SQLITE:
|
if conn.type == SQLITE:
|
||||||
exists = await conn.fetchone(
|
exists = await conn.fetchone(
|
||||||
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
||||||
|
|
|
@ -206,7 +206,7 @@ async def create_wallet(
|
||||||
async def update_wallet(
|
async def update_wallet(
|
||||||
wallet_id: str, new_name: str, conn: Optional[Connection] = None
|
wallet_id: str, new_name: str, conn: Optional[Connection] = None
|
||||||
) -> Optional[Wallet]:
|
) -> Optional[Wallet]:
|
||||||
return await (conn or db).execute(
|
await (conn or db).execute(
|
||||||
"""
|
"""
|
||||||
UPDATE wallets SET
|
UPDATE wallets SET
|
||||||
name = ?
|
name = ?
|
||||||
|
@ -214,6 +214,9 @@ async def update_wallet(
|
||||||
""",
|
""",
|
||||||
(new_name, wallet_id),
|
(new_name, wallet_id),
|
||||||
)
|
)
|
||||||
|
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||||
|
assert wallet, "updated created wallet couldn't be retrieved"
|
||||||
|
return wallet
|
||||||
|
|
||||||
|
|
||||||
async def delete_wallet(
|
async def delete_wallet(
|
||||||
|
@ -393,7 +396,7 @@ async def get_payments(
|
||||||
clause.append("checking_id NOT LIKE 'internal_%'")
|
clause.append("checking_id NOT LIKE 'internal_%'")
|
||||||
|
|
||||||
if not filters:
|
if not filters:
|
||||||
filters = Filters()
|
filters = Filters(limit=None, offset=None)
|
||||||
|
|
||||||
rows = await (conn or db).fetchall(
|
rows = await (conn or db).fetchall(
|
||||||
f"""
|
f"""
|
||||||
|
@ -712,15 +715,19 @@ async def update_admin_settings(data: EditableSettings):
|
||||||
await db.execute("UPDATE settings SET editable_settings = ?", (json.dumps(data),))
|
await db.execute("UPDATE settings SET editable_settings = ?", (json.dumps(data),))
|
||||||
|
|
||||||
|
|
||||||
async def update_super_user(super_user: str):
|
async def update_super_user(super_user: str) -> SuperSettings:
|
||||||
await db.execute("UPDATE settings SET super_user = ?", (super_user,))
|
await db.execute("UPDATE settings SET super_user = ?", (super_user,))
|
||||||
return await get_super_settings()
|
settings = await get_super_settings()
|
||||||
|
assert settings, "updated super_user settings could not be retrieved"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
async def create_admin_settings(super_user: str, new_settings: dict):
|
async def create_admin_settings(super_user: str, new_settings: dict):
|
||||||
sql = "INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)"
|
sql = "INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)"
|
||||||
await db.execute(sql, (super_user, json.dumps(new_settings)))
|
await db.execute(sql, (super_user, json.dumps(new_settings)))
|
||||||
return await get_super_settings()
|
settings = await get_super_settings()
|
||||||
|
assert settings, "created admin settings could not be retrieved"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
# db versions
|
# db versions
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, TypedDict
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -17,6 +17,7 @@ from lnbits.helpers import url_for
|
||||||
from lnbits.settings import (
|
from lnbits.settings import (
|
||||||
FAKE_WALLET,
|
FAKE_WALLET,
|
||||||
EditableSettings,
|
EditableSettings,
|
||||||
|
SuperSettings,
|
||||||
get_wallet_class,
|
get_wallet_class,
|
||||||
readonly_variables,
|
readonly_variables,
|
||||||
send_admin_user_to_saas,
|
send_admin_user_to_saas,
|
||||||
|
@ -43,11 +44,6 @@ from .crud import (
|
||||||
)
|
)
|
||||||
from .models import Payment
|
from .models import Payment
|
||||||
|
|
||||||
try:
|
|
||||||
from typing import TypedDict
|
|
||||||
except ImportError: # pragma: nocover
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class PaymentFailure(Exception):
|
class PaymentFailure(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -188,7 +184,7 @@ async def pay_invoice(
|
||||||
|
|
||||||
# do the balance check
|
# do the balance check
|
||||||
wallet = await get_wallet(wallet_id, conn=conn)
|
wallet = await get_wallet(wallet_id, conn=conn)
|
||||||
assert wallet
|
assert wallet, "Wallet for balancecheck could not be fetched"
|
||||||
if wallet.balance_msat < 0:
|
if wallet.balance_msat < 0:
|
||||||
logger.debug("balance is too low, deleting temporary payment")
|
logger.debug("balance is too low, deleting temporary payment")
|
||||||
if not internal_checking_id and wallet.balance_msat > -fee_reserve_msat:
|
if not internal_checking_id and wallet.balance_msat > -fee_reserve_msat:
|
||||||
|
@ -336,19 +332,19 @@ async def perform_lnurlauth(
|
||||||
|
|
||||||
return b
|
return b
|
||||||
|
|
||||||
def encode_strict_der(r_int, s_int, order):
|
def encode_strict_der(r: int, s: int, order: int):
|
||||||
# if s > order/2 verification will fail sometimes
|
# if s > order/2 verification will fail sometimes
|
||||||
# so we must fix it here (see https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147)
|
# so we must fix it here (see https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147)
|
||||||
if s_int > order // 2:
|
if s > order // 2:
|
||||||
s_int = order - s_int
|
s = order - s
|
||||||
|
|
||||||
# now we do the strict DER encoding copied from
|
# now we do the strict DER encoding copied from
|
||||||
# https://github.com/KiriKiri/bip66 (without any checks)
|
# https://github.com/KiriKiri/bip66 (without any checks)
|
||||||
r = int_to_bytes_suitable_der(r_int)
|
r_temp = int_to_bytes_suitable_der(r)
|
||||||
s = int_to_bytes_suitable_der(s_int)
|
s_temp = int_to_bytes_suitable_der(s)
|
||||||
|
|
||||||
r_len = len(r)
|
r_len = len(r_temp)
|
||||||
s_len = len(s)
|
s_len = len(s_temp)
|
||||||
sign_len = 6 + r_len + s_len
|
sign_len = 6 + r_len + s_len
|
||||||
|
|
||||||
signature = BytesIO()
|
signature = BytesIO()
|
||||||
|
@ -356,16 +352,17 @@ async def perform_lnurlauth(
|
||||||
signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
|
signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
|
||||||
signature.write(0x02.to_bytes(1, "big", signed=False))
|
signature.write(0x02.to_bytes(1, "big", signed=False))
|
||||||
signature.write(r_len.to_bytes(1, "big", signed=False))
|
signature.write(r_len.to_bytes(1, "big", signed=False))
|
||||||
signature.write(r)
|
signature.write(r_temp)
|
||||||
signature.write(0x02.to_bytes(1, "big", signed=False))
|
signature.write(0x02.to_bytes(1, "big", signed=False))
|
||||||
signature.write(s_len.to_bytes(1, "big", signed=False))
|
signature.write(s_len.to_bytes(1, "big", signed=False))
|
||||||
signature.write(s)
|
signature.write(s_temp)
|
||||||
|
|
||||||
return signature.getvalue()
|
return signature.getvalue()
|
||||||
|
|
||||||
sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der)
|
sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der)
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
assert key.verifying_key, "LNURLauth verifying_key does not exist"
|
||||||
r = await client.get(
|
r = await client.get(
|
||||||
callback,
|
callback,
|
||||||
params={
|
params={
|
||||||
|
@ -469,7 +466,7 @@ def update_cached_settings(sets_dict: dict):
|
||||||
setattr(settings, "super_user", sets_dict["super_user"])
|
setattr(settings, "super_user", sets_dict["super_user"])
|
||||||
|
|
||||||
|
|
||||||
async def init_admin_settings(super_user: str = None):
|
async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings:
|
||||||
account = None
|
account = None
|
||||||
if super_user:
|
if super_user:
|
||||||
account = await get_account(super_user)
|
account = await get_account(super_user)
|
||||||
|
|
|
@ -411,8 +411,7 @@ async def subscribe_wallet_invoices(request: Request, wallet: Wallet):
|
||||||
typ, data = await send_queue.get()
|
typ, data = await send_queue.get()
|
||||||
if data:
|
if data:
|
||||||
jdata = json.dumps(dict(data.dict(), pending=False))
|
jdata = json.dumps(dict(data.dict(), pending=False))
|
||||||
|
yield dict(data=jdata, event=typ)
|
||||||
yield dict(data=jdata, event=typ)
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"removing listener for wallet {uid}")
|
logger.debug(f"removing listener for wallet {uid}")
|
||||||
api_invoice_listeners.pop(uid)
|
api_invoice_listeners.pop(uid)
|
||||||
|
@ -431,11 +430,12 @@ async def api_payments_sse(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: refactor this route into a public and admin one
|
||||||
@core_app.get("/api/v1/payments/{payment_hash}")
|
@core_app.get("/api/v1/payments/{payment_hash}")
|
||||||
async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)):
|
async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)):
|
||||||
# We use X_Api_Key here because we want this call to work with and without keys
|
# We use X_Api_Key here because we want this call to work with and without keys
|
||||||
# If a valid key is given, we also return the field "details", otherwise not
|
# If a valid key is given, we also return the field "details", otherwise not
|
||||||
wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None
|
wallet = await get_wallet_for_key(X_Api_Key) if type(X_Api_Key) == str else None # type: ignore
|
||||||
|
|
||||||
# we have to specify the wallet id here, because postgres and sqlite return internal payments in different order
|
# we have to specify the wallet id here, because postgres and sqlite return internal payments in different order
|
||||||
# and get_standalone_payment otherwise just fetches the first one, causing unpredictable results
|
# and get_standalone_payment otherwise just fetches the first one, causing unpredictable results
|
||||||
|
@ -505,6 +505,7 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type
|
||||||
params.update(callback=url) # with k1 already in it
|
params.update(callback=url) # with k1 already in it
|
||||||
|
|
||||||
lnurlauth_key = wallet.wallet.lnurlauth_key(domain)
|
lnurlauth_key = wallet.wallet.lnurlauth_key(domain)
|
||||||
|
assert lnurlauth_key.verifying_key
|
||||||
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
|
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
|
||||||
else:
|
else:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
@ -693,7 +694,7 @@ async def api_auditor():
|
||||||
if not error_message:
|
if not error_message:
|
||||||
delta = node_balance - total_balance
|
delta = node_balance - total_balance
|
||||||
else:
|
else:
|
||||||
node_balance, delta = None, None
|
node_balance, delta = 0, 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"node_balance_msats": int(node_balance),
|
"node_balance_msats": int(node_balance),
|
||||||
|
@ -745,6 +746,7 @@ async def api_install_extension(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND, detail="Release not found"
|
status_code=HTTPStatus.NOT_FOUND, detail="Release not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
ext_info = InstallableExtension(
|
ext_info = InstallableExtension(
|
||||||
id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon
|
id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon
|
||||||
)
|
)
|
||||||
|
@ -824,8 +826,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@core_app.get("/api/v1/extension/{ext_id}/releases")
|
@core_app.get(
|
||||||
async def get_extension_releases(ext_id: str, user: User = Depends(check_admin)):
|
"/api/v1/extension/{ext_id}/releases", dependencies=[Depends(check_admin)]
|
||||||
|
)
|
||||||
|
async def get_extension_releases(ext_id: str):
|
||||||
try:
|
try:
|
||||||
extension_releases: List[
|
extension_releases: List[
|
||||||
ExtensionRelease
|
ExtensionRelease
|
||||||
|
|
|
@ -40,19 +40,18 @@ async def api_public_payment_longpolling(payment_hash):
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
|
|
||||||
async def payment_info_receiver(cancel_scope):
|
async def payment_info_receiver():
|
||||||
async for payment in payment_queue.get():
|
for payment in await payment_queue.get():
|
||||||
if payment.payment_hash == payment_hash:
|
if payment.payment_hash == payment_hash:
|
||||||
nonlocal response
|
nonlocal response
|
||||||
response = {"status": "paid"}
|
response = {"status": "paid"}
|
||||||
cancel_scope.cancel()
|
|
||||||
|
|
||||||
async def timeouter(cancel_scope):
|
async def timeouter(cancel_scope):
|
||||||
await asyncio.sleep(45)
|
await asyncio.sleep(45)
|
||||||
cancel_scope.cancel()
|
cancel_scope.cancel()
|
||||||
|
|
||||||
asyncio.create_task(payment_info_receiver())
|
cancel_scope = asyncio.create_task(payment_info_receiver())
|
||||||
asyncio.create_task(timeouter())
|
asyncio.create_task(timeouter(cancel_scope))
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
return response
|
return response
|
||||||
|
|
20
lnbits/db.py
20
lnbits/db.py
|
@ -131,7 +131,7 @@ class Database(Compat):
|
||||||
else:
|
else:
|
||||||
self.type = POSTGRES
|
self.type = POSTGRES
|
||||||
|
|
||||||
import psycopg2
|
from psycopg2.extensions import DECIMAL, new_type, register_type
|
||||||
|
|
||||||
def _parse_timestamp(value, _):
|
def _parse_timestamp(value, _):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
@ -141,15 +141,15 @@ class Database(Compat):
|
||||||
f = "%Y-%m-%d %H:%M:%S"
|
f = "%Y-%m-%d %H:%M:%S"
|
||||||
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||||
|
|
||||||
psycopg2.extensions.register_type(
|
register_type(
|
||||||
psycopg2.extensions.new_type(
|
new_type(
|
||||||
psycopg2.extensions.DECIMAL.values,
|
DECIMAL.values,
|
||||||
"DEC2FLOAT",
|
"DEC2FLOAT",
|
||||||
lambda value, curs: float(value) if value is not None else None,
|
lambda value, curs: float(value) if value is not None else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
psycopg2.extensions.register_type(
|
register_type(
|
||||||
psycopg2.extensions.new_type(
|
new_type(
|
||||||
(1082, 1083, 1266),
|
(1082, 1083, 1266),
|
||||||
"DATE2INT",
|
"DATE2INT",
|
||||||
lambda value, curs: time.mktime(value.timetuple())
|
lambda value, curs: time.mktime(value.timetuple())
|
||||||
|
@ -158,11 +158,7 @@ class Database(Compat):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
psycopg2.extensions.register_type(
|
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
|
||||||
psycopg2.extensions.new_type(
|
|
||||||
(1184, 1114), "TIMESTAMP2INT", _parse_timestamp
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if os.path.isdir(settings.lnbits_data_folder):
|
if os.path.isdir(settings.lnbits_data_folder):
|
||||||
self.path = os.path.join(
|
self.path = os.path.join(
|
||||||
|
@ -189,7 +185,7 @@ class Database(Compat):
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
await self.lock.acquire()
|
await self.lock.acquire()
|
||||||
try:
|
try:
|
||||||
async with self.engine.connect() as conn:
|
async with self.engine.connect() as conn: # type: ignore
|
||||||
async with conn.begin() as txn:
|
async with conn.begin() as txn:
|
||||||
wconn = Connection(conn, txn, self.type, self.name, self.schema)
|
wconn = Connection(conn, txn, self.type, self.name, self.schema)
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from fastapi import Security, status
|
from fastapi import HTTPException, Request, Security, status
|
||||||
from fastapi.exceptions import HTTPException
|
|
||||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||||
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.types import UUID4
|
from pydantic.types import UUID4
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||||
from lnbits.core.models import User, Wallet
|
from lnbits.core.models import User, Wallet
|
||||||
|
@ -17,9 +15,13 @@ from lnbits.requestvars import g
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: fix type ignores
|
||||||
class KeyChecker(SecurityBase):
|
class KeyChecker(SecurityBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
self,
|
||||||
|
scheme_name: Optional[str] = None,
|
||||||
|
auto_error: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
self.auto_error = auto_error
|
self.auto_error = auto_error
|
||||||
|
@ -27,13 +29,13 @@ class KeyChecker(SecurityBase):
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
if api_key:
|
if api_key:
|
||||||
key = APIKey(
|
key = APIKey(
|
||||||
**{"in": APIKeyIn.query},
|
**{"in": APIKeyIn.query}, # type: ignore
|
||||||
name="X-API-KEY",
|
name="X-API-KEY",
|
||||||
description="Wallet API Key - QUERY",
|
description="Wallet API Key - QUERY",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key = APIKey(
|
key = APIKey(
|
||||||
**{"in": APIKeyIn.header},
|
**{"in": APIKeyIn.header}, # type: ignore
|
||||||
name="X-API-KEY",
|
name="X-API-KEY",
|
||||||
description="Wallet API Key - HEADER",
|
description="Wallet API Key - HEADER",
|
||||||
)
|
)
|
||||||
|
@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
self,
|
||||||
|
scheme_name: Optional[str] = None,
|
||||||
|
auto_error: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
super().__init__(scheme_name, auto_error, api_key)
|
||||||
self._key_type = "invoice"
|
self._key_type = "invoice"
|
||||||
|
@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
self,
|
||||||
|
scheme_name: Optional[str] = None,
|
||||||
|
auto_error: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
super().__init__(scheme_name, auto_error, api_key)
|
||||||
self._key_type = "admin"
|
self._key_type = "admin"
|
||||||
|
|
|
@ -3,20 +3,145 @@ import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import urllib.request
|
|
||||||
import zipfile
|
import zipfile
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class ExplicitRelease(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
version: str
|
||||||
|
archive: str
|
||||||
|
hash: str
|
||||||
|
dependencies: List[str] = []
|
||||||
|
icon: Optional[str]
|
||||||
|
short_description: Optional[str]
|
||||||
|
html_url: Optional[str]
|
||||||
|
details: Optional[str]
|
||||||
|
info_notification: Optional[str]
|
||||||
|
critical_notification: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubRelease(BaseModel):
|
||||||
|
id: str
|
||||||
|
organisation: str
|
||||||
|
repository: str
|
||||||
|
|
||||||
|
|
||||||
|
class Manifest(BaseModel):
|
||||||
|
featured: List[str] = []
|
||||||
|
extensions: List["ExplicitRelease"] = []
|
||||||
|
repos: List["GitHubRelease"] = []
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubRepoRelease(BaseModel):
|
||||||
|
name: str
|
||||||
|
tag_name: str
|
||||||
|
zipball_url: str
|
||||||
|
html_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubRepo(BaseModel):
|
||||||
|
stargazers_count: str
|
||||||
|
html_url: str
|
||||||
|
default_branch: str
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionConfig(BaseModel):
|
||||||
|
name: str
|
||||||
|
short_description: str
|
||||||
|
tile: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(url, save_path):
|
||||||
|
with request.urlopen(url) as dl_file:
|
||||||
|
with open(save_path, "wb") as out_file:
|
||||||
|
out_file.write(dl_file.read())
|
||||||
|
|
||||||
|
|
||||||
|
def file_hash(filename):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
b = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(b)
|
||||||
|
with open(filename, "rb", buffering=0) as f:
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
h.update(mv[:n])
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_github_repo_info(
|
||||||
|
org: str, repository: str
|
||||||
|
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
||||||
|
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
||||||
|
error_msg = "Cannot fetch extension repo"
|
||||||
|
repo = await gihub_api_get(repo_url, error_msg)
|
||||||
|
github_repo = GitHubRepo.parse_obj(repo)
|
||||||
|
|
||||||
|
lates_release_url = (
|
||||||
|
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
||||||
|
)
|
||||||
|
error_msg = "Cannot fetch extension releases"
|
||||||
|
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
|
||||||
|
|
||||||
|
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
||||||
|
error_msg = "Cannot fetch config for extension"
|
||||||
|
config = await gihub_api_get(config_url, error_msg)
|
||||||
|
|
||||||
|
return (
|
||||||
|
github_repo,
|
||||||
|
GitHubRepoRelease.parse_obj(latest_release),
|
||||||
|
ExtensionConfig.parse_obj(config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_manifest(url) -> Manifest:
|
||||||
|
error_msg = "Cannot fetch extensions manifest"
|
||||||
|
manifest = await gihub_api_get(url, error_msg)
|
||||||
|
return Manifest.parse_obj(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
|
||||||
|
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
||||||
|
error_msg = "Cannot fetch extension releases"
|
||||||
|
releases = await gihub_api_get(releases_url, error_msg)
|
||||||
|
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
||||||
|
|
||||||
|
|
||||||
|
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
headers = (
|
||||||
|
{"Authorization": "Bearer " + settings.lnbits_ext_github_token}
|
||||||
|
if settings.lnbits_ext_github_token
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
resp = await client.get(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
||||||
|
if not path:
|
||||||
|
return ""
|
||||||
|
_, _, *rest = path.split("/")
|
||||||
|
tail = "/".join(rest)
|
||||||
|
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
||||||
|
|
||||||
|
|
||||||
class Extension(NamedTuple):
|
class Extension(NamedTuple):
|
||||||
code: str
|
code: str
|
||||||
is_valid: bool
|
is_valid: bool
|
||||||
|
@ -97,12 +222,12 @@ class ExtensionRelease(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
archive: str
|
archive: str
|
||||||
source_repo: str
|
source_repo: str
|
||||||
is_github_release = False
|
is_github_release: bool = False
|
||||||
hash: Optional[str]
|
hash: Optional[str] = None
|
||||||
html_url: Optional[str]
|
html_url: Optional[str] = None
|
||||||
description: Optional[str]
|
description: Optional[str] = None
|
||||||
details_html: Optional[str] = None
|
details_html: Optional[str] = None
|
||||||
icon: Optional[str]
|
icon: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_github_release(
|
def from_github_release(
|
||||||
|
@ -132,52 +257,6 @@ class ExtensionRelease(BaseModel):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class ExplicitRelease(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
version: str
|
|
||||||
archive: str
|
|
||||||
hash: str
|
|
||||||
dependencies: List[str] = []
|
|
||||||
icon: Optional[str]
|
|
||||||
short_description: Optional[str]
|
|
||||||
html_url: Optional[str]
|
|
||||||
details: Optional[str]
|
|
||||||
info_notification: Optional[str]
|
|
||||||
critical_notification: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubRelease(BaseModel):
|
|
||||||
id: str
|
|
||||||
organisation: str
|
|
||||||
repository: str
|
|
||||||
|
|
||||||
|
|
||||||
class Manifest(BaseModel):
|
|
||||||
featured: List[str] = []
|
|
||||||
extensions: List["ExplicitRelease"] = []
|
|
||||||
repos: List["GitHubRelease"] = []
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubRepoRelease(BaseModel):
|
|
||||||
name: str
|
|
||||||
tag_name: str
|
|
||||||
zipball_url: str
|
|
||||||
html_url: str
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubRepo(BaseModel):
|
|
||||||
stargazers_count: str
|
|
||||||
html_url: str
|
|
||||||
default_branch: str
|
|
||||||
|
|
||||||
|
|
||||||
class ExtensionConfig(BaseModel):
|
|
||||||
name: str
|
|
||||||
short_description: str
|
|
||||||
tile: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class InstallableExtension(BaseModel):
|
class InstallableExtension(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
|
@ -187,8 +266,9 @@ class InstallableExtension(BaseModel):
|
||||||
is_admin_only: bool = False
|
is_admin_only: bool = False
|
||||||
stars: int = 0
|
stars: int = 0
|
||||||
featured = False
|
featured = False
|
||||||
latest_release: Optional[ExtensionRelease]
|
latest_release: Optional[ExtensionRelease] = None
|
||||||
installed_release: Optional[ExtensionRelease]
|
installed_release: Optional[ExtensionRelease] = None
|
||||||
|
archive: Optional[str] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hash(self) -> str:
|
def hash(self) -> str:
|
||||||
|
@ -234,6 +314,7 @@ class InstallableExtension(BaseModel):
|
||||||
if ext_zip_file.is_file():
|
if ext_zip_file.is_file():
|
||||||
os.remove(ext_zip_file)
|
os.remove(ext_zip_file)
|
||||||
try:
|
try:
|
||||||
|
assert self.installed_release, "installed_release is none."
|
||||||
download_url(self.installed_release.archive, ext_zip_file)
|
download_url(self.installed_release.archive, ext_zip_file)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(ex)
|
logger.warning(ex)
|
||||||
|
@ -334,8 +415,7 @@ class InstallableExtension(BaseModel):
|
||||||
id=github_release.id,
|
id=github_release.id,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
short_description=config.short_description,
|
short_description=config.short_description,
|
||||||
version="0",
|
stars=int(repo.stargazers_count),
|
||||||
stars=repo.stargazers_count,
|
|
||||||
icon=icon_to_github_url(
|
icon=icon_to_github_url(
|
||||||
f"{github_release.organisation}/{github_release.repository}",
|
f"{github_release.organisation}/{github_release.repository}",
|
||||||
config.tile,
|
config.tile,
|
||||||
|
@ -354,7 +434,6 @@ class InstallableExtension(BaseModel):
|
||||||
id=e.id,
|
id=e.id,
|
||||||
name=e.name,
|
name=e.name,
|
||||||
archive=e.archive,
|
archive=e.archive,
|
||||||
hash=e.hash,
|
|
||||||
short_description=e.short_description,
|
short_description=e.short_description,
|
||||||
icon=e.icon,
|
icon=e.icon,
|
||||||
dependencies=e.dependencies,
|
dependencies=e.dependencies,
|
||||||
|
@ -453,82 +532,3 @@ def get_valid_extensions() -> List[Extension]:
|
||||||
return [
|
return [
|
||||||
extension for extension in ExtensionManager().extensions if extension.is_valid
|
extension for extension in ExtensionManager().extensions if extension.is_valid
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def download_url(url, save_path):
|
|
||||||
with urllib.request.urlopen(url) as dl_file:
|
|
||||||
with open(save_path, "wb") as out_file:
|
|
||||||
out_file.write(dl_file.read())
|
|
||||||
|
|
||||||
|
|
||||||
def file_hash(filename):
|
|
||||||
h = hashlib.sha256()
|
|
||||||
b = bytearray(128 * 1024)
|
|
||||||
mv = memoryview(b)
|
|
||||||
with open(filename, "rb", buffering=0) as f:
|
|
||||||
while n := f.readinto(mv):
|
|
||||||
h.update(mv[:n])
|
|
||||||
return h.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
|
||||||
if not path:
|
|
||||||
return ""
|
|
||||||
_, _, *rest = path.split("/")
|
|
||||||
tail = "/".join(rest)
|
|
||||||
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_github_repo_info(
|
|
||||||
org: str, repository: str
|
|
||||||
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
|
||||||
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
|
||||||
error_msg = "Cannot fetch extension repo"
|
|
||||||
repo = await gihub_api_get(repo_url, error_msg)
|
|
||||||
github_repo = GitHubRepo.parse_obj(repo)
|
|
||||||
|
|
||||||
lates_release_url = (
|
|
||||||
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
|
||||||
)
|
|
||||||
error_msg = "Cannot fetch extension releases"
|
|
||||||
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
|
|
||||||
|
|
||||||
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
|
||||||
error_msg = "Cannot fetch config for extension"
|
|
||||||
config = await gihub_api_get(config_url, error_msg)
|
|
||||||
|
|
||||||
return (
|
|
||||||
github_repo,
|
|
||||||
GitHubRepoRelease.parse_obj(latest_release),
|
|
||||||
ExtensionConfig.parse_obj(config),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_manifest(url) -> Manifest:
|
|
||||||
error_msg = "Cannot fetch extensions manifest"
|
|
||||||
manifest = await gihub_api_get(url, error_msg)
|
|
||||||
return Manifest.parse_obj(manifest)
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
|
|
||||||
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
|
||||||
error_msg = "Cannot fetch extension releases"
|
|
||||||
releases = await gihub_api_get(releases_url, error_msg)
|
|
||||||
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
|
||||||
|
|
||||||
|
|
||||||
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
headers = (
|
|
||||||
{"Authorization": f"Bearer {settings.lnbits_ext_github_token}"}
|
|
||||||
if settings.lnbits_ext_github_token
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
resp = await client.get(
|
|
||||||
url,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
if resp.status_code != 200:
|
|
||||||
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()
|
|
||||||
|
|
|
@ -1,25 +1,18 @@
|
||||||
# Borrowed from the excellent accent-starlette
|
|
||||||
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
|
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from starlette import templating
|
from jinja2 import BaseLoader, Environment, pass_context
|
||||||
from starlette.datastructures import QueryParams
|
from starlette.datastructures import QueryParams
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
from starlette.templating import Jinja2Templates as SuperJinja2Templates
|
||||||
try:
|
|
||||||
import jinja2
|
|
||||||
except ImportError: # pragma: nocover
|
|
||||||
jinja2 = None # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class Jinja2Templates(templating.Jinja2Templates):
|
class Jinja2Templates(SuperJinja2Templates):
|
||||||
def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231
|
def __init__(self, loader: BaseLoader) -> None:
|
||||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
super().__init__("")
|
||||||
self.env = self.get_environment(loader)
|
self.env = self.get_environment(loader)
|
||||||
|
|
||||||
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
|
def get_environment(self, loader: BaseLoader) -> Environment:
|
||||||
@jinja2.pass_context
|
@pass_context
|
||||||
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
|
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
|
||||||
request: Request = context["request"]
|
request: Request = context["request"]
|
||||||
return request.app.url_path_for(name, **path_params)
|
return request.app.url_path_for(name, **path_params)
|
||||||
|
@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates):
|
||||||
values.update(new)
|
values.update(new)
|
||||||
return QueryParams(**values)
|
return QueryParams(**values)
|
||||||
|
|
||||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
env = Environment(loader=loader, autoescape=True)
|
||||||
env.globals["url_for"] = url_for
|
env.globals["url_for"] = url_for
|
||||||
env.globals["url_params_update"] = url_params_update
|
env.globals["url_params_update"] = url_params_update
|
||||||
return env
|
return env
|
||||||
|
|
|
@ -26,6 +26,7 @@ class InstalledExtensionMiddleware:
|
||||||
else:
|
else:
|
||||||
_, path_name = path_elements
|
_, path_name = path_elements
|
||||||
path_type = None
|
path_type = None
|
||||||
|
rest = []
|
||||||
|
|
||||||
# block path for all users if the extension is disabled
|
# block path for all users if the extension is disabled
|
||||||
if path_name in settings.lnbits_deactivated_extensions:
|
if path_name in settings.lnbits_deactivated_extensions:
|
||||||
|
@ -88,7 +89,7 @@ class ExtensionsRedirectMiddleware:
|
||||||
if "from_path" not in redirect:
|
if "from_path" not in redirect:
|
||||||
return False
|
return False
|
||||||
header_filters = (
|
header_filters = (
|
||||||
redirect["header_filters"] if "header_filters" in redirect else []
|
redirect["header_filters"] if "header_filters" in redirect else {}
|
||||||
)
|
)
|
||||||
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
|
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
|
||||||
header_filters, req_headers
|
header_filters, req_headers
|
||||||
|
|
|
@ -24,6 +24,7 @@ def list_parse_fallback(v):
|
||||||
|
|
||||||
|
|
||||||
class LNbitsSettings(BaseSettings):
|
class LNbitsSettings(BaseSettings):
|
||||||
|
@classmethod
|
||||||
def validate(cls, val):
|
def validate(cls, val):
|
||||||
if type(val) == str:
|
if type(val) == str:
|
||||||
val = val.split(",") if val else []
|
val = val.split(",") if val else []
|
||||||
|
@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings):
|
||||||
class LNbitsFundingSource(LNbitsSettings):
|
class LNbitsFundingSource(LNbitsSettings):
|
||||||
lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
|
lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
|
||||||
lnbits_key: Optional[str] = Field(default=None)
|
lnbits_key: Optional[str] = Field(default=None)
|
||||||
|
lnbits_admin_key: Optional[str] = Field(default=None)
|
||||||
|
lnbits_invoice_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class ClicheFundingSource(LNbitsSettings):
|
class ClicheFundingSource(LNbitsSettings):
|
||||||
|
@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings):
|
||||||
lnpay_api_endpoint: Optional[str] = Field(default=None)
|
lnpay_api_endpoint: Optional[str] = Field(default=None)
|
||||||
lnpay_api_key: Optional[str] = Field(default=None)
|
lnpay_api_key: Optional[str] = Field(default=None)
|
||||||
lnpay_wallet_key: Optional[str] = Field(default=None)
|
lnpay_wallet_key: Optional[str] = Field(default=None)
|
||||||
|
lnpay_admin_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class OpenNodeFundingSource(LNbitsSettings):
|
class OpenNodeFundingSource(LNbitsSettings):
|
||||||
opennode_api_endpoint: Optional[str] = Field(default=None)
|
opennode_api_endpoint: Optional[str] = Field(default=None)
|
||||||
opennode_key: Optional[str] = Field(default=None)
|
opennode_key: Optional[str] = Field(default=None)
|
||||||
|
opennode_admin_key: Optional[str] = Field(default=None)
|
||||||
|
opennode_invoice_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class SparkFundingSource(LNbitsSettings):
|
class SparkFundingSource(LNbitsSettings):
|
||||||
|
@ -208,8 +214,9 @@ class EditableSettings(
|
||||||
"lnbits_admin_extensions",
|
"lnbits_admin_extensions",
|
||||||
pre=True,
|
pre=True,
|
||||||
)
|
)
|
||||||
|
@classmethod
|
||||||
def validate_editable_settings(cls, val):
|
def validate_editable_settings(cls, val):
|
||||||
return super().validate(cls, val)
|
return super().validate(val)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict):
|
def from_dict(cls, d: dict):
|
||||||
|
@ -281,8 +288,9 @@ class ReadOnlySettings(
|
||||||
"lnbits_allowed_funding_sources",
|
"lnbits_allowed_funding_sources",
|
||||||
pre=True,
|
pre=True,
|
||||||
)
|
)
|
||||||
|
@classmethod
|
||||||
def validate_readonly_settings(cls, val):
|
def validate_readonly_settings(cls, val):
|
||||||
return super().validate(cls, val)
|
return super().validate(val)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def readonly_fields(cls):
|
def readonly_fields(cls):
|
||||||
|
|
|
@ -3,7 +3,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -42,7 +42,7 @@ class SseListenersDict(dict):
|
||||||
A dict of sse listeners.
|
A dict of sse listeners.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str = None):
|
def __init__(self, name: Optional[str] = None):
|
||||||
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
|
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
@ -65,7 +65,7 @@ class SseListenersDict(dict):
|
||||||
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
|
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
|
||||||
|
|
||||||
|
|
||||||
def register_invoice_listener(send_chan: asyncio.Queue, name: str = None):
|
def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
A method intended for extensions (and core/tasks.py) to call when they want to be notified about
|
A method intended for extensions (and core/tasks.py) to call when they want to be notified about
|
||||||
new invoice payments incoming. Will emit all incoming payments.
|
new invoice payments incoming. Will emit all incoming payments.
|
||||||
|
@ -164,7 +164,7 @@ async def check_pending_payments():
|
||||||
async def perform_balance_checks():
|
async def perform_balance_checks():
|
||||||
while True:
|
while True:
|
||||||
for bc in await get_balance_checks():
|
for bc in await get_balance_checks():
|
||||||
redeem_lnurl_withdraw(bc.wallet, bc.url)
|
await redeem_lnurl_withdraw(bc.wallet, bc.url)
|
||||||
|
|
||||||
await asyncio.sleep(60 * 60 * 6) # every 6 hours
|
await asyncio.sleep(60 * 60 * 6) # every 6 hours
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
# flake8: noqa: F401
|
# flake8: noqa: F401
|
||||||
|
|
||||||
|
|
||||||
from .cliche import ClicheWallet
|
from .cliche import ClicheWallet
|
||||||
from .cln import CoreLightningWallet # legacy .env support
|
from .cln import CoreLightningWallet
|
||||||
from .cln import CoreLightningWallet as CLightningWallet
|
from .cln import CoreLightningWallet as CLightningWallet
|
||||||
from .eclair import EclairWallet
|
from .eclair import EclairWallet
|
||||||
from .fake import FakeWallet
|
from .fake import FakeWallet
|
||||||
|
|
|
@ -22,6 +22,8 @@ class ClicheWallet(Wallet):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.endpoint = settings.cliche_endpoint
|
self.endpoint = settings.cliche_endpoint
|
||||||
|
if not self.endpoint:
|
||||||
|
raise Exception("cannot initialize cliche")
|
||||||
|
|
||||||
async def status(self) -> StatusResponse:
|
async def status(self) -> StatusResponse:
|
||||||
try:
|
try:
|
||||||
|
@ -36,7 +38,7 @@ class ClicheWallet(Wallet):
|
||||||
data = json.loads(r)
|
data = json.loads(r)
|
||||||
except:
|
except:
|
||||||
return StatusResponse(
|
return StatusResponse(
|
||||||
f"Failed to connect to {self.endpoint}, got: '{r.text[:200]}...'", 0
|
f"Failed to connect to {self.endpoint}, got: '{r[:200]}...'", 0
|
||||||
)
|
)
|
||||||
|
|
||||||
return StatusResponse(None, data["result"]["wallets"][0]["balance"])
|
return StatusResponse(None, data["result"]["wallets"][0]["balance"])
|
||||||
|
@ -89,6 +91,13 @@ class ClicheWallet(Wallet):
|
||||||
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
|
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
|
||||||
ws = create_connection(self.endpoint)
|
ws = create_connection(self.endpoint)
|
||||||
ws.send(f"pay-invoice --invoice {bolt11}")
|
ws.send(f"pay-invoice --invoice {bolt11}")
|
||||||
|
checking_id, fee_msat, preimage, error_message, payment_ok = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
r = ws.recv()
|
r = ws.recv()
|
||||||
data = json.loads(r)
|
data = json.loads(r)
|
||||||
|
@ -151,9 +160,9 @@ class ClicheWallet(Wallet):
|
||||||
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
|
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
ws = await create_connection(self.endpoint)
|
ws = create_connection(self.endpoint)
|
||||||
while True:
|
while True:
|
||||||
r = await ws.recv()
|
r = ws.recv()
|
||||||
data = json.loads(r)
|
data = json.loads(r)
|
||||||
print(data)
|
print(data)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -7,10 +7,7 @@ from typing import AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from websockets.client import connect
|
||||||
# TODO: https://github.com/lnbits/lnbits/issues/764
|
|
||||||
# mypy https://github.com/aaugustin/websockets/issues/940
|
|
||||||
from websockets import connect # type: ignore
|
|
||||||
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
@ -34,11 +31,13 @@ class UnknownError(Exception):
|
||||||
class EclairWallet(Wallet):
|
class EclairWallet(Wallet):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
url = settings.eclair_url
|
url = settings.eclair_url
|
||||||
self.url = url[:-1] if url.endswith("/") else url
|
passw = settings.eclair_pass
|
||||||
|
if not url or not passw:
|
||||||
|
raise Exception("cannot initialize eclair")
|
||||||
|
|
||||||
|
self.url = url[:-1] if url.endswith("/") else url
|
||||||
self.ws_url = f"ws://{urllib.parse.urlsplit(self.url).netloc}/ws"
|
self.ws_url = f"ws://{urllib.parse.urlsplit(self.url).netloc}/ws"
|
||||||
|
|
||||||
passw = settings.eclair_pass
|
|
||||||
encodedAuth = base64.b64encode(f":{passw}".encode())
|
encodedAuth = base64.b64encode(f":{passw}".encode())
|
||||||
auth = str(encodedAuth, "utf-8")
|
auth = str(encodedAuth, "utf-8")
|
||||||
self.auth = {"Authorization": f"Basic {auth}"}
|
self.auth = {"Authorization": f"Basic {auth}"}
|
||||||
|
@ -71,7 +70,11 @@ class EclairWallet(Wallet):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> InvoiceResponse:
|
) -> InvoiceResponse:
|
||||||
|
|
||||||
data: Dict = {"amountMsat": amount * 1000}
|
data: Dict = {
|
||||||
|
"amountMsat": amount * 1000,
|
||||||
|
"description_hash": b"",
|
||||||
|
"description": memo,
|
||||||
|
}
|
||||||
if kwargs.get("expiry"):
|
if kwargs.get("expiry"):
|
||||||
data["expireIn"] = kwargs["expiry"]
|
data["expireIn"] = kwargs["expiry"]
|
||||||
|
|
||||||
|
@ -79,8 +82,6 @@ class EclairWallet(Wallet):
|
||||||
data["descriptionHash"] = description_hash.hex()
|
data["descriptionHash"] = description_hash.hex()
|
||||||
elif unhashed_description:
|
elif unhashed_description:
|
||||||
data["descriptionHash"] = hashlib.sha256(unhashed_description).hexdigest()
|
data["descriptionHash"] = hashlib.sha256(unhashed_description).hexdigest()
|
||||||
else:
|
|
||||||
data["description"] = memo or ""
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
|
@ -149,6 +150,7 @@ class EclairWallet(Wallet):
|
||||||
}
|
}
|
||||||
|
|
||||||
data = r.json()[-1]
|
data = r.json()[-1]
|
||||||
|
fee_msat = 0
|
||||||
if data["status"]["type"] == "sent":
|
if data["status"]["type"] == "sent":
|
||||||
fee_msat = -data["status"]["feesPaid"]
|
fee_msat = -data["status"]["feesPaid"]
|
||||||
preimage = data["status"]["paymentPreimage"]
|
preimage = data["status"]["paymentPreimage"]
|
||||||
|
@ -223,10 +225,10 @@ class EclairWallet(Wallet):
|
||||||
) as ws:
|
) as ws:
|
||||||
while True:
|
while True:
|
||||||
message = await ws.recv()
|
message = await ws.recv()
|
||||||
message = json.loads(message)
|
message_json = json.loads(message)
|
||||||
|
|
||||||
if message and message["type"] == "payment-received":
|
if message_json and message_json["type"] == "payment-received":
|
||||||
yield message["paymentHash"]
|
yield message_json["paymentHash"]
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
|
@ -48,16 +48,15 @@ class FakeWallet(Wallet):
|
||||||
"amount": amount,
|
"amount": amount,
|
||||||
"currency": "bc",
|
"currency": "bc",
|
||||||
"privkey": self.privkey,
|
"privkey": self.privkey,
|
||||||
"memo": None,
|
"memo": memo,
|
||||||
"description_hash": None,
|
"description_hash": b"",
|
||||||
"description": "",
|
"description": "",
|
||||||
"fallback": None,
|
"fallback": None,
|
||||||
"expires": None,
|
"expires": kwargs.get("expiry"),
|
||||||
|
"timestamp": datetime.now().timestamp(),
|
||||||
"route": None,
|
"route": None,
|
||||||
|
"tags_set": [],
|
||||||
}
|
}
|
||||||
data["expires"] = kwargs.get("expiry")
|
|
||||||
data["amount"] = amount * 1000
|
|
||||||
data["timestamp"] = datetime.now().timestamp()
|
|
||||||
if description_hash:
|
if description_hash:
|
||||||
data["tags_set"] = ["h"]
|
data["tags_set"] = ["h"]
|
||||||
data["description_hash"] = description_hash
|
data["description_hash"] = description_hash
|
||||||
|
@ -69,7 +68,7 @@ class FakeWallet(Wallet):
|
||||||
data["memo"] = memo
|
data["memo"] = memo
|
||||||
data["description"] = memo
|
data["description"] = memo
|
||||||
randomHash = (
|
randomHash = (
|
||||||
data["privkey"][:6]
|
self.privkey[:6]
|
||||||
+ hashlib.sha256(str(random.getrandbits(256)).encode()).hexdigest()[6:]
|
+ hashlib.sha256(str(random.getrandbits(256)).encode()).hexdigest()[6:]
|
||||||
)
|
)
|
||||||
data["paymenthash"] = randomHash
|
data["paymenthash"] = randomHash
|
||||||
|
@ -78,12 +77,10 @@ class FakeWallet(Wallet):
|
||||||
|
|
||||||
return InvoiceResponse(True, checking_id, payment_request)
|
return InvoiceResponse(True, checking_id, payment_request)
|
||||||
|
|
||||||
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
|
async def pay_invoice(self, bolt11: str, _: int) -> PaymentResponse:
|
||||||
invoice = decode(bolt11)
|
invoice = decode(bolt11)
|
||||||
if (
|
|
||||||
hasattr(invoice, "checking_id")
|
if invoice.payment_hash[:6] == self.privkey[:6]:
|
||||||
and invoice.checking_id[:6] == self.privkey[:6] # type: ignore
|
|
||||||
):
|
|
||||||
await self.queue.put(invoice)
|
await self.queue.put(invoice)
|
||||||
return PaymentResponse(True, invoice.payment_hash, 0)
|
return PaymentResponse(True, invoice.payment_hash, 0)
|
||||||
else:
|
else:
|
||||||
|
@ -91,10 +88,10 @@ class FakeWallet(Wallet):
|
||||||
ok=False, error_message="Only internal invoices can be used!"
|
ok=False, error_message="Only internal invoices can be used!"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
|
async def get_invoice_status(self, _: str) -> PaymentStatus:
|
||||||
return PaymentStatus(None)
|
return PaymentStatus(None)
|
||||||
|
|
||||||
async def get_payment_status(self, checking_id: str) -> PaymentStatus:
|
async def get_payment_status(self, _: str) -> PaymentStatus:
|
||||||
return PaymentStatus(None)
|
return PaymentStatus(None)
|
||||||
|
|
||||||
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
|
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
|
||||||
|
|
|
@ -21,12 +21,13 @@ class LNbitsWallet(Wallet):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.endpoint = settings.lnbits_endpoint
|
self.endpoint = settings.lnbits_endpoint
|
||||||
|
|
||||||
key = (
|
key = (
|
||||||
settings.lnbits_key
|
settings.lnbits_key
|
||||||
or settings.lnbits_admin_key
|
or settings.lnbits_admin_key
|
||||||
or settings.lnbits_invoice_key
|
or settings.lnbits_invoice_key
|
||||||
)
|
)
|
||||||
|
if not self.endpoint or not key:
|
||||||
|
raise Exception("cannot initialize lnbits wallet")
|
||||||
self.key = {"X-Api-Key": key}
|
self.key = {"X-Api-Key": key}
|
||||||
|
|
||||||
async def status(self) -> StatusResponse:
|
async def status(self) -> StatusResponse:
|
||||||
|
@ -60,7 +61,7 @@ class LNbitsWallet(Wallet):
|
||||||
unhashed_description: Optional[bytes] = None,
|
unhashed_description: Optional[bytes] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> InvoiceResponse:
|
) -> InvoiceResponse:
|
||||||
data: Dict = {"out": False, "amount": amount}
|
data: Dict = {"out": False, "amount": amount, "memo": memo or ""}
|
||||||
if kwargs.get("expiry"):
|
if kwargs.get("expiry"):
|
||||||
data["expiry"] = kwargs["expiry"]
|
data["expiry"] = kwargs["expiry"]
|
||||||
if description_hash:
|
if description_hash:
|
||||||
|
@ -68,8 +69,6 @@ class LNbitsWallet(Wallet):
|
||||||
if unhashed_description:
|
if unhashed_description:
|
||||||
data["unhashed_description"] = unhashed_description.hex()
|
data["unhashed_description"] = unhashed_description.hex()
|
||||||
|
|
||||||
data["memo"] = memo or ""
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
url=f"{self.endpoint}/api/v1/payments", headers=self.key, json=data
|
url=f"{self.endpoint}/api/v1/payments", headers=self.key, json=data
|
||||||
|
|
|
@ -105,9 +105,6 @@ class LndWallet(Wallet):
|
||||||
)
|
)
|
||||||
|
|
||||||
endpoint = settings.lnd_grpc_endpoint
|
endpoint = settings.lnd_grpc_endpoint
|
||||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
|
||||||
self.port = int(settings.lnd_grpc_port)
|
|
||||||
self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert
|
|
||||||
|
|
||||||
macaroon = (
|
macaroon = (
|
||||||
settings.lnd_grpc_macaroon
|
settings.lnd_grpc_macaroon
|
||||||
|
@ -122,8 +119,17 @@ class LndWallet(Wallet):
|
||||||
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
||||||
encrypted_macaroon
|
encrypted_macaroon
|
||||||
)
|
)
|
||||||
self.macaroon = load_macaroon(macaroon)
|
|
||||||
|
|
||||||
|
cert_path = settings.lnd_grpc_cert or settings.lnd_cert
|
||||||
|
if not endpoint or not macaroon or not cert_path or not settings.lnd_grpc_port:
|
||||||
|
raise Exception("cannot initialize lndrest")
|
||||||
|
|
||||||
|
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||||
|
self.port = int(settings.lnd_grpc_port)
|
||||||
|
self.cert_path = settings.lnd_grpc_cert or settings.lnd_cert
|
||||||
|
|
||||||
|
self.macaroon = load_macaroon(macaroon)
|
||||||
|
self.cert_path = cert_path
|
||||||
cert = open(self.cert_path, "rb").read()
|
cert = open(self.cert_path, "rb").read()
|
||||||
creds = grpc.ssl_channel_credentials(cert)
|
creds = grpc.ssl_channel_credentials(cert)
|
||||||
auth_creds = grpc.metadata_call_credentials(self.metadata_callback)
|
auth_creds = grpc.metadata_call_credentials(self.metadata_callback)
|
||||||
|
@ -140,8 +146,6 @@ class LndWallet(Wallet):
|
||||||
async def status(self) -> StatusResponse:
|
async def status(self) -> StatusResponse:
|
||||||
try:
|
try:
|
||||||
resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest())
|
resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest())
|
||||||
except RpcError as exc:
|
|
||||||
return StatusResponse(str(exc._details), 0)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return StatusResponse(str(exc), 0)
|
return StatusResponse(str(exc), 0)
|
||||||
|
|
||||||
|
@ -155,20 +159,23 @@ class LndWallet(Wallet):
|
||||||
unhashed_description: Optional[bytes] = None,
|
unhashed_description: Optional[bytes] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> InvoiceResponse:
|
) -> InvoiceResponse:
|
||||||
params: Dict = {"value": amount, "private": True}
|
data: Dict = {
|
||||||
|
"description_hash": b"",
|
||||||
|
"value": amount,
|
||||||
|
"private": True,
|
||||||
|
"memo": memo or "",
|
||||||
|
}
|
||||||
if kwargs.get("expiry"):
|
if kwargs.get("expiry"):
|
||||||
params["expiry"] = kwargs["expiry"]
|
data["expiry"] = kwargs["expiry"]
|
||||||
if description_hash:
|
if description_hash:
|
||||||
params["description_hash"] = description_hash
|
data["description_hash"] = description_hash
|
||||||
elif unhashed_description:
|
elif unhashed_description:
|
||||||
params["description_hash"] = hashlib.sha256(
|
data["description_hash"] = hashlib.sha256(
|
||||||
unhashed_description
|
unhashed_description
|
||||||
).digest() # as bytes directly
|
).digest() # as bytes directly
|
||||||
else:
|
|
||||||
params["memo"] = memo or ""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
req = ln.Invoice(**params)
|
req = ln.Invoice(**data)
|
||||||
resp = await self.rpc.AddInvoice(req)
|
resp = await self.rpc.AddInvoice(req)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_message = str(exc)
|
error_message = str(exc)
|
||||||
|
@ -188,8 +195,6 @@ class LndWallet(Wallet):
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
resp = await self.routerpc.SendPaymentV2(req).read()
|
resp = await self.routerpc.SendPaymentV2(req).read()
|
||||||
except RpcError as exc:
|
|
||||||
return PaymentResponse(False, None, None, None, exc._details)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return PaymentResponse(False, None, None, None, str(exc))
|
return PaymentResponse(False, None, None, None, str(exc))
|
||||||
|
|
||||||
|
|
|
@ -24,11 +24,6 @@ class LndRestWallet(Wallet):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
endpoint = settings.lnd_rest_endpoint
|
endpoint = settings.lnd_rest_endpoint
|
||||||
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
|
||||||
endpoint = (
|
|
||||||
f"https://{endpoint}" if not endpoint.startswith("http") else endpoint
|
|
||||||
)
|
|
||||||
self.endpoint = endpoint
|
|
||||||
|
|
||||||
macaroon = (
|
macaroon = (
|
||||||
settings.lnd_rest_macaroon
|
settings.lnd_rest_macaroon
|
||||||
|
@ -43,6 +38,15 @@ class LndRestWallet(Wallet):
|
||||||
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
||||||
encrypted_macaroon
|
encrypted_macaroon
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not endpoint or not macaroon or not settings.lnd_rest_cert:
|
||||||
|
raise Exception("cannot initialize lndrest")
|
||||||
|
|
||||||
|
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||||
|
endpoint = (
|
||||||
|
f"https://{endpoint}" if not endpoint.startswith("http") else endpoint
|
||||||
|
)
|
||||||
|
self.endpoint = endpoint
|
||||||
self.macaroon = load_macaroon(macaroon)
|
self.macaroon = load_macaroon(macaroon)
|
||||||
|
|
||||||
self.auth = {"Grpc-Metadata-macaroon": self.macaroon}
|
self.auth = {"Grpc-Metadata-macaroon": self.macaroon}
|
||||||
|
@ -74,7 +78,7 @@ class LndRestWallet(Wallet):
|
||||||
unhashed_description: Optional[bytes] = None,
|
unhashed_description: Optional[bytes] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> InvoiceResponse:
|
) -> InvoiceResponse:
|
||||||
data: Dict = {"value": amount, "private": True}
|
data: Dict = {"value": amount, "private": True, "memo": memo or ""}
|
||||||
if kwargs.get("expiry"):
|
if kwargs.get("expiry"):
|
||||||
data["expiry"] = kwargs["expiry"]
|
data["expiry"] = kwargs["expiry"]
|
||||||
if description_hash:
|
if description_hash:
|
||||||
|
@ -85,8 +89,6 @@ class LndRestWallet(Wallet):
|
||||||
data["description_hash"] = base64.b64encode(
|
data["description_hash"] = base64.b64encode(
|
||||||
hashlib.sha256(unhashed_description).digest()
|
hashlib.sha256(unhashed_description).digest()
|
||||||
).decode("ascii")
|
).decode("ascii")
|
||||||
else:
|
|
||||||
data["memo"] = memo or ""
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(verify=self.cert) as client:
|
async with httpx.AsyncClient(verify=self.cert) as client:
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
|
|
|
@ -5,7 +5,7 @@ from http import HTTPStatus
|
||||||
from typing import AsyncGenerator, Dict, Optional
|
from typing import AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
@ -24,8 +24,13 @@ class LNPayWallet(Wallet):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
endpoint = settings.lnpay_api_endpoint
|
endpoint = settings.lnpay_api_endpoint
|
||||||
|
wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key
|
||||||
|
|
||||||
|
if not endpoint or not wallet_key or not settings.lnpay_api_key:
|
||||||
|
raise Exception("cannot initialize lnpay")
|
||||||
|
|
||||||
|
self.wallet_key = wallet_key
|
||||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||||
self.wallet_key = settings.lnpay_wallet_key or settings.lnpay_admin_key
|
|
||||||
self.auth = {"X-Api-Key": settings.lnpay_api_key}
|
self.auth = {"X-Api-Key": settings.lnpay_api_key}
|
||||||
|
|
||||||
async def status(self) -> StatusResponse:
|
async def status(self) -> StatusResponse:
|
||||||
|
@ -134,7 +139,9 @@ class LNPayWallet(Wallet):
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
async def webhook_listener(self):
|
async def webhook_listener(self):
|
||||||
text: str = await request.get_data()
|
# TODO: request.get_data is undefined, was it something with Flask or quart?
|
||||||
|
# probably issue introduced when refactoring?
|
||||||
|
text: str = await request.get_data() # type: ignore
|
||||||
try:
|
try:
|
||||||
data = json.loads(text)
|
data = json.loads(text)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
|
|
|
@ -21,13 +21,14 @@ from .base import (
|
||||||
class LnTipsWallet(Wallet):
|
class LnTipsWallet(Wallet):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
endpoint = settings.lntips_api_endpoint
|
endpoint = settings.lntips_api_endpoint
|
||||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
|
||||||
|
|
||||||
key = (
|
key = (
|
||||||
settings.lntips_api_key
|
settings.lntips_api_key
|
||||||
or settings.lntips_admin_key
|
or settings.lntips_admin_key
|
||||||
or settings.lntips_invoice_key
|
or settings.lntips_invoice_key
|
||||||
)
|
)
|
||||||
|
if not endpoint or not key:
|
||||||
|
raise Exception("cannot initialize lntxbod")
|
||||||
|
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||||
self.auth = {"Authorization": f"Basic {key}"}
|
self.auth = {"Authorization": f"Basic {key}"}
|
||||||
|
|
||||||
async def status(self) -> StatusResponse:
|
async def status(self) -> StatusResponse:
|
||||||
|
@ -55,13 +56,11 @@ class LnTipsWallet(Wallet):
|
||||||
unhashed_description: Optional[bytes] = None,
|
unhashed_description: Optional[bytes] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> InvoiceResponse:
|
) -> InvoiceResponse:
|
||||||
data: Dict = {"amount": amount}
|
data: Dict = {"amount": amount, "description_hash": "", "memo": memo or ""}
|
||||||
if description_hash:
|
if description_hash:
|
||||||
data["description_hash"] = description_hash.hex()
|
data["description_hash"] = description_hash.hex()
|
||||||
elif unhashed_description:
|
elif unhashed_description:
|
||||||
data["description_hash"] = hashlib.sha256(unhashed_description).hexdigest()
|
data["description_hash"] = hashlib.sha256(unhashed_description).hexdigest()
|
||||||
else:
|
|
||||||
data["memo"] = memo or ""
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
|
|
|
@ -4,7 +4,7 @@ from http import HTTPStatus
|
||||||
from typing import AsyncGenerator, Optional
|
from typing import AsyncGenerator, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
@ -24,13 +24,15 @@ class OpenNodeWallet(Wallet):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
endpoint = settings.opennode_api_endpoint
|
endpoint = settings.opennode_api_endpoint
|
||||||
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
|
||||||
|
|
||||||
key = (
|
key = (
|
||||||
settings.opennode_key
|
settings.opennode_key
|
||||||
or settings.opennode_admin_key
|
or settings.opennode_admin_key
|
||||||
or settings.opennode_invoice_key
|
or settings.opennode_invoice_key
|
||||||
)
|
)
|
||||||
|
if not endpoint or not key:
|
||||||
|
raise Exception("cannot initialize opennode")
|
||||||
|
|
||||||
|
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
|
||||||
self.auth = {"Authorization": key}
|
self.auth = {"Authorization": key}
|
||||||
|
|
||||||
async def status(self) -> StatusResponse:
|
async def status(self) -> StatusResponse:
|
||||||
|
@ -140,7 +142,9 @@ class OpenNodeWallet(Wallet):
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
async def webhook_listener(self):
|
async def webhook_listener(self):
|
||||||
data = await request.form
|
# TODO: request.form is undefined, was it something with Flask or quart?
|
||||||
|
# probably issue introduced when refactoring?
|
||||||
|
data = await request.form # type: ignore
|
||||||
if "status" not in data or data["status"] != "paid":
|
if "status" not in data or data["status"] != "paid":
|
||||||
raise HTTPException(status_code=HTTPStatus.NO_CONTENT)
|
raise HTTPException(status_code=HTTPStatus.NO_CONTENT)
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ class UnknownError(Exception):
|
||||||
|
|
||||||
class SparkWallet(Wallet):
|
class SparkWallet(Wallet):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
assert settings.spark_url, "spark url does not exist"
|
||||||
self.url = settings.spark_url.replace("/rpc", "")
|
self.url = settings.spark_url.replace("/rpc", "")
|
||||||
self.token = settings.spark_token
|
self.token = settings.spark_token
|
||||||
|
|
||||||
|
@ -46,6 +47,7 @@ class SparkWallet(Wallet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
assert self.token, "spark wallet token does not exist"
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
self.url + "/rpc",
|
self.url + "/rpc",
|
||||||
headers={"X-Access": self.token},
|
headers={"X-Access": self.token},
|
||||||
|
@ -133,38 +135,49 @@ class SparkWallet(Wallet):
|
||||||
bolt11=bolt11,
|
bolt11=bolt11,
|
||||||
maxfee=fee_limit_msat,
|
maxfee=fee_limit_msat,
|
||||||
)
|
)
|
||||||
|
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
|
||||||
|
preimage = r["payment_preimage"]
|
||||||
|
return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None)
|
||||||
|
|
||||||
except (SparkError, UnknownError) as exc:
|
except (SparkError, UnknownError) as exc:
|
||||||
listpays = await self.listpays(bolt11)
|
listpays = await self.listpays(bolt11)
|
||||||
if listpays:
|
if not listpays:
|
||||||
pays = listpays["pays"]
|
return PaymentResponse(False, None, None, None, str(exc))
|
||||||
|
|
||||||
if len(pays) == 0:
|
pays = listpays["pays"]
|
||||||
return PaymentResponse(False, None, None, None, str(exc))
|
|
||||||
|
|
||||||
pay = pays[0]
|
if len(pays) == 0:
|
||||||
payment_hash = pay["payment_hash"]
|
return PaymentResponse(False, None, None, None, str(exc))
|
||||||
|
|
||||||
if len(pays) > 1:
|
pay = pays[0]
|
||||||
raise SparkError(
|
payment_hash = pay["payment_hash"]
|
||||||
f"listpays({payment_hash}) returned an unexpected response: {listpays}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pay["status"] == "failed":
|
if len(pays) > 1:
|
||||||
return PaymentResponse(False, None, None, None, str(exc))
|
raise SparkError(
|
||||||
elif pay["status"] == "pending":
|
f"listpays({payment_hash}) returned an unexpected response: {listpays}"
|
||||||
return PaymentResponse(None, payment_hash, None, None, None)
|
)
|
||||||
elif pay["status"] == "complete":
|
|
||||||
r = pay
|
|
||||||
r["payment_preimage"] = pay["preimage"]
|
|
||||||
r["msatoshi"] = int(pay["amount_msat"][0:-4])
|
|
||||||
r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4])
|
|
||||||
# this may result in an error if it was paid previously
|
|
||||||
# our database won't allow the same payment_hash to be added twice
|
|
||||||
# this is good
|
|
||||||
|
|
||||||
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
|
if pay["status"] == "failed":
|
||||||
preimage = r["payment_preimage"]
|
return PaymentResponse(False, None, None, None, str(exc))
|
||||||
return PaymentResponse(True, r["payment_hash"], fee_msat, preimage, None)
|
|
||||||
|
if pay["status"] == "pending":
|
||||||
|
return PaymentResponse(None, payment_hash, None, None, None)
|
||||||
|
|
||||||
|
if pay["status"] == "complete":
|
||||||
|
r = pay
|
||||||
|
r["payment_preimage"] = pay["preimage"]
|
||||||
|
r["msatoshi"] = int(pay["amount_msat"][0:-4])
|
||||||
|
r["msatoshi_sent"] = int(pay["amount_sent_msat"][0:-4])
|
||||||
|
# this may result in an error if it was paid previously
|
||||||
|
# our database won't allow the same payment_hash to be added twice
|
||||||
|
# this is good
|
||||||
|
fee_msat = -int(r["msatoshi_sent"] - r["msatoshi"])
|
||||||
|
preimage = r["payment_preimage"]
|
||||||
|
return PaymentResponse(
|
||||||
|
True, r["payment_hash"], fee_msat, preimage, None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return PaymentResponse(False, None, None, None, str(exc))
|
||||||
|
|
||||||
async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
|
async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
|
||||||
try:
|
try:
|
||||||
|
@ -205,7 +218,7 @@ class SparkWallet(Wallet):
|
||||||
- int(r["pays"][0]["amount_msat"][0:-4])
|
- int(r["pays"][0]["amount_msat"][0:-4])
|
||||||
)
|
)
|
||||||
return PaymentStatus(True, fee_msat, r["pays"][0]["preimage"])
|
return PaymentStatus(True, fee_msat, r["pays"][0]["preimage"])
|
||||||
elif status == "failed":
|
if status == "failed":
|
||||||
return PaymentStatus(False)
|
return PaymentStatus(False)
|
||||||
return PaymentStatus(None)
|
return PaymentStatus(None)
|
||||||
raise KeyError("supplied an invalid checking_id")
|
raise KeyError("supplied an invalid checking_id")
|
||||||
|
|
|
@ -69,9 +69,6 @@ include = [
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
"lnbits/wallets/lnd_grpc_files",
|
"lnbits/wallets/lnd_grpc_files",
|
||||||
"lnbits/wallets",
|
|
||||||
"lnbits/core",
|
|
||||||
"lnbits/*.py",
|
|
||||||
"lnbits/extensions",
|
"lnbits/extensions",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user