Serverside Pagination for payments (#1613)

* initial backend support

* implement payments pagination on frontend

* implement search for payments api

* fix pyright issues

* sqlite support for searching

* backwards compatability

* formatting, small fixes

* small optimization

* fix sorting issue, add error handling

* GET payments test

* filter by dates, use List instead of list

* fix sqlite

* update bundle

* test old payments endpoint aswell

* refactor for easier review

* optimise test

* revert unnecessary change

---------

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
jackstar12 2023-05-09 10:18:53 +02:00 committed by GitHub
parent 45b199a8ef
commit c0f66989cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 460 additions and 142 deletions

View File

@ -7,12 +7,12 @@ from uuid import uuid4
import shortuuid
from lnbits import bolt11
from lnbits.db import COCKROACH, POSTGRES, Connection, Filters
from lnbits.db import Connection, Filters, Page
from lnbits.extension_manager import InstallableExtension
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
from . import db
from .models import BalanceCheck, Payment, TinyURL, User, Wallet
from .models import BalanceCheck, Payment, PaymentFilters, TinyURL, User, Wallet
# accounts
# --------
@ -343,7 +343,7 @@ async def get_latest_payments_by_extension(ext_name: str, ext_id: str, limit: in
return rows
async def get_payments(
async def get_payments_paginated(
*,
wallet_id: Optional[str] = None,
complete: bool = False,
@ -352,28 +352,23 @@ async def get_payments(
incoming: bool = False,
since: Optional[int] = None,
exclude_uncheckable: bool = False,
filters: Optional[Filters[Payment]] = None,
filters: Optional[Filters[PaymentFilters]] = None,
conn: Optional[Connection] = None,
) -> List[Payment]:
) -> Page[Payment]:
"""
Filters payments to be returned by complete | pending | outgoing | incoming.
"""
args: List[Any] = []
values: List[Any] = []
clause: List[str] = []
if since is not None:
if db.type == POSTGRES:
clause.append("time > to_timestamp(?)")
elif db.type == COCKROACH:
clause.append("time > cast(? AS timestamp)")
else:
clause.append("time > ?")
args.append(since)
clause.append(f"time > {db.timestamp_placeholder}")
values.append(since)
if wallet_id:
clause.append("wallet = ?")
args.append(wallet_id)
values.append(wallet_id)
if complete and pending:
pass
@ -397,21 +392,54 @@ async def get_payments(
clause.append("checking_id NOT LIKE 'temp_%'")
clause.append("checking_id NOT LIKE 'internal_%'")
if not filters:
filters = Filters(limit=None, offset=None)
rows = await (conn or db).fetchall(
f"""
SELECT *
FROM apipayments
{filters.where(clause)}
ORDER BY time DESC
{filters.pagination()}
""",
filters.values(args),
return await (conn or db).fetch_page(
"SELECT * FROM apipayments",
clause,
values,
filters=filters,
model=Payment,
)
return [Payment.from_row(row) for row in rows]
async def get_payments(
*,
wallet_id: Optional[str] = None,
complete: bool = False,
pending: bool = False,
outgoing: bool = False,
incoming: bool = False,
since: Optional[int] = None,
exclude_uncheckable: bool = False,
filters: Optional[Filters[PaymentFilters]] = None,
conn: Optional[Connection] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> list[Payment]:
"""
Filters payments to be returned by complete | pending | outgoing | incoming.
"""
if not filters:
filters = Filters()
if limit:
filters.limit = limit
if offset:
filters.offset = offset
page = await get_payments_paginated(
wallet_id=wallet_id,
complete=complete,
pending=pending,
outgoing=outgoing,
incoming=incoming,
since=since,
exclude_uncheckable=exclude_uncheckable,
filters=filters,
conn=conn,
)
return page.data
async def delete_expired_invoices(
@ -454,7 +482,6 @@ async def create_payment(
webhook: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Payment:
# todo: add this when tests are fixed
# previous_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
# assert previous_payment is None, "Payment already exists"
@ -514,7 +541,6 @@ async def update_payment_details(
new_checking_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None:
set_clause: List[str] = []
set_variables: List[Any] = []

View File

@ -11,7 +11,7 @@ from lnurl import encode as lnurl_encode
from loguru import logger
from pydantic import BaseModel
from lnbits.db import Connection
from lnbits.db import Connection, FilterModel, FromRowModel
from lnbits.helpers import url_for
from lnbits.settings import get_wallet_class, settings
from lnbits.wallets.base import PaymentStatus
@ -86,7 +86,7 @@ class User(BaseModel):
return False
class Payment(BaseModel):
class Payment(FromRowModel):
checking_id: str
pending: bool
amount: int
@ -214,6 +214,24 @@ class Payment(BaseModel):
await delete_payment(self.checking_id, conn=conn)
class PaymentFilters(FilterModel):
__search_fields__ = ["memo", "amount"]
checking_id: str
amount: int
fee: int
memo: Optional[str]
time: datetime.datetime
bolt11: str
preimage: str
payment_hash: str
expiry: Optional[datetime.datetime]
extra: Dict = {}
wallet_id: str
webhook: Optional[str]
webhook_status: Optional[int]
class BalanceCheck(BaseModel):
wallet: str
service: str

View File

@ -1,6 +1,6 @@
// update cache version every time there is a new deployment
// so the service worker reinitializes the cache
const CACHE_VERSION = 5
const CACHE_VERSION = 6
const CURRENT_CACHE = `lnbits-${CACHE_VERSION}-`
const getApiKey = request => {

View File

@ -152,14 +152,14 @@ new Vue({
field: 'memo'
},
{
name: 'date',
name: 'time',
align: 'left',
label: this.$t('date'),
field: 'date',
sortable: true
},
{
name: 'sat',
name: 'amount',
align: 'right',
label: this.$t('amount') + ' (' + LNBITS_DENOMINATION + ')',
field: 'sat',
@ -173,9 +173,14 @@ new Vue({
}
],
pagination: {
rowsPerPage: 10
rowsPerPage: 10,
page: 1,
sortBy: 'time',
descending: true,
rowsNumber: 10
},
filter: null
filter: null,
loading: false
},
paymentsChart: {
show: false
@ -695,15 +700,34 @@ new Vue({
LNbits.href.deleteWallet(walletId, user)
})
},
fetchPayments: function () {
return LNbits.api.getPayments(this.g.wallet).then(response => {
this.payments = response.data
.map(obj => {
fetchPayments: function (props) {
// Props are passed by qasar when pagination or sorting changes
if (props) {
this.paymentsTable.pagination = props.pagination
}
let pagination = this.paymentsTable.pagination
this.paymentsTable.loading = true
const query = {
limit: pagination.rowsPerPage,
offset: (pagination.page - 1) * pagination.rowsPerPage,
sortby: pagination.sortBy ?? 'time',
direction: pagination.descending ? 'desc' : 'asc'
}
if (this.paymentsTable.filter) {
query.search = this.paymentsTable.filter
}
return LNbits.api
.getPayments(this.g.wallet, query)
.then(response => {
this.paymentsTable.loading = false
this.paymentsTable.pagination.rowsNumber = response.data.total
this.payments = response.data.data.map(obj => {
return LNbits.map.payment(obj)
})
.sort((a, b) => {
return b.time - a.time
})
.catch(err => {
this.paymentsTable.loading = false
LNbits.utils.notifyApiError(err)
})
},
fetchBalance: function () {

View File

@ -125,7 +125,6 @@
</div>
</div>
<q-input
v-if="payments.length > 10"
filled
dense
clearable
@ -138,12 +137,14 @@
<q-table
dense
flat
:data="filteredPayments"
:data="payments"
:row-key="paymentTableRowKey"
:columns="paymentsTable.columns"
:pagination.sync="paymentsTable.pagination"
:no-data-label="$t('no_transactions')"
:filter="paymentsTable.filter"
:loading="paymentsTable.loading"
@request="fetchPayments"
>
{% raw %}
<template v-slot:header="props">
@ -192,14 +193,14 @@
</q-badge>
{{ props.row.memo }}
</q-td>
<q-td auto-width key="date" :props="props">
<q-td auto-width key="time" :props="props">
<q-tooltip>{{ props.row.date }}</q-tooltip>
{{ props.row.dateFrom }}
</q-td>
{% endraw %}
<q-td
auto-width
key="sat"
key="amount"
v-if="'{{LNBITS_DENOMINATION}}' != 'sats'"
:props="props"
>{% raw %} {{
@ -207,7 +208,7 @@
}}
</q-td>
<q-td auto-width key="sat" v-else :props="props">
<q-td auto-width key="amount" v-else :props="props">
{{ props.row.fsat }}
</q-td>
<q-td auto-width key="fee" :props="props">

View File

@ -33,8 +33,8 @@ from lnbits.core.helpers import (
migrate_extension_database,
stop_extension_background_work,
)
from lnbits.core.models import Payment, User, Wallet
from lnbits.db import Filters
from lnbits.core.models import Payment, PaymentFilters, User, Wallet
from lnbits.db import Filters, Page
from lnbits.decorators import (
WalletTypeInfo,
check_admin,
@ -66,6 +66,7 @@ from ..crud import (
delete_tinyurl,
get_dbversions,
get_payments,
get_payments_paginated,
get_standalone_payment,
get_tinyurl,
get_tinyurl_by_url,
@ -122,19 +123,19 @@ async def api_update_wallet(
summary="get list of payments",
response_description="list of payments",
response_model=List[Payment],
openapi_extra=generate_filter_params_openapi(Payment),
openapi_extra=generate_filter_params_openapi(PaymentFilters),
)
async def api_payments(
wallet: WalletTypeInfo = Depends(get_key_type),
filters: Filters = Depends(parse_filters(Payment)),
filters: Filters = Depends(parse_filters(PaymentFilters)),
):
pendingPayments = await get_payments(
pending_payments = await get_payments(
wallet_id=wallet.wallet.id,
pending=True,
exclude_uncheckable=True,
filters=filters,
)
for payment in pendingPayments:
for payment in pending_payments:
await check_transaction_status(
wallet_id=payment.wallet_id, payment_hash=payment.payment_hash
)
@ -146,6 +147,37 @@ async def api_payments(
)
@core_app.get(
"/api/v1/payments/paginated",
name="Payment List",
summary="get paginated list of payments",
response_description="list of payments",
response_model=Page[Payment],
openapi_extra=generate_filter_params_openapi(PaymentFilters),
)
async def api_payments_paginated(
wallet: WalletTypeInfo = Depends(get_key_type),
filters: Filters = Depends(parse_filters(PaymentFilters)),
):
pending = await get_payments_paginated(
wallet_id=wallet.wallet.id,
pending=True,
exclude_uncheckable=True,
filters=filters,
)
for payment in pending.data:
await check_transaction_status(
wallet_id=payment.wallet_id, payment_hash=payment.payment_hash
)
page = await get_payments_paginated(
wallet_id=wallet.wallet.id,
pending=True,
complete=True,
filters=filters,
)
return page
class CreateInvoiceData(BaseModel):
out: Optional[bool] = True
amount: float = Query(None, ge=0)
@ -788,7 +820,6 @@ async def api_install_extension(
@core_app.delete("/api/v1/extension/{ext_id}")
async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)):
installable_extensions = await InstallableExtension.get_installable_extensions()
extensions = [e for e in installable_extensions if e.id == ext_id]

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import datetime
import os
@ -5,7 +7,8 @@ import re
import time
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
from sqlite3 import Row
from typing import Any, Generic, List, Literal, Optional, Type, TypeVar
from loguru import logger
from pydantic import BaseModel, ValidationError
@ -19,6 +22,51 @@ POSTGRES = "POSTGRES"
COCKROACH = "COCKROACH"
SQLITE = "SQLITE"
if settings.lnbits_database_url:
database_uri = settings.lnbits_database_url
if database_uri.startswith("cockroachdb://"):
DB_TYPE = COCKROACH
else:
DB_TYPE = POSTGRES
from psycopg2.extensions import DECIMAL, new_type, register_type
def _parse_timestamp(value, _):
if value is None:
return None
f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value:
f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
register_type(
new_type(
DECIMAL.values,
"DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
register_type(
new_type(
(1082, 1083, 1266),
"DATE2INT",
lambda value, curs: time.mktime(value.timetuple())
if value is not None
else None,
)
)
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
else:
if os.path.isdir(settings.lnbits_data_folder):
DB_TYPE = SQLITE
else:
raise NotADirectoryError(
f"LNBITS_DATA_FOLDER named {settings.lnbits_data_folder} was not created"
f" - please 'mkdir {settings.lnbits_data_folder}' and try again"
)
class Compat:
type: Optional[str] = "<inherited>"
@ -68,6 +116,16 @@ class Compat:
return "BIGINT"
return "INT"
@classmethod
@property
def timestamp_placeholder(cls):
if DB_TYPE == POSTGRES:
return "to_timestamp(?)"
elif DB_TYPE == COCKROACH:
return "cast(? AS timestamp)"
else:
return "?"
class Connection(Compat):
def __init__(self, conn: AsyncConnection, txn, typ, name, schema):
@ -87,17 +145,21 @@ class Connection(Compat):
# strip html
CLEANR = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
def cleanhtml(raw_html):
if isinstance(raw_html, str):
cleantext = re.sub(CLEANR, "", raw_html)
return cleantext
else:
return raw_html
# tuple to list and back to tuple
value_list = [values] if isinstance(values, str) else list(values)
values = tuple([cleanhtml(val) for val in value_list])
return values
raw_values = [values] if isinstance(values, str) else list(values)
values = []
for raw_value in raw_values:
if isinstance(raw_value, str):
values.append(re.sub(CLEANR, "", raw_value))
elif isinstance(raw_value, datetime.datetime):
ts = raw_value.timestamp()
if self.type == SQLITE:
values.append(int(ts))
else:
values.append(ts)
else:
values.append(raw_value)
return tuple(values)
async def fetchall(self, query: str, values: tuple = ()) -> list:
result = await self.conn.execute(
@ -113,6 +175,51 @@ class Connection(Compat):
await result.close()
return row
async def fetch_page(
self,
query: str,
where: Optional[List[str]] = None,
values: Optional[List[str]] = None,
filters: Optional[Filters] = None,
model: Optional[Type[TRowModel]] = None,
) -> Page[TRowModel]:
if not filters:
filters = Filters()
clause = filters.where(where)
parsed_values = filters.values(values)
rows = await self.fetchall(
f"""
{query}
{clause}
{filters.order_by()}
{filters.pagination()}
""",
parsed_values,
)
if rows:
# no need for extra query if no pagination is specified
if filters.offset or filters.limit:
count = await self.fetchone(
f"""
SELECT COUNT(*) FROM (
{query}
{clause}
) as count
""",
parsed_values,
)
count = int(count[0])
else:
count = len(rows)
else:
count = 0
return Page(
data=[model.from_row(row) for row in rows] if model else rows,
total=count,
)
async def execute(self, query: str, values: tuple = ()):
return await self.conn.execute(
self.rewrite_query(query), self.rewrite_values(values)
@ -122,57 +229,17 @@ class Connection(Compat):
class Database(Compat):
def __init__(self, db_name: str):
self.name = db_name
self.schema = self.name
self.type = DB_TYPE
if settings.lnbits_database_url:
database_uri = settings.lnbits_database_url
if database_uri.startswith("cockroachdb://"):
self.type = COCKROACH
else:
self.type = POSTGRES
from psycopg2.extensions import DECIMAL, new_type, register_type
def _parse_timestamp(value, _):
if value is None:
return None
f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value:
f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
register_type(
new_type(
DECIMAL.values,
"DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
register_type(
new_type(
(1082, 1083, 1266),
"DATE2INT",
lambda value, curs: time.mktime(value.timetuple())
if value is not None
else None,
)
)
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
else:
if os.path.isdir(settings.lnbits_data_folder):
if DB_TYPE == SQLITE:
self.path = os.path.join(
settings.lnbits_data_folder, f"{self.name}.sqlite3"
)
database_uri = f"sqlite:///{self.path}"
self.type = SQLITE
else:
raise NotADirectoryError(
f"LNBITS_DATA_FOLDER named {settings.lnbits_data_folder} was not created"
f" - please 'mkdir {settings.lnbits_data_folder}' and try again"
)
logger.trace(f"database {self.type} added for {self.name}")
self.schema = self.name
database_uri = settings.lnbits_database_url
if self.name.startswith("ext_"):
self.schema = self.name[4:]
else:
@ -181,6 +248,8 @@ class Database(Compat):
self.engine = create_engine(database_uri, strategy=ASYNCIO_STRATEGY)
self.lock = asyncio.Lock()
logger.trace(f"database {self.type} added for {self.name}")
@asynccontextmanager
async def connect(self):
await self.lock.acquire()
@ -215,6 +284,17 @@ class Database(Compat):
await result.close()
return row
async def fetch_page(
self,
query: str,
where: Optional[List[str]] = None,
values: Optional[List[str]] = None,
filters: Optional[Filters] = None,
model: Optional[Type[TRowModel]] = None,
) -> Page[TRowModel]:
async with self.connect() as conn:
return await conn.fetch_page(query, where, values, filters, model)
async def execute(self, query: str, values: tuple = ()):
async with self.connect() as conn:
return await conn.execute(query, values)
@ -229,6 +309,8 @@ class Operator(Enum):
LT = "lt"
EQ = "eq"
NE = "ne"
GE = "ge"
LE = "le"
INCLUDE = "in"
EXCLUDE = "ex"
@ -246,21 +328,45 @@ class Operator(Enum):
return ">"
elif self == Operator.LT:
return "<"
elif self == Operator.GE:
return ">="
elif self == Operator.LE:
return "<="
else:
raise ValueError("Unknown SQL Operator")
class FromRowModel(BaseModel):
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class FilterModel(BaseModel):
__search_fields__: List[str] = []
T = TypeVar("T")
TModel = TypeVar("TModel", bound=BaseModel)
TRowModel = TypeVar("TRowModel", bound=FromRowModel)
TFilterModel = TypeVar("TFilterModel", bound=FilterModel)
class Filter(BaseModel, Generic[TModel]):
class Page(BaseModel, Generic[T]):
data: list[T]
total: int
class Filter(BaseModel, Generic[TFilterModel]):
field: str
nested: Optional[list[str]]
nested: Optional[List[str]]
op: Operator = Operator.EQ
values: list[Any]
model: Optional[Type[TFilterModel]]
@classmethod
def parse_query(cls, key: str, raw_values: list[Any], model: Type[TModel]):
def parse_query(cls, key: str, raw_values: list[Any], model: Type[TFilterModel]):
# Key format:
# key[operator]
# e.g. name[eq]
@ -300,7 +406,7 @@ class Filter(BaseModel, Generic[TModel]):
else:
raise ValueError("Unknown filter field")
return cls(field=field, op=op, nested=nested, values=values)
return cls(field=field, op=op, nested=nested, values=values, model=model)
@property
def statement(self):
@ -308,18 +414,29 @@ class Filter(BaseModel, Generic[TModel]):
if self.nested:
for name in self.nested:
accessor = f"({accessor} ->> '{name}')"
if self.model and self.model.__fields__[self.field].type_ == datetime.datetime:
placeholder = Compat.timestamp_placeholder
else:
placeholder = "?"
if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
placeholders = ", ".join(["?"] * len(self.values))
placeholders = ", ".join([placeholder] * len(self.values))
stmt = [f"{accessor} {self.op.as_sql} ({placeholders})"]
else:
stmt = [f"{accessor} {self.op.as_sql} ?"] * len(self.values)
stmt = [f"{accessor} {self.op.as_sql} {placeholder}"] * len(self.values)
return " OR ".join(stmt)
class Filters(BaseModel, Generic[TModel]):
filters: List[Filter[TModel]] = []
limit: Optional[int]
offset: Optional[int]
class Filters(BaseModel, Generic[TFilterModel]):
filters: List[Filter[TFilterModel]] = []
search: Optional[str] = None
offset: Optional[int] = None
limit: Optional[int] = None
sortby: Optional[str] = None
direction: Optional[Literal["asc", "desc"]] = None
model: Optional[Type[TFilterModel]] = None
def pagination(self) -> str:
stmt = ""
@ -329,16 +446,36 @@ class Filters(BaseModel, Generic[TModel]):
stmt += f"OFFSET {self.offset}"
return stmt
def where(self, where_stmts: List[str]) -> str:
def where(self, where_stmts: Optional[List[str]] = None) -> str:
if not where_stmts:
where_stmts = []
if self.filters:
for filter in self.filters:
where_stmts.append(filter.statement)
if self.search and self.model:
if DB_TYPE == POSTGRES:
where_stmts.append(
f"lower(concat({f', '.join(self.model.__search_fields__)})) LIKE ?"
)
elif DB_TYPE == SQLITE:
where_stmts.append(
f"lower({'||'.join(self.model.__search_fields__)}) LIKE ?"
)
if where_stmts:
return "WHERE " + " AND ".join(where_stmts)
return ""
def values(self, values: List[str]) -> Tuple:
def order_by(self) -> str:
if self.sortby:
return f"ORDER BY {self.sortby} {self.direction or 'asc'}"
return ""
def values(self, values: Optional[List[str]] = None) -> tuple:
if not values:
values = []
if self.filters:
for filter in self.filters:
values.extend(filter.values)
if self.search and self.model:
values.append(f"%{self.search}%")
return tuple(values)

View File

@ -1,16 +1,16 @@
from http import HTTPStatus
from typing import Optional, Type
from typing import Literal, Optional, Type
from fastapi import HTTPException, Request, Security, status
from fastapi import Query, Request, Security, status
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.security.base import SecurityBase
from pydantic import BaseModel
from pydantic.types import UUID4
from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet
from lnbits.db import Filter, Filters
from lnbits.db import Filter, Filters, TFilterModel
from lnbits.requestvars import g
from lnbits.settings import settings
@ -185,7 +185,6 @@ async def require_admin_key(
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
):
token = api_key_header or api_key_query
if not token:
@ -211,7 +210,6 @@ async def require_invoice_key(
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
):
token = api_key_header or api_key_query
if not token:
@ -279,14 +277,19 @@ async def check_super_user(usr: UUID4) -> User:
return user
def parse_filters(model: Type[BaseModel]):
def parse_filters(model: Type[TFilterModel]):
"""
Parses the query params as filters.
:param model: model used for validation of filter values
"""
def dependency(
request: Request, limit: Optional[int] = None, offset: Optional[int] = None
request: Request,
limit: Optional[int] = None,
offset: Optional[int] = None,
sortby: Optional[str] = None,
direction: Optional[Literal["asc", "desc"]] = None,
search: Optional[str] = Query(None, description="Text based search"),
):
params = request.query_params
filters = []
@ -300,6 +303,10 @@ def parse_filters(model: Type[BaseModel]):
filters=filters,
limit=limit,
offset=offset,
sortby=sortby,
direction=direction,
search=search,
model=model,
)
return dependency

View File

@ -4,7 +4,6 @@ from typing import Any, List, Optional, Type
import jinja2
import shortuuid
from pydantic import BaseModel
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
@ -15,6 +14,7 @@ from lnbits.jinja2_templating import Jinja2Templates
from lnbits.requestvars import g
from lnbits.settings import settings
from .db import FilterModel
from .extension_manager import get_valid_extensions
@ -32,7 +32,6 @@ def url_for(endpoint: str, external: Optional[bool] = False, **params: Any) -> s
def template_renderer(additional_folders: Optional[List] = None) -> Jinja2Templates:
folders = ["lnbits/templates", "lnbits/core/templates"]
if additional_folders:
folders.extend(additional_folders)
@ -96,7 +95,7 @@ def get_current_extension_name() -> str:
return ext_name
def generate_filter_params_openapi(model: Type[BaseModel], keep_optional=False):
def generate_filter_params_openapi(model: Type[FilterModel], keep_optional=False):
"""
Generate openapi documentation for Filters. This is intended to be used along parse_filters (see example)
:param model: Filter model
@ -117,6 +116,11 @@ def generate_filter_params_openapi(model: Type[BaseModel], keep_optional=False):
description = "Supports Filtering"
if schema["type"] == "object":
description += f". Nested attributes can be filtered too, e.g. `{field.alias}.[additional].[attributes]`"
if (
hasattr(model, "__search_fields__")
and field.name in model.__search_fields__
):
description += ". Supports Search"
parameter = {
"name": field.alias,

File diff suppressed because one or more lines are too long

View File

@ -67,8 +67,13 @@ window.LNbits = {
getWallet: function (wallet) {
return this.request('get', '/api/v1/wallet', wallet.inkey)
},
getPayments: function (wallet) {
return this.request('get', '/api/v1/payments', wallet.inkey)
getPayments: function (wallet, query) {
const params = new URLSearchParams(query)
return this.request(
'get',
'/api/v1/payments/paginated?' + params,
wallet.inkey
)
},
getPayment: function (wallet, paymentHash) {
return this.request(
@ -185,7 +190,7 @@ window.LNbits = {
},
payment: function (data) {
obj = {
checking_id: data.id,
checking_id: data.checking_id,
pending: data.pending,
amount: data.amount,
fee: data.fee,

View File

@ -1,10 +1,15 @@
import asyncio
import hashlib
from time import time
import pytest
from lnbits import bolt11
from lnbits.core.models import Payment
from lnbits.core.views.api import api_payment
from lnbits.db import DB_TYPE, SQLITE
from lnbits.settings import get_wallet_class
from tests.conftest import CreateInvoiceData, api_payments_create_invoice
from ...helpers import get_random_invoice_data, is_fake
@ -181,6 +186,66 @@ async def test_pay_invoice_adminkey(client, invoice, adminkey_headers_from):
assert response.status_code > 300 # should fail
@pytest.mark.asyncio
async def test_get_payments(client, from_wallet, adminkey_headers_from):
# Because sqlite only stores timestamps with milliseconds we have to wait a second to ensure
# a different timestamp than previous invoices
# due to this limitation both payments (normal and paginated) are tested at the same time as they are almost
# identical anyways
if DB_TYPE == SQLITE:
await asyncio.sleep(1)
ts = time()
fake_data = [
CreateInvoiceData(amount=10, memo="aaaa"),
CreateInvoiceData(amount=100, memo="bbbb"),
CreateInvoiceData(amount=1000, memo="aabb"),
]
for invoice in fake_data:
await api_payments_create_invoice(invoice, from_wallet)
async def get_payments(params: dict):
params["time[ge]"] = ts
response = await client.get(
"/api/v1/payments",
params=params,
headers=adminkey_headers_from,
)
assert response.status_code == 200
return [Payment(**payment) for payment in response.json()]
payments = await get_payments({"sortby": "amount", "direction": "desc", "limit": 2})
assert payments[-1].amount < payments[0].amount
assert len(payments) == 2
payments = await get_payments({"offset": 2, "limit": 2})
assert len(payments) == 1
payments = await get_payments({"sortby": "amount", "direction": "asc"})
assert payments[-1].amount > payments[0].amount
payments = await get_payments({"search": "aaa"})
assert len(payments) == 1
payments = await get_payments({"search": "aa"})
assert len(payments) == 2
# amount is in msat
payments = await get_payments({"amount[gt]": 10000})
assert len(payments) == 2
response = await client.get(
"/api/v1/payments/paginated",
params={"limit": 2, "time[ge]": ts},
headers=adminkey_headers_from,
)
assert response.status_code == 200
paginated = response.json()
assert len(paginated["data"]) == 2
assert paginated["total"] == len(fake_data)
# check POST /api/v1/payments/decode
@pytest.mark.asyncio
async def test_decode_invoice(client, invoice):