diff --git a/lnbits/extensions/invoices/crud.py b/lnbits/extensions/invoices/crud.py index 4fd055e9..9a05f9c5 100644 --- a/lnbits/extensions/invoices/crud.py +++ b/lnbits/extensions/invoices/crud.py @@ -6,7 +6,6 @@ from . import db from .models import ( CreateInvoiceData, CreateInvoiceItemData, - CreatePaymentData, Invoice, InvoiceItem, Payment, @@ -30,7 +29,7 @@ async def get_invoice_items(invoice_id: str) -> List[InvoiceItem]: return [InvoiceItem.from_row(row) for row in rows] -async def get_invoice_item(item_id: str) -> InvoiceItem: +async def get_invoice_item(item_id: str) -> Optional[InvoiceItem]: row = await db.fetchone( "SELECT * FROM invoices.invoice_items WHERE id = ?", (item_id,) ) @@ -61,7 +60,7 @@ async def get_invoice_payments(invoice_id: str) -> List[Payment]: return [Payment.from_row(row) for row in rows] -async def get_invoice_payment(payment_id: str) -> Payment: +async def get_invoice_payment(payment_id: str) -> Optional[Payment]: row = await db.fetchone( "SELECT * FROM invoices.payments WHERE id = ?", (payment_id,) ) @@ -120,7 +119,9 @@ async def create_invoice_items( return invoice_items -async def update_invoice_internal(wallet_id: str, data: UpdateInvoiceData) -> Invoice: +async def update_invoice_internal( + wallet_id: str, data: Union[UpdateInvoiceData, Invoice] +) -> Invoice: await db.execute( """ UPDATE invoices.invoices @@ -155,21 +156,21 @@ async def update_invoice_items( updated_items.append(item.id) await db.execute( """ - UPDATE invoices.invoice_items + UPDATE invoices.invoice_items SET description = ?, amount = ? WHERE id = ? """, (item.description, int(item.amount * 100), item.id), ) - placeholders = ",".join("?" for i in range(len(updated_items))) + placeholders = ",".join("?" for _ in range(len(updated_items))) if not placeholders: placeholders = "?" - updated_items = ("skip",) + updated_items = ["skip"] await db.execute( f""" - DELETE FROM invoices.invoice_items + DELETE FROM invoices.invoice_items WHERE invoice_id = ? AND id NOT IN ({placeholders}) """, @@ -180,8 +181,11 @@ async def update_invoice_items( ) for item in data: - if not item.id: - await create_invoice_items(invoice_id=invoice_id, data=[item]) + if not item: + await create_invoice_items( + invoice_id=invoice_id, + data=[CreateInvoiceItemData(description=item.description)], + ) invoice_items = await get_invoice_items(invoice_id) return invoice_items diff --git a/lnbits/extensions/invoices/models.py b/lnbits/extensions/invoices/models.py index adf03e46..6f0e63cb 100644 --- a/lnbits/extensions/invoices/models.py +++ b/lnbits/extensions/invoices/models.py @@ -2,7 +2,7 @@ from enum import Enum from sqlite3 import Row from typing import List, Optional -from fastapi.param_functions import Query +from fastapi import Query from pydantic import BaseModel diff --git a/lnbits/extensions/invoices/tasks.py b/lnbits/extensions/invoices/tasks.py index 61bcb7b4..ae76b9e3 100644 --- a/lnbits/extensions/invoices/tasks.py +++ b/lnbits/extensions/invoices/tasks.py @@ -1,9 +1,7 @@ import asyncio -import json from lnbits.core.models import Payment -from lnbits.helpers import urlsafe_short_hash -from lnbits.tasks import internal_invoice_queue, register_invoice_listener +from lnbits.tasks import register_invoice_listener from .crud import ( create_invoice_payment, @@ -14,6 +12,7 @@ from .crud import ( get_payments_total, update_invoice_internal, ) +from .models import InvoiceStatusEnum async def wait_for_paid_invoices(): @@ -26,17 +25,22 @@ async def wait_for_paid_invoices(): async def on_invoice_paid(payment: Payment) -> None: + if not payment.extra: + return + if payment.extra.get("tag") != "invoices": - # not relevant return invoice_id = payment.extra.get("invoice_id") + assert invoice_id - payment = await create_invoice_payment( - invoice_id=invoice_id, amount=payment.extra.get("famount") - ) + amount = payment.extra.get("famount") + assert amount + + await create_invoice_payment(invoice_id=invoice_id, amount=amount) invoice = await get_invoice(invoice_id) + assert invoice invoice_items = await get_invoice_items(invoice_id) invoice_total = await get_invoice_total(invoice_items) @@ -45,7 +49,7 @@ async def on_invoice_paid(payment: Payment) -> None: payments_total = await get_payments_total(invoice_payments) if payments_total >= invoice_total: - invoice.status = "paid" + invoice.status = InvoiceStatusEnum.paid await update_invoice_internal(invoice.wallet, invoice) return diff --git a/lnbits/extensions/invoices/views.py b/lnbits/extensions/invoices/views.py index b492a67c..cc35b351 100644 --- a/lnbits/extensions/invoices/views.py +++ b/lnbits/extensions/invoices/views.py @@ -1,10 +1,8 @@ from datetime import datetime from http import HTTPStatus -from fastapi import FastAPI, Request -from fastapi.params import Depends +from fastapi import Depends, HTTPException, Request from fastapi.templating import Jinja2Templates -from starlette.exceptions import HTTPException from starlette.responses import HTMLResponse from lnbits.core.models import User diff --git a/lnbits/extensions/invoices/views_api.py b/lnbits/extensions/invoices/views_api.py index 23a262e3..1a7762a8 100644 --- a/lnbits/extensions/invoices/views_api.py +++ b/lnbits/extensions/invoices/views_api.py @@ -1,14 +1,12 @@ from http import HTTPStatus -from fastapi import Query -from fastapi.params import Depends +from fastapi import Depends, HTTPException, Query from loguru import logger -from starlette.exceptions import HTTPException from lnbits.core.crud import get_user from lnbits.core.services import create_invoice from lnbits.core.views.api import api_payment -from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key +from lnbits.decorators import WalletTypeInfo, get_key_type from lnbits.utils.exchange_rates import fiat_amount_as_satoshis from . import invoices_ext @@ -33,7 +31,8 @@ async def api_invoices( ): wallet_ids = [wallet.wallet.id] if all_wallets: - wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids + user = await get_user(wallet.wallet.user) + wallet_ids = user.wallet_ids if user else [] return [invoice.dict() for invoice in await get_invoices(wallet_ids)] @@ -83,9 +82,7 @@ async def api_invoice_update( @invoices_ext.post( "/api/v1/invoice/{invoice_id}/payments", status_code=HTTPStatus.CREATED ) -async def api_invoices_create_payment( - famount: int = Query(..., ge=1), invoice_id: str = None -): +async def api_invoices_create_payment(invoice_id: str, famount: int = Query(..., ge=1)): invoice = await get_invoice(invoice_id) invoice_items = await get_invoice_items(invoice_id) invoice_total = await get_invoice_total(invoice_items) diff --git a/pyproject.toml b/pyproject.toml index 606420ac..5cfbf4dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,6 @@ exclude = """(?x)( ^lnbits/extensions/bleskomat. | ^lnbits/extensions/boltz. | ^lnbits/extensions/boltcards. - | ^lnbits/extensions/invoices. | ^lnbits/extensions/livestream. | ^lnbits/extensions/lnaddress. | ^lnbits/extensions/lnurldevice.