fix invoices mypy issues

This commit is contained in:
dni ⚡ 2023-01-04 22:49:10 +01:00
parent 7e4a3a6831
commit 30eccab53c
6 changed files with 33 additions and 31 deletions

View File

@ -6,7 +6,6 @@ from . import db
from .models import ( from .models import (
CreateInvoiceData, CreateInvoiceData,
CreateInvoiceItemData, CreateInvoiceItemData,
CreatePaymentData,
Invoice, Invoice,
InvoiceItem, InvoiceItem,
Payment, Payment,
@ -30,7 +29,7 @@ async def get_invoice_items(invoice_id: str) -> List[InvoiceItem]:
return [InvoiceItem.from_row(row) for row in rows] return [InvoiceItem.from_row(row) for row in rows]
async def get_invoice_item(item_id: str) -> InvoiceItem: async def get_invoice_item(item_id: str) -> Optional[InvoiceItem]:
row = await db.fetchone( row = await db.fetchone(
"SELECT * FROM invoices.invoice_items WHERE id = ?", (item_id,) "SELECT * FROM invoices.invoice_items WHERE id = ?", (item_id,)
) )
@ -61,7 +60,7 @@ async def get_invoice_payments(invoice_id: str) -> List[Payment]:
return [Payment.from_row(row) for row in rows] return [Payment.from_row(row) for row in rows]
async def get_invoice_payment(payment_id: str) -> Payment: async def get_invoice_payment(payment_id: str) -> Optional[Payment]:
row = await db.fetchone( row = await db.fetchone(
"SELECT * FROM invoices.payments WHERE id = ?", (payment_id,) "SELECT * FROM invoices.payments WHERE id = ?", (payment_id,)
) )
@ -120,7 +119,9 @@ async def create_invoice_items(
return invoice_items return invoice_items
async def update_invoice_internal(wallet_id: str, data: UpdateInvoiceData) -> Invoice: async def update_invoice_internal(
wallet_id: str, data: Union[UpdateInvoiceData, Invoice]
) -> Invoice:
await db.execute( await db.execute(
""" """
UPDATE invoices.invoices UPDATE invoices.invoices
@ -155,21 +156,21 @@ async def update_invoice_items(
updated_items.append(item.id) updated_items.append(item.id)
await db.execute( await db.execute(
""" """
UPDATE invoices.invoice_items UPDATE invoices.invoice_items
SET description = ?, amount = ? SET description = ?, amount = ?
WHERE id = ? WHERE id = ?
""", """,
(item.description, int(item.amount * 100), item.id), (item.description, int(item.amount * 100), item.id),
) )
placeholders = ",".join("?" for i in range(len(updated_items))) placeholders = ",".join("?" for _ in range(len(updated_items)))
if not placeholders: if not placeholders:
placeholders = "?" placeholders = "?"
updated_items = ("skip",) updated_items = ["skip"]
await db.execute( await db.execute(
f""" f"""
DELETE FROM invoices.invoice_items DELETE FROM invoices.invoice_items
WHERE invoice_id = ? WHERE invoice_id = ?
AND id NOT IN ({placeholders}) AND id NOT IN ({placeholders})
""", """,
@ -180,8 +181,11 @@ async def update_invoice_items(
) )
for item in data: for item in data:
if not item.id: if not item:
await create_invoice_items(invoice_id=invoice_id, data=[item]) await create_invoice_items(
invoice_id=invoice_id,
data=[CreateInvoiceItemData(description=item.description)],
)
invoice_items = await get_invoice_items(invoice_id) invoice_items = await get_invoice_items(invoice_id)
return invoice_items return invoice_items

View File

@ -2,7 +2,7 @@ from enum import Enum
from sqlite3 import Row from sqlite3 import Row
from typing import List, Optional from typing import List, Optional
from fastapi.param_functions import Query from fastapi import Query
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -1,9 +1,7 @@
import asyncio import asyncio
import json
from lnbits.core.models import Payment from lnbits.core.models import Payment
from lnbits.helpers import urlsafe_short_hash from lnbits.tasks import register_invoice_listener
from lnbits.tasks import internal_invoice_queue, register_invoice_listener
from .crud import ( from .crud import (
create_invoice_payment, create_invoice_payment,
@ -14,6 +12,7 @@ from .crud import (
get_payments_total, get_payments_total,
update_invoice_internal, update_invoice_internal,
) )
from .models import InvoiceStatusEnum
async def wait_for_paid_invoices(): async def wait_for_paid_invoices():
@ -26,17 +25,22 @@ async def wait_for_paid_invoices():
async def on_invoice_paid(payment: Payment) -> None: async def on_invoice_paid(payment: Payment) -> None:
if not payment.extra:
return
if payment.extra.get("tag") != "invoices": if payment.extra.get("tag") != "invoices":
# not relevant
return return
invoice_id = payment.extra.get("invoice_id") invoice_id = payment.extra.get("invoice_id")
assert invoice_id
payment = await create_invoice_payment( amount = payment.extra.get("famount")
invoice_id=invoice_id, amount=payment.extra.get("famount") assert amount
)
await create_invoice_payment(invoice_id=invoice_id, amount=amount)
invoice = await get_invoice(invoice_id) invoice = await get_invoice(invoice_id)
assert invoice
invoice_items = await get_invoice_items(invoice_id) invoice_items = await get_invoice_items(invoice_id)
invoice_total = await get_invoice_total(invoice_items) invoice_total = await get_invoice_total(invoice_items)
@ -45,7 +49,7 @@ async def on_invoice_paid(payment: Payment) -> None:
payments_total = await get_payments_total(invoice_payments) payments_total = await get_payments_total(invoice_payments)
if payments_total >= invoice_total: if payments_total >= invoice_total:
invoice.status = "paid" invoice.status = InvoiceStatusEnum.paid
await update_invoice_internal(invoice.wallet, invoice) await update_invoice_internal(invoice.wallet, invoice)
return return

View File

@ -1,10 +1,8 @@
from datetime import datetime from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from fastapi import FastAPI, Request from fastapi import Depends, HTTPException, Request
from fastapi.params import Depends
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
from lnbits.core.models import User from lnbits.core.models import User

View File

@ -1,14 +1,12 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import Query from fastapi import Depends, HTTPException, Query
from fastapi.params import Depends
from loguru import logger from loguru import logger
from starlette.exceptions import HTTPException
from lnbits.core.crud import get_user from lnbits.core.crud import get_user
from lnbits.core.services import create_invoice from lnbits.core.services import create_invoice
from lnbits.core.views.api import api_payment from lnbits.core.views.api import api_payment
from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key from lnbits.decorators import WalletTypeInfo, get_key_type
from lnbits.utils.exchange_rates import fiat_amount_as_satoshis from lnbits.utils.exchange_rates import fiat_amount_as_satoshis
from . import invoices_ext from . import invoices_ext
@ -33,7 +31,8 @@ async def api_invoices(
): ):
wallet_ids = [wallet.wallet.id] wallet_ids = [wallet.wallet.id]
if all_wallets: if all_wallets:
wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids user = await get_user(wallet.wallet.user)
wallet_ids = user.wallet_ids if user else []
return [invoice.dict() for invoice in await get_invoices(wallet_ids)] return [invoice.dict() for invoice in await get_invoices(wallet_ids)]
@ -83,9 +82,7 @@ async def api_invoice_update(
@invoices_ext.post( @invoices_ext.post(
"/api/v1/invoice/{invoice_id}/payments", status_code=HTTPStatus.CREATED "/api/v1/invoice/{invoice_id}/payments", status_code=HTTPStatus.CREATED
) )
async def api_invoices_create_payment( async def api_invoices_create_payment(invoice_id: str, famount: int = Query(..., ge=1)):
famount: int = Query(..., ge=1), invoice_id: str = None
):
invoice = await get_invoice(invoice_id) invoice = await get_invoice(invoice_id)
invoice_items = await get_invoice_items(invoice_id) invoice_items = await get_invoice_items(invoice_id)
invoice_total = await get_invoice_total(invoice_items) invoice_total = await get_invoice_total(invoice_items)

View File

@ -92,7 +92,6 @@ exclude = """(?x)(
^lnbits/extensions/bleskomat. ^lnbits/extensions/bleskomat.
| ^lnbits/extensions/boltz. | ^lnbits/extensions/boltz.
| ^lnbits/extensions/boltcards. | ^lnbits/extensions/boltcards.
| ^lnbits/extensions/invoices.
| ^lnbits/extensions/livestream. | ^lnbits/extensions/livestream.
| ^lnbits/extensions/lnaddress. | ^lnbits/extensions/lnaddress.
| ^lnbits/extensions/lnurldevice. | ^lnbits/extensions/lnurldevice.