fix invoices mypy issues
This commit is contained in:
parent
7e4a3a6831
commit
30eccab53c
|
@ -6,7 +6,6 @@ from . import db
|
|||
from .models import (
|
||||
CreateInvoiceData,
|
||||
CreateInvoiceItemData,
|
||||
CreatePaymentData,
|
||||
Invoice,
|
||||
InvoiceItem,
|
||||
Payment,
|
||||
|
@ -30,7 +29,7 @@ async def get_invoice_items(invoice_id: str) -> List[InvoiceItem]:
|
|||
return [InvoiceItem.from_row(row) for row in rows]
|
||||
|
||||
|
||||
async def get_invoice_item(item_id: str) -> InvoiceItem:
|
||||
async def get_invoice_item(item_id: str) -> Optional[InvoiceItem]:
|
||||
row = await db.fetchone(
|
||||
"SELECT * FROM invoices.invoice_items WHERE id = ?", (item_id,)
|
||||
)
|
||||
|
@ -61,7 +60,7 @@ async def get_invoice_payments(invoice_id: str) -> List[Payment]:
|
|||
return [Payment.from_row(row) for row in rows]
|
||||
|
||||
|
||||
async def get_invoice_payment(payment_id: str) -> Payment:
|
||||
async def get_invoice_payment(payment_id: str) -> Optional[Payment]:
|
||||
row = await db.fetchone(
|
||||
"SELECT * FROM invoices.payments WHERE id = ?", (payment_id,)
|
||||
)
|
||||
|
@ -120,7 +119,9 @@ async def create_invoice_items(
|
|||
return invoice_items
|
||||
|
||||
|
||||
async def update_invoice_internal(wallet_id: str, data: UpdateInvoiceData) -> Invoice:
|
||||
async def update_invoice_internal(
|
||||
wallet_id: str, data: Union[UpdateInvoiceData, Invoice]
|
||||
) -> Invoice:
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE invoices.invoices
|
||||
|
@ -155,21 +156,21 @@ async def update_invoice_items(
|
|||
updated_items.append(item.id)
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE invoices.invoice_items
|
||||
UPDATE invoices.invoice_items
|
||||
SET description = ?, amount = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(item.description, int(item.amount * 100), item.id),
|
||||
)
|
||||
|
||||
placeholders = ",".join("?" for i in range(len(updated_items)))
|
||||
placeholders = ",".join("?" for _ in range(len(updated_items)))
|
||||
if not placeholders:
|
||||
placeholders = "?"
|
||||
updated_items = ("skip",)
|
||||
updated_items = ["skip"]
|
||||
|
||||
await db.execute(
|
||||
f"""
|
||||
DELETE FROM invoices.invoice_items
|
||||
DELETE FROM invoices.invoice_items
|
||||
WHERE invoice_id = ?
|
||||
AND id NOT IN ({placeholders})
|
||||
""",
|
||||
|
@ -180,8 +181,11 @@ async def update_invoice_items(
|
|||
)
|
||||
|
||||
for item in data:
|
||||
if not item.id:
|
||||
await create_invoice_items(invoice_id=invoice_id, data=[item])
|
||||
if not item:
|
||||
await create_invoice_items(
|
||||
invoice_id=invoice_id,
|
||||
data=[CreateInvoiceItemData(description=item.description)],
|
||||
)
|
||||
|
||||
invoice_items = await get_invoice_items(invoice_id)
|
||||
return invoice_items
|
||||
|
|
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||
from sqlite3 import Row
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi.param_functions import Query
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
from lnbits.core.models import Payment
|
||||
from lnbits.helpers import urlsafe_short_hash
|
||||
from lnbits.tasks import internal_invoice_queue, register_invoice_listener
|
||||
from lnbits.tasks import register_invoice_listener
|
||||
|
||||
from .crud import (
|
||||
create_invoice_payment,
|
||||
|
@ -14,6 +12,7 @@ from .crud import (
|
|||
get_payments_total,
|
||||
update_invoice_internal,
|
||||
)
|
||||
from .models import InvoiceStatusEnum
|
||||
|
||||
|
||||
async def wait_for_paid_invoices():
|
||||
|
@ -26,17 +25,22 @@ async def wait_for_paid_invoices():
|
|||
|
||||
|
||||
async def on_invoice_paid(payment: Payment) -> None:
|
||||
if not payment.extra:
|
||||
return
|
||||
|
||||
if payment.extra.get("tag") != "invoices":
|
||||
# not relevant
|
||||
return
|
||||
|
||||
invoice_id = payment.extra.get("invoice_id")
|
||||
assert invoice_id
|
||||
|
||||
payment = await create_invoice_payment(
|
||||
invoice_id=invoice_id, amount=payment.extra.get("famount")
|
||||
)
|
||||
amount = payment.extra.get("famount")
|
||||
assert amount
|
||||
|
||||
await create_invoice_payment(invoice_id=invoice_id, amount=amount)
|
||||
|
||||
invoice = await get_invoice(invoice_id)
|
||||
assert invoice
|
||||
|
||||
invoice_items = await get_invoice_items(invoice_id)
|
||||
invoice_total = await get_invoice_total(invoice_items)
|
||||
|
@ -45,7 +49,7 @@ async def on_invoice_paid(payment: Payment) -> None:
|
|||
payments_total = await get_payments_total(invoice_payments)
|
||||
|
||||
if payments_total >= invoice_total:
|
||||
invoice.status = "paid"
|
||||
invoice.status = InvoiceStatusEnum.paid
|
||||
await update_invoice_internal(invoice.wallet, invoice)
|
||||
|
||||
return
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.params import Depends
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
from lnbits.core.models import User
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from http import HTTPStatus
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.params import Depends
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from lnbits.core.crud import get_user
|
||||
from lnbits.core.services import create_invoice
|
||||
from lnbits.core.views.api import api_payment
|
||||
from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key
|
||||
from lnbits.decorators import WalletTypeInfo, get_key_type
|
||||
from lnbits.utils.exchange_rates import fiat_amount_as_satoshis
|
||||
|
||||
from . import invoices_ext
|
||||
|
@ -33,7 +31,8 @@ async def api_invoices(
|
|||
):
|
||||
wallet_ids = [wallet.wallet.id]
|
||||
if all_wallets:
|
||||
wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids
|
||||
user = await get_user(wallet.wallet.user)
|
||||
wallet_ids = user.wallet_ids if user else []
|
||||
|
||||
return [invoice.dict() for invoice in await get_invoices(wallet_ids)]
|
||||
|
||||
|
@ -83,9 +82,7 @@ async def api_invoice_update(
|
|||
@invoices_ext.post(
|
||||
"/api/v1/invoice/{invoice_id}/payments", status_code=HTTPStatus.CREATED
|
||||
)
|
||||
async def api_invoices_create_payment(
|
||||
famount: int = Query(..., ge=1), invoice_id: str = None
|
||||
):
|
||||
async def api_invoices_create_payment(invoice_id: str, famount: int = Query(..., ge=1)):
|
||||
invoice = await get_invoice(invoice_id)
|
||||
invoice_items = await get_invoice_items(invoice_id)
|
||||
invoice_total = await get_invoice_total(invoice_items)
|
||||
|
|
|
@ -92,7 +92,6 @@ exclude = """(?x)(
|
|||
^lnbits/extensions/bleskomat.
|
||||
| ^lnbits/extensions/boltz.
|
||||
| ^lnbits/extensions/boltcards.
|
||||
| ^lnbits/extensions/invoices.
|
||||
| ^lnbits/extensions/livestream.
|
||||
| ^lnbits/extensions/lnaddress.
|
||||
| ^lnbits/extensions/lnurldevice.
|
||||
|
|
Loading…
Reference in New Issue
Block a user