Serverside Pagination for payments (#1613)
* initial backend support
* implement payments pagination on frontend
* implement search for payments api
* fix pyright issues
* sqlite support for searching
* backwards compatability
* formatting, small fixes
* small optimization
* fix sorting issue, add error handling
* GET payments test
* filter by dates, use List instead of list
* fix sqlite
* update bundle
* test old payments endpoint aswell
* refactor for easier review
* optimise test
* revert unnecessary change
---------
Co-authored-by: dni ⚡ <office@dnilabs.com>
This commit is contained in:
parent
45b199a8ef
commit
c0f66989cb
|
@ -7,12 +7,12 @@ from uuid import uuid4
|
|||
import shortuuid
|
||||
|
||||
from lnbits import bolt11
|
||||
from lnbits.db import COCKROACH, POSTGRES, Connection, Filters
|
||||
from lnbits.db import Connection, Filters, Page
|
||||
from lnbits.extension_manager import InstallableExtension
|
||||
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
|
||||
|
||||
from . import db
|
||||
from .models import BalanceCheck, Payment, TinyURL, User, Wallet
|
||||
from .models import BalanceCheck, Payment, PaymentFilters, TinyURL, User, Wallet
|
||||
|
||||
# accounts
|
||||
# --------
|
||||
|
@ -343,7 +343,7 @@ async def get_latest_payments_by_extension(ext_name: str, ext_id: str, limit: in
|
|||
return rows
|
||||
|
||||
|
||||
async def get_payments(
|
||||
async def get_payments_paginated(
|
||||
*,
|
||||
wallet_id: Optional[str] = None,
|
||||
complete: bool = False,
|
||||
|
@ -352,28 +352,23 @@ async def get_payments(
|
|||
incoming: bool = False,
|
||||
since: Optional[int] = None,
|
||||
exclude_uncheckable: bool = False,
|
||||
filters: Optional[Filters[Payment]] = None,
|
||||
filters: Optional[Filters[PaymentFilters]] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Payment]:
|
||||
) -> Page[Payment]:
|
||||
"""
|
||||
Filters payments to be returned by complete | pending | outgoing | incoming.
|
||||
"""
|
||||
|
||||
args: List[Any] = []
|
||||
values: List[Any] = []
|
||||
clause: List[str] = []
|
||||
|
||||
if since is not None:
|
||||
if db.type == POSTGRES:
|
||||
clause.append("time > to_timestamp(?)")
|
||||
elif db.type == COCKROACH:
|
||||
clause.append("time > cast(? AS timestamp)")
|
||||
else:
|
||||
clause.append("time > ?")
|
||||
args.append(since)
|
||||
clause.append(f"time > {db.timestamp_placeholder}")
|
||||
values.append(since)
|
||||
|
||||
if wallet_id:
|
||||
clause.append("wallet = ?")
|
||||
args.append(wallet_id)
|
||||
values.append(wallet_id)
|
||||
|
||||
if complete and pending:
|
||||
pass
|
||||
|
@ -397,21 +392,54 @@ async def get_payments(
|
|||
clause.append("checking_id NOT LIKE 'temp_%'")
|
||||
clause.append("checking_id NOT LIKE 'internal_%'")
|
||||
|
||||
if not filters:
|
||||
filters = Filters(limit=None, offset=None)
|
||||
|
||||
rows = await (conn or db).fetchall(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM apipayments
|
||||
{filters.where(clause)}
|
||||
ORDER BY time DESC
|
||||
{filters.pagination()}
|
||||
""",
|
||||
filters.values(args),
|
||||
return await (conn or db).fetch_page(
|
||||
"SELECT * FROM apipayments",
|
||||
clause,
|
||||
values,
|
||||
filters=filters,
|
||||
model=Payment,
|
||||
)
|
||||
|
||||
return [Payment.from_row(row) for row in rows]
|
||||
|
||||
async def get_payments(
|
||||
*,
|
||||
wallet_id: Optional[str] = None,
|
||||
complete: bool = False,
|
||||
pending: bool = False,
|
||||
outgoing: bool = False,
|
||||
incoming: bool = False,
|
||||
since: Optional[int] = None,
|
||||
exclude_uncheckable: bool = False,
|
||||
filters: Optional[Filters[PaymentFilters]] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> list[Payment]:
|
||||
"""
|
||||
Filters payments to be returned by complete | pending | outgoing | incoming.
|
||||
"""
|
||||
|
||||
if not filters:
|
||||
filters = Filters()
|
||||
|
||||
if limit:
|
||||
filters.limit = limit
|
||||
if offset:
|
||||
filters.offset = offset
|
||||
|
||||
page = await get_payments_paginated(
|
||||
wallet_id=wallet_id,
|
||||
complete=complete,
|
||||
pending=pending,
|
||||
outgoing=outgoing,
|
||||
incoming=incoming,
|
||||
since=since,
|
||||
exclude_uncheckable=exclude_uncheckable,
|
||||
filters=filters,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
return page.data
|
||||
|
||||
|
||||
async def delete_expired_invoices(
|
||||
|
@ -454,7 +482,6 @@ async def create_payment(
|
|||
webhook: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Payment:
|
||||
|
||||
# todo: add this when tests are fixed
|
||||
# previous_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
|
||||
# assert previous_payment is None, "Payment already exists"
|
||||
|
@ -514,7 +541,6 @@ async def update_payment_details(
|
|||
new_checking_id: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
|
||||
set_clause: List[str] = []
|
||||
set_variables: List[Any] = []
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from lnurl import encode as lnurl_encode
|
|||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lnbits.db import Connection
|
||||
from lnbits.db import Connection, FilterModel, FromRowModel
|
||||
from lnbits.helpers import url_for
|
||||
from lnbits.settings import get_wallet_class, settings
|
||||
from lnbits.wallets.base import PaymentStatus
|
||||
|
@ -86,7 +86,7 @@ class User(BaseModel):
|
|||
return False
|
||||
|
||||
|
||||
class Payment(BaseModel):
|
||||
class Payment(FromRowModel):
|
||||
checking_id: str
|
||||
pending: bool
|
||||
amount: int
|
||||
|
@ -214,6 +214,24 @@ class Payment(BaseModel):
|
|||
await delete_payment(self.checking_id, conn=conn)
|
||||
|
||||
|
||||
class PaymentFilters(FilterModel):
|
||||
__search_fields__ = ["memo", "amount"]
|
||||
|
||||
checking_id: str
|
||||
amount: int
|
||||
fee: int
|
||||
memo: Optional[str]
|
||||
time: datetime.datetime
|
||||
bolt11: str
|
||||
preimage: str
|
||||
payment_hash: str
|
||||
expiry: Optional[datetime.datetime]
|
||||
extra: Dict = {}
|
||||
wallet_id: str
|
||||
webhook: Optional[str]
|
||||
webhook_status: Optional[int]
|
||||
|
||||
|
||||
class BalanceCheck(BaseModel):
|
||||
wallet: str
|
||||
service: str
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// update cache version every time there is a new deployment
|
||||
// so the service worker reinitializes the cache
|
||||
const CACHE_VERSION = 5
|
||||
const CACHE_VERSION = 6
|
||||
const CURRENT_CACHE = `lnbits-${CACHE_VERSION}-`
|
||||
|
||||
const getApiKey = request => {
|
||||
|
|
|
@ -152,14 +152,14 @@ new Vue({
|
|||
field: 'memo'
|
||||
},
|
||||
{
|
||||
name: 'date',
|
||||
name: 'time',
|
||||
align: 'left',
|
||||
label: this.$t('date'),
|
||||
field: 'date',
|
||||
sortable: true
|
||||
},
|
||||
{
|
||||
name: 'sat',
|
||||
name: 'amount',
|
||||
align: 'right',
|
||||
label: this.$t('amount') + ' (' + LNBITS_DENOMINATION + ')',
|
||||
field: 'sat',
|
||||
|
@ -173,9 +173,14 @@ new Vue({
|
|||
}
|
||||
],
|
||||
pagination: {
|
||||
rowsPerPage: 10
|
||||
rowsPerPage: 10,
|
||||
page: 1,
|
||||
sortBy: 'time',
|
||||
descending: true,
|
||||
rowsNumber: 10
|
||||
},
|
||||
filter: null
|
||||
filter: null,
|
||||
loading: false
|
||||
},
|
||||
paymentsChart: {
|
||||
show: false
|
||||
|
@ -695,15 +700,34 @@ new Vue({
|
|||
LNbits.href.deleteWallet(walletId, user)
|
||||
})
|
||||
},
|
||||
fetchPayments: function () {
|
||||
return LNbits.api.getPayments(this.g.wallet).then(response => {
|
||||
this.payments = response.data
|
||||
.map(obj => {
|
||||
fetchPayments: function (props) {
|
||||
// Props are passed by qasar when pagination or sorting changes
|
||||
if (props) {
|
||||
this.paymentsTable.pagination = props.pagination
|
||||
}
|
||||
let pagination = this.paymentsTable.pagination
|
||||
this.paymentsTable.loading = true
|
||||
const query = {
|
||||
limit: pagination.rowsPerPage,
|
||||
offset: (pagination.page - 1) * pagination.rowsPerPage,
|
||||
sortby: pagination.sortBy ?? 'time',
|
||||
direction: pagination.descending ? 'desc' : 'asc'
|
||||
}
|
||||
if (this.paymentsTable.filter) {
|
||||
query.search = this.paymentsTable.filter
|
||||
}
|
||||
return LNbits.api
|
||||
.getPayments(this.g.wallet, query)
|
||||
.then(response => {
|
||||
this.paymentsTable.loading = false
|
||||
this.paymentsTable.pagination.rowsNumber = response.data.total
|
||||
this.payments = response.data.data.map(obj => {
|
||||
return LNbits.map.payment(obj)
|
||||
})
|
||||
.sort((a, b) => {
|
||||
return b.time - a.time
|
||||
})
|
||||
.catch(err => {
|
||||
this.paymentsTable.loading = false
|
||||
LNbits.utils.notifyApiError(err)
|
||||
})
|
||||
},
|
||||
fetchBalance: function () {
|
||||
|
|
|
@ -125,7 +125,6 @@
|
|||
</div>
|
||||
</div>
|
||||
<q-input
|
||||
v-if="payments.length > 10"
|
||||
filled
|
||||
dense
|
||||
clearable
|
||||
|
@ -138,12 +137,14 @@
|
|||
<q-table
|
||||
dense
|
||||
flat
|
||||
:data="filteredPayments"
|
||||
:data="payments"
|
||||
:row-key="paymentTableRowKey"
|
||||
:columns="paymentsTable.columns"
|
||||
:pagination.sync="paymentsTable.pagination"
|
||||
:no-data-label="$t('no_transactions')"
|
||||
:filter="paymentsTable.filter"
|
||||
:loading="paymentsTable.loading"
|
||||
@request="fetchPayments"
|
||||
>
|
||||
{% raw %}
|
||||
<template v-slot:header="props">
|
||||
|
@ -192,14 +193,14 @@
|
|||
</q-badge>
|
||||
{{ props.row.memo }}
|
||||
</q-td>
|
||||
<q-td auto-width key="date" :props="props">
|
||||
<q-td auto-width key="time" :props="props">
|
||||
<q-tooltip>{{ props.row.date }}</q-tooltip>
|
||||
{{ props.row.dateFrom }}
|
||||
</q-td>
|
||||
{% endraw %}
|
||||
<q-td
|
||||
auto-width
|
||||
key="sat"
|
||||
key="amount"
|
||||
v-if="'{{LNBITS_DENOMINATION}}' != 'sats'"
|
||||
:props="props"
|
||||
>{% raw %} {{
|
||||
|
@ -207,7 +208,7 @@
|
|||
}}
|
||||
</q-td>
|
||||
|
||||
<q-td auto-width key="sat" v-else :props="props">
|
||||
<q-td auto-width key="amount" v-else :props="props">
|
||||
{{ props.row.fsat }}
|
||||
</q-td>
|
||||
<q-td auto-width key="fee" :props="props">
|
||||
|
|
|
@ -33,8 +33,8 @@ from lnbits.core.helpers import (
|
|||
migrate_extension_database,
|
||||
stop_extension_background_work,
|
||||
)
|
||||
from lnbits.core.models import Payment, User, Wallet
|
||||
from lnbits.db import Filters
|
||||
from lnbits.core.models import Payment, PaymentFilters, User, Wallet
|
||||
from lnbits.db import Filters, Page
|
||||
from lnbits.decorators import (
|
||||
WalletTypeInfo,
|
||||
check_admin,
|
||||
|
@ -66,6 +66,7 @@ from ..crud import (
|
|||
delete_tinyurl,
|
||||
get_dbversions,
|
||||
get_payments,
|
||||
get_payments_paginated,
|
||||
get_standalone_payment,
|
||||
get_tinyurl,
|
||||
get_tinyurl_by_url,
|
||||
|
@ -122,19 +123,19 @@ async def api_update_wallet(
|
|||
summary="get list of payments",
|
||||
response_description="list of payments",
|
||||
response_model=List[Payment],
|
||||
openapi_extra=generate_filter_params_openapi(Payment),
|
||||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||
)
|
||||
async def api_payments(
|
||||
wallet: WalletTypeInfo = Depends(get_key_type),
|
||||
filters: Filters = Depends(parse_filters(Payment)),
|
||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||
):
|
||||
pendingPayments = await get_payments(
|
||||
pending_payments = await get_payments(
|
||||
wallet_id=wallet.wallet.id,
|
||||
pending=True,
|
||||
exclude_uncheckable=True,
|
||||
filters=filters,
|
||||
)
|
||||
for payment in pendingPayments:
|
||||
for payment in pending_payments:
|
||||
await check_transaction_status(
|
||||
wallet_id=payment.wallet_id, payment_hash=payment.payment_hash
|
||||
)
|
||||
|
@ -146,6 +147,37 @@ async def api_payments(
|
|||
)
|
||||
|
||||
|
||||
@core_app.get(
|
||||
"/api/v1/payments/paginated",
|
||||
name="Payment List",
|
||||
summary="get paginated list of payments",
|
||||
response_description="list of payments",
|
||||
response_model=Page[Payment],
|
||||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||
)
|
||||
async def api_payments_paginated(
|
||||
wallet: WalletTypeInfo = Depends(get_key_type),
|
||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||
):
|
||||
pending = await get_payments_paginated(
|
||||
wallet_id=wallet.wallet.id,
|
||||
pending=True,
|
||||
exclude_uncheckable=True,
|
||||
filters=filters,
|
||||
)
|
||||
for payment in pending.data:
|
||||
await check_transaction_status(
|
||||
wallet_id=payment.wallet_id, payment_hash=payment.payment_hash
|
||||
)
|
||||
page = await get_payments_paginated(
|
||||
wallet_id=wallet.wallet.id,
|
||||
pending=True,
|
||||
complete=True,
|
||||
filters=filters,
|
||||
)
|
||||
return page
|
||||
|
||||
|
||||
class CreateInvoiceData(BaseModel):
|
||||
out: Optional[bool] = True
|
||||
amount: float = Query(None, ge=0)
|
||||
|
@ -788,7 +820,6 @@ async def api_install_extension(
|
|||
|
||||
@core_app.delete("/api/v1/extension/{ext_id}")
|
||||
async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)):
|
||||
|
||||
installable_extensions = await InstallableExtension.get_installable_extensions()
|
||||
|
||||
extensions = [e for e in installable_extensions if e.id == ext_id]
|
||||
|
|
273
lnbits/db.py
273
lnbits/db.py
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
|
@ -5,7 +7,8 @@ import re
|
|||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
from sqlite3 import Row
|
||||
from typing import Any, Generic, List, Literal, Optional, Type, TypeVar
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
@ -19,6 +22,51 @@ POSTGRES = "POSTGRES"
|
|||
COCKROACH = "COCKROACH"
|
||||
SQLITE = "SQLITE"
|
||||
|
||||
if settings.lnbits_database_url:
|
||||
database_uri = settings.lnbits_database_url
|
||||
|
||||
if database_uri.startswith("cockroachdb://"):
|
||||
DB_TYPE = COCKROACH
|
||||
else:
|
||||
DB_TYPE = POSTGRES
|
||||
|
||||
from psycopg2.extensions import DECIMAL, new_type, register_type
|
||||
|
||||
def _parse_timestamp(value, _):
|
||||
if value is None:
|
||||
return None
|
||||
f = "%Y-%m-%d %H:%M:%S.%f"
|
||||
if "." not in value:
|
||||
f = "%Y-%m-%d %H:%M:%S"
|
||||
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||
|
||||
register_type(
|
||||
new_type(
|
||||
DECIMAL.values,
|
||||
"DEC2FLOAT",
|
||||
lambda value, curs: float(value) if value is not None else None,
|
||||
)
|
||||
)
|
||||
register_type(
|
||||
new_type(
|
||||
(1082, 1083, 1266),
|
||||
"DATE2INT",
|
||||
lambda value, curs: time.mktime(value.timetuple())
|
||||
if value is not None
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
|
||||
else:
|
||||
if os.path.isdir(settings.lnbits_data_folder):
|
||||
DB_TYPE = SQLITE
|
||||
else:
|
||||
raise NotADirectoryError(
|
||||
f"LNBITS_DATA_FOLDER named {settings.lnbits_data_folder} was not created"
|
||||
f" - please 'mkdir {settings.lnbits_data_folder}' and try again"
|
||||
)
|
||||
|
||||
|
||||
class Compat:
|
||||
type: Optional[str] = "<inherited>"
|
||||
|
@ -68,6 +116,16 @@ class Compat:
|
|||
return "BIGINT"
|
||||
return "INT"
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def timestamp_placeholder(cls):
|
||||
if DB_TYPE == POSTGRES:
|
||||
return "to_timestamp(?)"
|
||||
elif DB_TYPE == COCKROACH:
|
||||
return "cast(? AS timestamp)"
|
||||
else:
|
||||
return "?"
|
||||
|
||||
|
||||
class Connection(Compat):
|
||||
def __init__(self, conn: AsyncConnection, txn, typ, name, schema):
|
||||
|
@ -87,17 +145,21 @@ class Connection(Compat):
|
|||
# strip html
|
||||
CLEANR = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
|
||||
|
||||
def cleanhtml(raw_html):
|
||||
if isinstance(raw_html, str):
|
||||
cleantext = re.sub(CLEANR, "", raw_html)
|
||||
return cleantext
|
||||
else:
|
||||
return raw_html
|
||||
|
||||
# tuple to list and back to tuple
|
||||
value_list = [values] if isinstance(values, str) else list(values)
|
||||
values = tuple([cleanhtml(val) for val in value_list])
|
||||
return values
|
||||
raw_values = [values] if isinstance(values, str) else list(values)
|
||||
values = []
|
||||
for raw_value in raw_values:
|
||||
if isinstance(raw_value, str):
|
||||
values.append(re.sub(CLEANR, "", raw_value))
|
||||
elif isinstance(raw_value, datetime.datetime):
|
||||
ts = raw_value.timestamp()
|
||||
if self.type == SQLITE:
|
||||
values.append(int(ts))
|
||||
else:
|
||||
values.append(ts)
|
||||
else:
|
||||
values.append(raw_value)
|
||||
return tuple(values)
|
||||
|
||||
async def fetchall(self, query: str, values: tuple = ()) -> list:
|
||||
result = await self.conn.execute(
|
||||
|
@ -113,6 +175,51 @@ class Connection(Compat):
|
|||
await result.close()
|
||||
return row
|
||||
|
||||
async def fetch_page(
|
||||
self,
|
||||
query: str,
|
||||
where: Optional[List[str]] = None,
|
||||
values: Optional[List[str]] = None,
|
||||
filters: Optional[Filters] = None,
|
||||
model: Optional[Type[TRowModel]] = None,
|
||||
) -> Page[TRowModel]:
|
||||
if not filters:
|
||||
filters = Filters()
|
||||
clause = filters.where(where)
|
||||
parsed_values = filters.values(values)
|
||||
|
||||
rows = await self.fetchall(
|
||||
f"""
|
||||
{query}
|
||||
{clause}
|
||||
{filters.order_by()}
|
||||
{filters.pagination()}
|
||||
""",
|
||||
parsed_values,
|
||||
)
|
||||
if rows:
|
||||
# no need for extra query if no pagination is specified
|
||||
if filters.offset or filters.limit:
|
||||
count = await self.fetchone(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM (
|
||||
{query}
|
||||
{clause}
|
||||
) as count
|
||||
""",
|
||||
parsed_values,
|
||||
)
|
||||
count = int(count[0])
|
||||
else:
|
||||
count = len(rows)
|
||||
else:
|
||||
count = 0
|
||||
|
||||
return Page(
|
||||
data=[model.from_row(row) for row in rows] if model else rows,
|
||||
total=count,
|
||||
)
|
||||
|
||||
async def execute(self, query: str, values: tuple = ()):
|
||||
return await self.conn.execute(
|
||||
self.rewrite_query(query), self.rewrite_values(values)
|
||||
|
@ -122,57 +229,17 @@ class Connection(Compat):
|
|||
class Database(Compat):
|
||||
def __init__(self, db_name: str):
|
||||
self.name = db_name
|
||||
self.schema = self.name
|
||||
self.type = DB_TYPE
|
||||
|
||||
if settings.lnbits_database_url:
|
||||
database_uri = settings.lnbits_database_url
|
||||
|
||||
if database_uri.startswith("cockroachdb://"):
|
||||
self.type = COCKROACH
|
||||
else:
|
||||
self.type = POSTGRES
|
||||
|
||||
from psycopg2.extensions import DECIMAL, new_type, register_type
|
||||
|
||||
def _parse_timestamp(value, _):
|
||||
if value is None:
|
||||
return None
|
||||
f = "%Y-%m-%d %H:%M:%S.%f"
|
||||
if "." not in value:
|
||||
f = "%Y-%m-%d %H:%M:%S"
|
||||
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||
|
||||
register_type(
|
||||
new_type(
|
||||
DECIMAL.values,
|
||||
"DEC2FLOAT",
|
||||
lambda value, curs: float(value) if value is not None else None,
|
||||
)
|
||||
)
|
||||
register_type(
|
||||
new_type(
|
||||
(1082, 1083, 1266),
|
||||
"DATE2INT",
|
||||
lambda value, curs: time.mktime(value.timetuple())
|
||||
if value is not None
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
|
||||
else:
|
||||
if os.path.isdir(settings.lnbits_data_folder):
|
||||
if DB_TYPE == SQLITE:
|
||||
self.path = os.path.join(
|
||||
settings.lnbits_data_folder, f"{self.name}.sqlite3"
|
||||
)
|
||||
database_uri = f"sqlite:///{self.path}"
|
||||
self.type = SQLITE
|
||||
else:
|
||||
raise NotADirectoryError(
|
||||
f"LNBITS_DATA_FOLDER named {settings.lnbits_data_folder} was not created"
|
||||
f" - please 'mkdir {settings.lnbits_data_folder}' and try again"
|
||||
)
|
||||
logger.trace(f"database {self.type} added for {self.name}")
|
||||
self.schema = self.name
|
||||
database_uri = settings.lnbits_database_url
|
||||
|
||||
if self.name.startswith("ext_"):
|
||||
self.schema = self.name[4:]
|
||||
else:
|
||||
|
@ -181,6 +248,8 @@ class Database(Compat):
|
|||
self.engine = create_engine(database_uri, strategy=ASYNCIO_STRATEGY)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
logger.trace(f"database {self.type} added for {self.name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
await self.lock.acquire()
|
||||
|
@ -215,6 +284,17 @@ class Database(Compat):
|
|||
await result.close()
|
||||
return row
|
||||
|
||||
async def fetch_page(
|
||||
self,
|
||||
query: str,
|
||||
where: Optional[List[str]] = None,
|
||||
values: Optional[List[str]] = None,
|
||||
filters: Optional[Filters] = None,
|
||||
model: Optional[Type[TRowModel]] = None,
|
||||
) -> Page[TRowModel]:
|
||||
async with self.connect() as conn:
|
||||
return await conn.fetch_page(query, where, values, filters, model)
|
||||
|
||||
async def execute(self, query: str, values: tuple = ()):
|
||||
async with self.connect() as conn:
|
||||
return await conn.execute(query, values)
|
||||
|
@ -229,6 +309,8 @@ class Operator(Enum):
|
|||
LT = "lt"
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
GE = "ge"
|
||||
LE = "le"
|
||||
INCLUDE = "in"
|
||||
EXCLUDE = "ex"
|
||||
|
||||
|
@ -246,21 +328,45 @@ class Operator(Enum):
|
|||
return ">"
|
||||
elif self == Operator.LT:
|
||||
return "<"
|
||||
elif self == Operator.GE:
|
||||
return ">="
|
||||
elif self == Operator.LE:
|
||||
return "<="
|
||||
else:
|
||||
raise ValueError("Unknown SQL Operator")
|
||||
|
||||
|
||||
class FromRowModel(BaseModel):
|
||||
@classmethod
|
||||
def from_row(cls, row: Row):
|
||||
return cls(**dict(row))
|
||||
|
||||
|
||||
class FilterModel(BaseModel):
|
||||
__search_fields__: List[str] = []
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
TModel = TypeVar("TModel", bound=BaseModel)
|
||||
TRowModel = TypeVar("TRowModel", bound=FromRowModel)
|
||||
TFilterModel = TypeVar("TFilterModel", bound=FilterModel)
|
||||
|
||||
|
||||
class Filter(BaseModel, Generic[TModel]):
|
||||
class Page(BaseModel, Generic[T]):
|
||||
data: list[T]
|
||||
total: int
|
||||
|
||||
|
||||
class Filter(BaseModel, Generic[TFilterModel]):
|
||||
field: str
|
||||
nested: Optional[list[str]]
|
||||
nested: Optional[List[str]]
|
||||
op: Operator = Operator.EQ
|
||||
values: list[Any]
|
||||
|
||||
model: Optional[Type[TFilterModel]]
|
||||
|
||||
@classmethod
|
||||
def parse_query(cls, key: str, raw_values: list[Any], model: Type[TModel]):
|
||||
def parse_query(cls, key: str, raw_values: list[Any], model: Type[TFilterModel]):
|
||||
# Key format:
|
||||
# key[operator]
|
||||
# e.g. name[eq]
|
||||
|
@ -300,7 +406,7 @@ class Filter(BaseModel, Generic[TModel]):
|
|||
else:
|
||||
raise ValueError("Unknown filter field")
|
||||
|
||||
return cls(field=field, op=op, nested=nested, values=values)
|
||||
return cls(field=field, op=op, nested=nested, values=values, model=model)
|
||||
|
||||
@property
|
||||
def statement(self):
|
||||
|
@ -308,18 +414,29 @@ class Filter(BaseModel, Generic[TModel]):
|
|||
if self.nested:
|
||||
for name in self.nested:
|
||||
accessor = f"({accessor} ->> '{name}')"
|
||||
if self.model and self.model.__fields__[self.field].type_ == datetime.datetime:
|
||||
placeholder = Compat.timestamp_placeholder
|
||||
else:
|
||||
placeholder = "?"
|
||||
if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
|
||||
placeholders = ", ".join(["?"] * len(self.values))
|
||||
placeholders = ", ".join([placeholder] * len(self.values))
|
||||
stmt = [f"{accessor} {self.op.as_sql} ({placeholders})"]
|
||||
else:
|
||||
stmt = [f"{accessor} {self.op.as_sql} ?"] * len(self.values)
|
||||
stmt = [f"{accessor} {self.op.as_sql} {placeholder}"] * len(self.values)
|
||||
return " OR ".join(stmt)
|
||||
|
||||
|
||||
class Filters(BaseModel, Generic[TModel]):
|
||||
filters: List[Filter[TModel]] = []
|
||||
limit: Optional[int]
|
||||
offset: Optional[int]
|
||||
class Filters(BaseModel, Generic[TFilterModel]):
|
||||
filters: List[Filter[TFilterModel]] = []
|
||||
search: Optional[str] = None
|
||||
|
||||
offset: Optional[int] = None
|
||||
limit: Optional[int] = None
|
||||
|
||||
sortby: Optional[str] = None
|
||||
direction: Optional[Literal["asc", "desc"]] = None
|
||||
|
||||
model: Optional[Type[TFilterModel]] = None
|
||||
|
||||
def pagination(self) -> str:
|
||||
stmt = ""
|
||||
|
@ -329,16 +446,36 @@ class Filters(BaseModel, Generic[TModel]):
|
|||
stmt += f"OFFSET {self.offset}"
|
||||
return stmt
|
||||
|
||||
def where(self, where_stmts: List[str]) -> str:
|
||||
def where(self, where_stmts: Optional[List[str]] = None) -> str:
|
||||
if not where_stmts:
|
||||
where_stmts = []
|
||||
if self.filters:
|
||||
for filter in self.filters:
|
||||
where_stmts.append(filter.statement)
|
||||
if self.search and self.model:
|
||||
if DB_TYPE == POSTGRES:
|
||||
where_stmts.append(
|
||||
f"lower(concat({f', '.join(self.model.__search_fields__)})) LIKE ?"
|
||||
)
|
||||
elif DB_TYPE == SQLITE:
|
||||
where_stmts.append(
|
||||
f"lower({'||'.join(self.model.__search_fields__)}) LIKE ?"
|
||||
)
|
||||
if where_stmts:
|
||||
return "WHERE " + " AND ".join(where_stmts)
|
||||
return ""
|
||||
|
||||
def values(self, values: List[str]) -> Tuple:
|
||||
def order_by(self) -> str:
|
||||
if self.sortby:
|
||||
return f"ORDER BY {self.sortby} {self.direction or 'asc'}"
|
||||
return ""
|
||||
|
||||
def values(self, values: Optional[List[str]] = None) -> tuple:
|
||||
if not values:
|
||||
values = []
|
||||
if self.filters:
|
||||
for filter in self.filters:
|
||||
values.extend(filter.values)
|
||||
if self.search and self.model:
|
||||
values.append(f"%{self.search}%")
|
||||
return tuple(values)
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Optional, Type
|
||||
from typing import Literal, Optional, Type
|
||||
|
||||
from fastapi import HTTPException, Request, Security, status
|
||||
from fastapi import Query, Request, Security, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
from fastapi.security.base import SecurityBase
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import UUID4
|
||||
|
||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||
from lnbits.core.models import User, Wallet
|
||||
from lnbits.db import Filter, Filters
|
||||
from lnbits.db import Filter, Filters, TFilterModel
|
||||
from lnbits.requestvars import g
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
@ -185,7 +185,6 @@ async def require_admin_key(
|
|||
api_key_header: str = Security(api_key_header),
|
||||
api_key_query: str = Security(api_key_query),
|
||||
):
|
||||
|
||||
token = api_key_header or api_key_query
|
||||
|
||||
if not token:
|
||||
|
@ -211,7 +210,6 @@ async def require_invoice_key(
|
|||
api_key_header: str = Security(api_key_header),
|
||||
api_key_query: str = Security(api_key_query),
|
||||
):
|
||||
|
||||
token = api_key_header or api_key_query
|
||||
|
||||
if not token:
|
||||
|
@ -279,14 +277,19 @@ async def check_super_user(usr: UUID4) -> User:
|
|||
return user
|
||||
|
||||
|
||||
def parse_filters(model: Type[BaseModel]):
|
||||
def parse_filters(model: Type[TFilterModel]):
|
||||
"""
|
||||
Parses the query params as filters.
|
||||
:param model: model used for validation of filter values
|
||||
"""
|
||||
|
||||
def dependency(
|
||||
request: Request, limit: Optional[int] = None, offset: Optional[int] = None
|
||||
request: Request,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
sortby: Optional[str] = None,
|
||||
direction: Optional[Literal["asc", "desc"]] = None,
|
||||
search: Optional[str] = Query(None, description="Text based search"),
|
||||
):
|
||||
params = request.query_params
|
||||
filters = []
|
||||
|
@ -300,6 +303,10 @@ def parse_filters(model: Type[BaseModel]):
|
|||
filters=filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sortby=sortby,
|
||||
direction=direction,
|
||||
search=search,
|
||||
model=model,
|
||||
)
|
||||
|
||||
return dependency
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Any, List, Optional, Type
|
|||
|
||||
import jinja2
|
||||
import shortuuid
|
||||
from pydantic import BaseModel
|
||||
from pydantic.schema import (
|
||||
field_schema,
|
||||
get_flat_models_from_fields,
|
||||
|
@ -15,6 +14,7 @@ from lnbits.jinja2_templating import Jinja2Templates
|
|||
from lnbits.requestvars import g
|
||||
from lnbits.settings import settings
|
||||
|
||||
from .db import FilterModel
|
||||
from .extension_manager import get_valid_extensions
|
||||
|
||||
|
||||
|
@ -32,7 +32,6 @@ def url_for(endpoint: str, external: Optional[bool] = False, **params: Any) -> s
|
|||
|
||||
|
||||
def template_renderer(additional_folders: Optional[List] = None) -> Jinja2Templates:
|
||||
|
||||
folders = ["lnbits/templates", "lnbits/core/templates"]
|
||||
if additional_folders:
|
||||
folders.extend(additional_folders)
|
||||
|
@ -96,7 +95,7 @@ def get_current_extension_name() -> str:
|
|||
return ext_name
|
||||
|
||||
|
||||
def generate_filter_params_openapi(model: Type[BaseModel], keep_optional=False):
|
||||
def generate_filter_params_openapi(model: Type[FilterModel], keep_optional=False):
|
||||
"""
|
||||
Generate openapi documentation for Filters. This is intended to be used along parse_filters (see example)
|
||||
:param model: Filter model
|
||||
|
@ -117,6 +116,11 @@ def generate_filter_params_openapi(model: Type[BaseModel], keep_optional=False):
|
|||
description = "Supports Filtering"
|
||||
if schema["type"] == "object":
|
||||
description += f". Nested attributes can be filtered too, e.g. `{field.alias}.[additional].[attributes]`"
|
||||
if (
|
||||
hasattr(model, "__search_fields__")
|
||||
and field.name in model.__search_fields__
|
||||
):
|
||||
description += ". Supports Search"
|
||||
|
||||
parameter = {
|
||||
"name": field.alias,
|
||||
|
|
2
lnbits/static/bundle.min.js
vendored
2
lnbits/static/bundle.min.js
vendored
File diff suppressed because one or more lines are too long
|
@ -67,8 +67,13 @@ window.LNbits = {
|
|||
getWallet: function (wallet) {
|
||||
return this.request('get', '/api/v1/wallet', wallet.inkey)
|
||||
},
|
||||
getPayments: function (wallet) {
|
||||
return this.request('get', '/api/v1/payments', wallet.inkey)
|
||||
getPayments: function (wallet, query) {
|
||||
const params = new URLSearchParams(query)
|
||||
return this.request(
|
||||
'get',
|
||||
'/api/v1/payments/paginated?' + params,
|
||||
wallet.inkey
|
||||
)
|
||||
},
|
||||
getPayment: function (wallet, paymentHash) {
|
||||
return this.request(
|
||||
|
@ -185,7 +190,7 @@ window.LNbits = {
|
|||
},
|
||||
payment: function (data) {
|
||||
obj = {
|
||||
checking_id: data.id,
|
||||
checking_id: data.checking_id,
|
||||
pending: data.pending,
|
||||
amount: data.amount,
|
||||
fee: data.fee,
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
import asyncio
|
||||
import hashlib
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
|
||||
from lnbits import bolt11
|
||||
from lnbits.core.models import Payment
|
||||
from lnbits.core.views.api import api_payment
|
||||
from lnbits.db import DB_TYPE, SQLITE
|
||||
from lnbits.settings import get_wallet_class
|
||||
from tests.conftest import CreateInvoiceData, api_payments_create_invoice
|
||||
|
||||
from ...helpers import get_random_invoice_data, is_fake
|
||||
|
||||
|
@ -181,6 +186,66 @@ async def test_pay_invoice_adminkey(client, invoice, adminkey_headers_from):
|
|||
assert response.status_code > 300 # should fail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_payments(client, from_wallet, adminkey_headers_from):
|
||||
# Because sqlite only stores timestamps with milliseconds we have to wait a second to ensure
|
||||
# a different timestamp than previous invoices
|
||||
# due to this limitation both payments (normal and paginated) are tested at the same time as they are almost
|
||||
# identical anyways
|
||||
if DB_TYPE == SQLITE:
|
||||
await asyncio.sleep(1)
|
||||
ts = time()
|
||||
|
||||
fake_data = [
|
||||
CreateInvoiceData(amount=10, memo="aaaa"),
|
||||
CreateInvoiceData(amount=100, memo="bbbb"),
|
||||
CreateInvoiceData(amount=1000, memo="aabb"),
|
||||
]
|
||||
|
||||
for invoice in fake_data:
|
||||
await api_payments_create_invoice(invoice, from_wallet)
|
||||
|
||||
async def get_payments(params: dict):
|
||||
params["time[ge]"] = ts
|
||||
response = await client.get(
|
||||
"/api/v1/payments",
|
||||
params=params,
|
||||
headers=adminkey_headers_from,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return [Payment(**payment) for payment in response.json()]
|
||||
|
||||
payments = await get_payments({"sortby": "amount", "direction": "desc", "limit": 2})
|
||||
assert payments[-1].amount < payments[0].amount
|
||||
assert len(payments) == 2
|
||||
|
||||
payments = await get_payments({"offset": 2, "limit": 2})
|
||||
assert len(payments) == 1
|
||||
|
||||
payments = await get_payments({"sortby": "amount", "direction": "asc"})
|
||||
assert payments[-1].amount > payments[0].amount
|
||||
|
||||
payments = await get_payments({"search": "aaa"})
|
||||
assert len(payments) == 1
|
||||
|
||||
payments = await get_payments({"search": "aa"})
|
||||
assert len(payments) == 2
|
||||
|
||||
# amount is in msat
|
||||
payments = await get_payments({"amount[gt]": 10000})
|
||||
assert len(payments) == 2
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/payments/paginated",
|
||||
params={"limit": 2, "time[ge]": ts},
|
||||
headers=adminkey_headers_from,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
paginated = response.json()
|
||||
assert len(paginated["data"]) == 2
|
||||
assert paginated["total"] == len(fake_data)
|
||||
|
||||
|
||||
# check POST /api/v1/payments/decode
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_invoice(client, invoice):
|
||||
|
|
Loading…
Reference in New Issue
Block a user