Merge pull request #1290 from lnbits/fix/mypy-withdraw

fix mypy issue on withdraw + some refactoring
This commit is contained in:
Arc 2023-01-04 19:31:27 +00:00 committed by GitHub
commit 76eb5a9309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 117 additions and 150 deletions

View File

@ -1,6 +1,8 @@
from datetime import datetime
from typing import List, Optional, Union
import shortuuid
from lnbits.helpers import urlsafe_short_hash
from . import db
@ -8,9 +10,10 @@ from .models import CreateWithdrawData, HashCheck, WithdrawLink
async def create_withdraw_link(
data: CreateWithdrawData, wallet_id: str, usescsv: str
data: CreateWithdrawData, wallet_id: str
) -> WithdrawLink:
link_id = urlsafe_short_hash()
available_links = ",".join([str(i) for i in range(data.uses)])
await db.execute(
"""
INSERT INTO withdraw.withdraw_link (
@ -45,7 +48,7 @@ async def create_withdraw_link(
urlsafe_short_hash(),
urlsafe_short_hash(),
int(datetime.now().timestamp()) + data.wait_time,
usescsv,
available_links,
data.webhook_url,
data.webhook_headers,
data.webhook_body,
@ -94,6 +97,26 @@ async def get_withdraw_links(wallet_ids: Union[str, List[str]]) -> List[Withdraw
return [WithdrawLink(**row) for row in rows]
async def remove_unique_withdraw_link(link: WithdrawLink, unique_hash: str) -> None:
unique_links = [
x.strip()
for x in link.usescsv.split(",")
if unique_hash != shortuuid.uuid(name=link.id + link.unique_hash + x.strip())
]
await update_withdraw_link(
link.id,
usescsv=",".join(unique_links),
)
async def increment_withdraw_link(link: WithdrawLink) -> None:
await update_withdraw_link(
link.id,
used=link.used + 1,
open_time=link.wait_time + int(datetime.now().timestamp()),
)
async def update_withdraw_link(link_id: str, **kwargs) -> Optional[WithdrawLink]:
if "is_unique" in kwargs:
kwargs["is_unique"] = int(kwargs["is_unique"])
@ -132,7 +155,7 @@ async def create_hash_check(the_hash: str, lnurl_id: str) -> HashCheck:
return hashCheck
async def get_hash_check(the_hash: str, lnurl_id: str) -> Optional[HashCheck]:
async def get_hash_check(the_hash: str, lnurl_id: str) -> HashCheck:
rowid = await db.fetchone(
"SELECT * FROM withdraw.hash_check WHERE id = ?", (the_hash,)
)
@ -141,10 +164,10 @@ async def get_hash_check(the_hash: str, lnurl_id: str) -> Optional[HashCheck]:
)
if not rowlnurl:
await create_hash_check(the_hash, lnurl_id)
return {"lnurl": True, "hash": False}
return HashCheck(lnurl=True, hash=False)
else:
if not rowid:
await create_hash_check(the_hash, lnurl_id)
return {"lnurl": True, "hash": False}
return HashCheck(lnurl=True, hash=False)
else:
return {"lnurl": True, "hash": True}
return HashCheck(lnurl=True, hash=True)

View File

@ -1,28 +1,27 @@
import json
import traceback
from datetime import datetime
from http import HTTPStatus
import httpx
import shortuuid # type: ignore
from fastapi import HTTPException
from fastapi.param_functions import Query
import shortuuid
from fastapi import HTTPException, Query, Request, Response
from loguru import logger
from starlette.requests import Request
from starlette.responses import HTMLResponse
from lnbits.core.crud import update_payment_extra
from lnbits.core.services import pay_invoice
from . import withdraw_ext
from .crud import get_withdraw_link_by_hash, update_withdraw_link
# FOR LNURLs WHICH ARE NOT UNIQUE
from .crud import (
get_withdraw_link_by_hash,
increment_withdraw_link,
remove_unique_withdraw_link,
)
from .models import WithdrawLink
@withdraw_ext.get(
"/api/v1/lnurl/{unique_hash}",
response_class=HTMLResponse,
response_class=Response,
name="withdraw.api_lnurl_response",
)
async def api_lnurl_response(request: Request, unique_hash):
@ -53,9 +52,6 @@ async def api_lnurl_response(request: Request, unique_hash):
return json.dumps(withdrawResponse)
# CALLBACK
@withdraw_ext.get(
"/api/v1/lnurl/cb/{unique_hash}",
name="withdraw.api_lnurl_callback",
@ -99,105 +95,79 @@ async def api_lnurl_callback(
detail=f"wait link open_time {link.open_time - now} seconds.",
)
usescsv = ""
for x in range(1, link.uses - link.used):
usecv = link.usescsv.split(",")
usescsv += "," + str(usecv[x])
usecsvback = usescsv
found = False
if id_unique_hash is not None:
useslist = link.usescsv.split(",")
for ind, x in enumerate(useslist):
tohash = link.id + link.unique_hash + str(x)
if id_unique_hash == shortuuid.uuid(name=tohash):
found = True
useslist.pop(ind)
usescsv = ",".join(useslist)
if not found:
if id_unique_hash:
if check_unique_link(link, id_unique_hash):
await remove_unique_withdraw_link(link, id_unique_hash)
else:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="withdraw not found."
)
else:
usescsv = usescsv[1:]
changesback = {
"open_time": link.wait_time,
"used": link.used,
"usescsv": usecsvback,
}
try:
changes = {
"open_time": link.wait_time + now,
"used": link.used + 1,
"usescsv": usescsv,
}
await update_withdraw_link(link.id, **changes)
payment_request = pr
payment_hash = await pay_invoice(
wallet_id=link.wallet,
payment_request=payment_request,
payment_request=pr,
max_sat=link.max_withdrawable,
extra={"tag": "withdraw"},
)
await increment_withdraw_link(link)
if link.webhook_url:
async with httpx.AsyncClient() as client:
try:
kwargs = {
"json": {
"payment_hash": payment_hash,
"payment_request": payment_request,
"lnurlw": link.id,
},
"timeout": 40,
}
if link.webhook_body:
kwargs["json"]["body"] = json.loads(link.webhook_body)
if link.webhook_headers:
kwargs["headers"] = json.loads(link.webhook_headers)
r: httpx.Response = await client.post(link.webhook_url, **kwargs)
await update_payment_extra(
payment_hash=payment_hash,
extra={
"wh_success": r.is_success,
"wh_message": r.reason_phrase,
"wh_response": r.text,
},
outgoing=True,
)
except Exception as exc:
# webhook fails shouldn't cause the lnurlw to fail since invoice is already paid
logger.error(
"Caught exception when dispatching webhook url: " + str(exc)
)
await update_payment_extra(
payment_hash=payment_hash,
extra={"wh_success": False, "wh_message": str(exc)},
outgoing=True,
)
await dispatch_webhook(link, payment_hash, pr)
return {"status": "OK"}
except Exception as e:
await update_withdraw_link(link.id, **changesback)
logger.error(traceback.format_exc())
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail=f"withdraw not working. {str(e)}"
)
def check_unique_link(link: WithdrawLink, unique_hash: str) -> bool:
return any(
unique_hash == shortuuid.uuid(name=link.id + link.unique_hash + x.strip())
for x in link.usescsv.split(",")
)
async def dispatch_webhook(
link: WithdrawLink, payment_hash: str, payment_request: str
) -> None:
async with httpx.AsyncClient() as client:
try:
r: httpx.Response = await client.post(
link.webhook_url,
json={
"payment_hash": payment_hash,
"payment_request": payment_request,
"lnurlw": link.id,
"body": json.loads(link.webhook_body) if link.webhook_body else "",
},
headers=json.loads(link.webhook_headers)
if link.webhook_headers
else None,
timeout=40,
)
await update_payment_extra(
payment_hash=payment_hash,
extra={
"wh_success": r.is_success,
"wh_message": r.reason_phrase,
"wh_response": r.text,
},
outgoing=True,
)
except Exception as exc:
# webhook fails shouldn't cause the lnurlw to fail since invoice is already paid
logger.error("Caught exception when dispatching webhook url: " + str(exc))
await update_payment_extra(
payment_hash=payment_hash,
extra={"wh_success": False, "wh_message": str(exc)},
outgoing=True,
)
# FOR LNURLs WHICH ARE UNIQUE
@withdraw_ext.get(
"/api/v1/lnurl/{unique_hash}/{id_unique_hash}",
response_class=HTMLResponse,
response_class=Response,
name="withdraw.api_lnurl_multi_response",
)
async def api_lnurl_multi_response(request: Request, unique_hash, id_unique_hash):
@ -213,14 +183,7 @@ async def api_lnurl_multi_response(request: Request, unique_hash, id_unique_hash
status_code=HTTPStatus.NOT_FOUND, detail="Withdraw is spent."
)
useslist = link.usescsv.split(",")
found = False
for x in useslist:
tohash = link.id + link.unique_hash + str(x)
if id_unique_hash == shortuuid.uuid(name=tohash):
found = True
if not found:
if not check_unique_link(link, id_unique_hash):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="LNURL-withdraw not found."
)

View File

@ -1,9 +1,8 @@
from sqlite3 import Row
import shortuuid # type: ignore
from fastapi.param_functions import Query
import shortuuid
from fastapi import Query
from lnurl import Lnurl, LnurlWithdrawResponse
from lnurl import encode as lnurl_encode # type: ignore
from lnurl import encode as lnurl_encode
from lnurl.models import ClearnetUrl, MilliSatoshi
from pydantic import BaseModel
from starlette.requests import Request
@ -67,18 +66,14 @@ class WithdrawLink(BaseModel):
name="withdraw.api_lnurl_callback", unique_hash=self.unique_hash
)
return LnurlWithdrawResponse(
callback=url,
callback=ClearnetUrl(url, scheme="https"),
k1=self.k1,
min_withdrawable=self.min_withdrawable * 1000,
max_withdrawable=self.max_withdrawable * 1000,
default_description=self.title,
minWithdrawable=MilliSatoshi(self.min_withdrawable * 1000),
maxWithdrawable=MilliSatoshi(self.max_withdrawable * 1000),
defaultDescription=self.title,
)
class HashCheck(BaseModel):
id: str
lnurl_id: str
@classmethod
def from_row(cls, row: Row) -> "Hash":
return cls(**dict(row))
hash: bool
lnurl: bool

View File

@ -2,10 +2,8 @@ from http import HTTPStatus
from io import BytesIO
import pyqrcode
from fastapi import Request
from fastapi.params import Depends
from fastapi import Depends, HTTPException, Request
from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse, StreamingResponse
from lnbits.core.models import User

View File

@ -1,10 +1,7 @@
from http import HTTPStatus
from fastapi.param_functions import Query
from fastapi.params import Depends
from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl # type: ignore
from starlette.exceptions import HTTPException
from starlette.requests import Request
from fastapi import Depends, HTTPException, Query, Request
from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl
from lnbits.core.crud import get_user
from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key
@ -30,7 +27,8 @@ async def api_links(
wallet_ids = [wallet.wallet.id]
if all_wallets:
wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids
user = await get_user(wallet.wallet.user)
wallet_ids = user.wallet_ids if user else []
try:
return [
@ -47,7 +45,7 @@ async def api_links(
@withdraw_ext.get("/api/v1/links/{link_id}", status_code=HTTPStatus.OK)
async def api_link_retrieve(
link_id, request: Request, wallet: WalletTypeInfo = Depends(get_key_type)
link_id: str, request: Request, wallet: WalletTypeInfo = Depends(get_key_type)
):
link = await get_withdraw_link(link_id, 0)
@ -68,7 +66,7 @@ async def api_link_retrieve(
async def api_link_create_or_update(
req: Request,
data: CreateWithdrawData,
link_id: str = None,
link_id: str = Query(None),
wallet: WalletTypeInfo = Depends(require_admin_key),
):
if data.uses > 250:
@ -85,14 +83,6 @@ async def api_link_create_or_update(
status_code=HTTPStatus.BAD_REQUEST,
)
usescsv = ""
for i in range(data.uses):
if data.is_unique:
usescsv += "," + str(i + 1)
else:
usescsv += "," + str(1)
usescsv = usescsv[1:]
if link_id:
link = await get_withdraw_link(link_id, 0)
if not link:
@ -103,13 +93,10 @@ async def api_link_create_or_update(
raise HTTPException(
detail="Not your withdraw link.", status_code=HTTPStatus.FORBIDDEN
)
link = await update_withdraw_link(
link_id, **data.dict(), usescsv=usescsv, used=0
)
link = await update_withdraw_link(link_id, **data.dict())
else:
link = await create_withdraw_link(
wallet_id=wallet.wallet.id, data=data, usescsv=usescsv
)
link = await create_withdraw_link(wallet_id=wallet.wallet.id, data=data)
assert link
return {**link.dict(), **{"lnurl": link.lnurl(req)}}
@ -131,9 +118,11 @@ async def api_link_delete(link_id, wallet: WalletTypeInfo = Depends(require_admi
return {"success": True}
@withdraw_ext.get("/api/v1/links/{the_hash}/{lnurl_id}", status_code=HTTPStatus.OK)
async def api_hash_retrieve(
the_hash, lnurl_id, wallet: WalletTypeInfo = Depends(get_key_type)
):
@withdraw_ext.get(
"/api/v1/links/{the_hash}/{lnurl_id}",
status_code=HTTPStatus.OK,
dependencies=[Depends(get_key_type)],
)
async def api_hash_retrieve(the_hash, lnurl_id):
hashCheck = await get_hash_check(the_hash, lnurl_id)
return hashCheck

View File

@ -101,7 +101,6 @@ exclude = """(?x)(
| ^lnbits/extensions/satspay.
| ^lnbits/extensions/streamalerts.
| ^lnbits/extensions/watchonly.
| ^lnbits/extensions/withdraw.
| ^lnbits/wallets/lnd_grpc_files.
)"""