FEAT: Filters for GET requests, add it to GET /payments (#1557)
* feat filters, add them to GET payments * add limit and offset to filters (#1563) * add limit and offset to filters * move filters example to parse_filters doc string * black * add openapi docs * remove example commentC * improve typing and make nested filter possible in openapi * typo in fn name * readd Type --------- Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com> Co-authored-by: calle <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
parent
fe9e821af5
commit
8ce84ce592
|
@ -7,7 +7,7 @@ from uuid import uuid4
|
|||
import shortuuid
|
||||
|
||||
from lnbits import bolt11
|
||||
from lnbits.db import COCKROACH, POSTGRES, Connection
|
||||
from lnbits.db import COCKROACH, POSTGRES, Connection, Filters
|
||||
from lnbits.extension_manager import InstallableExtension
|
||||
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
|
||||
|
||||
|
@ -347,8 +347,7 @@ async def get_payments(
|
|||
incoming: bool = False,
|
||||
since: Optional[int] = None,
|
||||
exclude_uncheckable: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
filters: Optional[Filters[Payment]] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Payment]:
|
||||
"""
|
||||
|
@ -393,29 +392,20 @@ async def get_payments(
|
|||
clause.append("checking_id NOT LIKE 'temp_%'")
|
||||
clause.append("checking_id NOT LIKE 'internal_%'")
|
||||
|
||||
limit_clause = f"LIMIT {limit}" if type(limit) == int and limit > 0 else ""
|
||||
offset_clause = f"OFFSET {offset}" if type(offset) == int and offset > 0 else ""
|
||||
# combine limit and offset clauses
|
||||
limit_offset_clause = (
|
||||
f"{limit_clause} {offset_clause}"
|
||||
if limit_clause and offset_clause
|
||||
else limit_clause or offset_clause
|
||||
)
|
||||
|
||||
where = ""
|
||||
if clause:
|
||||
where = f"WHERE {' AND '.join(clause)}"
|
||||
if not filters:
|
||||
filters = Filters()
|
||||
|
||||
rows = await (conn or db).fetchall(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM apipayments
|
||||
{where}
|
||||
{filters.where(clause)}
|
||||
ORDER BY time DESC
|
||||
{limit_offset_clause}
|
||||
{filters.pagination()}
|
||||
""",
|
||||
tuple(args),
|
||||
filters.values(args),
|
||||
)
|
||||
|
||||
return [Payment.from_row(row) for row in rows]
|
||||
|
||||
|
||||
|
|
|
@ -34,10 +34,12 @@ from lnbits.core.helpers import (
|
|||
stop_extension_background_work,
|
||||
)
|
||||
from lnbits.core.models import Payment, User, Wallet
|
||||
from lnbits.db import Filters
|
||||
from lnbits.decorators import (
|
||||
WalletTypeInfo,
|
||||
check_admin,
|
||||
get_key_type,
|
||||
parse_filters,
|
||||
require_admin_key,
|
||||
require_invoice_key,
|
||||
)
|
||||
|
@ -48,7 +50,7 @@ from lnbits.extension_manager import (
|
|||
InstallableExtension,
|
||||
get_valid_extensions,
|
||||
)
|
||||
from lnbits.helpers import url_for
|
||||
from lnbits.helpers import generate_filter_params_openapi, url_for
|
||||
from lnbits.settings import get_wallet_class, settings
|
||||
from lnbits.utils.exchange_rates import (
|
||||
currencies,
|
||||
|
@ -114,18 +116,23 @@ async def api_update_wallet(
|
|||
}
|
||||
|
||||
|
||||
@core_app.get("/api/v1/payments")
|
||||
@core_app.get(
|
||||
"/api/v1/payments",
|
||||
name="Payment List",
|
||||
summary="get list of payments",
|
||||
response_description="list of payments",
|
||||
response_model=List[Payment],
|
||||
openapi_extra=generate_filter_params_openapi(Payment),
|
||||
)
|
||||
async def api_payments(
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
wallet: WalletTypeInfo = Depends(get_key_type),
|
||||
filters: Filters = Depends(parse_filters(Payment)),
|
||||
):
|
||||
pendingPayments = await get_payments(
|
||||
wallet_id=wallet.wallet.id,
|
||||
pending=True,
|
||||
exclude_uncheckable=True,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
filters=filters,
|
||||
)
|
||||
for payment in pendingPayments:
|
||||
await check_transaction_status(
|
||||
|
@ -135,8 +142,7 @@ async def api_payments(
|
|||
wallet_id=wallet.wallet.id,
|
||||
pending=True,
|
||||
complete=True,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
|
||||
|
|
124
lnbits/db.py
124
lnbits/db.py
|
@ -4,9 +4,11 @@ import os
|
|||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy_aio.base import AsyncConnection
|
||||
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY
|
||||
|
@ -224,3 +226,123 @@ class Database(Compat):
|
|||
@asynccontextmanager
|
||||
async def reuse_conn(self, conn: Connection):
|
||||
yield conn
|
||||
|
||||
|
||||
class Operator(Enum):
|
||||
GT = "gt"
|
||||
LT = "lt"
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
INCLUDE = "in"
|
||||
EXCLUDE = "ex"
|
||||
|
||||
@property
|
||||
def as_sql(self):
|
||||
if self == Operator.EQ:
|
||||
return "="
|
||||
elif self == Operator.NE:
|
||||
return "!="
|
||||
elif self == Operator.INCLUDE:
|
||||
return "IN"
|
||||
elif self == Operator.EXCLUDE:
|
||||
return "NOT IN"
|
||||
elif self == Operator.GT:
|
||||
return ">"
|
||||
elif self == Operator.LT:
|
||||
return "<"
|
||||
else:
|
||||
raise ValueError("Unknown SQL Operator")
|
||||
|
||||
|
||||
TModel = TypeVar("TModel", bound=BaseModel)
|
||||
|
||||
|
||||
class Filter(BaseModel, Generic[TModel]):
|
||||
field: str
|
||||
nested: Optional[list[str]]
|
||||
op: Operator = Operator.EQ
|
||||
values: list[Any]
|
||||
|
||||
@classmethod
|
||||
def parse_query(cls, key: str, raw_values: list[Any], model: Type[TModel]):
|
||||
# Key format:
|
||||
# key[operator]
|
||||
# e.g. name[eq]
|
||||
if key.endswith("]"):
|
||||
split = key[:-1].split("[")
|
||||
if len(split) != 2:
|
||||
raise ValueError("Invalid key")
|
||||
field_names = split[0].split(".")
|
||||
op = Operator(split[1])
|
||||
else:
|
||||
field_names = key.split(".")
|
||||
op = Operator("eq")
|
||||
|
||||
field = field_names[0]
|
||||
nested = field_names[1:]
|
||||
|
||||
if field in model.__fields__:
|
||||
compare_field = model.__fields__[field]
|
||||
values = []
|
||||
for raw_value in raw_values:
|
||||
# If there is a nested field, pydantic expects a dict, so the raw value is turned into a dict before
|
||||
# and the converted value is extracted afterwards
|
||||
for name in reversed(nested):
|
||||
raw_value = {name: raw_value}
|
||||
|
||||
validated, errors = compare_field.validate(raw_value, {}, loc="none")
|
||||
if errors:
|
||||
raise ValidationError(errors=[errors], model=model)
|
||||
|
||||
for name in nested:
|
||||
if isinstance(validated, dict):
|
||||
validated = validated[name]
|
||||
else:
|
||||
validated = getattr(validated, name)
|
||||
|
||||
values.append(validated)
|
||||
else:
|
||||
raise ValueError("Unknown filter field")
|
||||
|
||||
return cls(field=field, op=op, nested=nested, values=values)
|
||||
|
||||
@property
|
||||
def statement(self):
|
||||
accessor = self.field
|
||||
if self.nested:
|
||||
for name in self.nested:
|
||||
accessor = f"({accessor} ->> '{name}')"
|
||||
if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
|
||||
placeholders = ", ".join(["?"] * len(self.values))
|
||||
stmt = [f"{accessor} {self.op.as_sql} ({placeholders})"]
|
||||
else:
|
||||
stmt = [f"{accessor} {self.op.as_sql} ?"] * len(self.values)
|
||||
return " OR ".join(stmt)
|
||||
|
||||
|
||||
class Filters(BaseModel, Generic[TModel]):
|
||||
filters: List[Filter[TModel]] = []
|
||||
limit: Optional[int]
|
||||
offset: Optional[int]
|
||||
|
||||
def pagination(self) -> str:
|
||||
stmt = ""
|
||||
if self.limit:
|
||||
stmt += f"LIMIT {self.limit} "
|
||||
if self.offset:
|
||||
stmt += f"OFFSET {self.offset}"
|
||||
return stmt
|
||||
|
||||
def where(self, where_stmts: List[str]) -> str:
|
||||
if self.filters:
|
||||
for filter in self.filters:
|
||||
where_stmts.append(filter.statement)
|
||||
if where_stmts:
|
||||
return "WHERE " + " AND ".join(where_stmts)
|
||||
return ""
|
||||
|
||||
def values(self, values: List[str]) -> Tuple:
|
||||
if self.filters:
|
||||
for filter in self.filters:
|
||||
values.extend(filter.values)
|
||||
return tuple(values)
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Optional, Type
|
||||
|
||||
from fastapi import Security, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
||||
from fastapi.security.base import SecurityBase
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import UUID4
|
||||
from starlette.requests import Request
|
||||
|
||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||
from lnbits.core.models import User, Wallet
|
||||
from lnbits.db import Filter, Filters
|
||||
from lnbits.requestvars import g
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
@ -266,3 +269,29 @@ async def check_super_user(usr: UUID4) -> User:
|
|||
detail="User not authorized. No super user privileges.",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def parse_filters(model: Type[BaseModel]):
|
||||
"""
|
||||
Parses the query params as filters.
|
||||
:param model: model used for validation of filter values
|
||||
"""
|
||||
|
||||
def dependency(
|
||||
request: Request, limit: Optional[int] = None, offset: Optional[int] = None
|
||||
):
|
||||
params = request.query_params
|
||||
filters = []
|
||||
for key in params.keys():
|
||||
try:
|
||||
filters.append(Filter.parse_query(key, params.getlist(key), model))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return Filters(
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return dependency
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
import jinja2
|
||||
import shortuuid # type: ignore
|
||||
import shortuuid
|
||||
from pydantic import BaseModel
|
||||
from pydantic.schema import (
|
||||
field_schema,
|
||||
get_flat_models_from_fields,
|
||||
get_model_name_map,
|
||||
)
|
||||
|
||||
from lnbits.jinja2_templating import Jinja2Templates
|
||||
from lnbits.requestvars import g
|
||||
|
@ -102,3 +108,39 @@ def get_current_extension_name() -> str:
|
|||
except:
|
||||
ext_name = extension_director_name
|
||||
return ext_name
|
||||
|
||||
|
||||
def generate_filter_params_openapi(model: Type[BaseModel], keep_optional=False):
|
||||
"""
|
||||
Generate openapi documentation for Filters. This is intended to be used along parse_filters (see example)
|
||||
:param model: Filter model
|
||||
:param keep_optional: If false, all parameters will be optional, otherwise inferred from model
|
||||
"""
|
||||
fields = list(model.__fields__.values())
|
||||
models = get_flat_models_from_fields(fields, set())
|
||||
namemap = get_model_name_map(models)
|
||||
params = []
|
||||
for field in fields:
|
||||
schema, definitions, _ = field_schema(field, model_name_map=namemap)
|
||||
|
||||
# Support nested definition
|
||||
if "$ref" in schema:
|
||||
name = schema["$ref"].split("/")[-1]
|
||||
schema = definitions[name]
|
||||
|
||||
description = "Supports Filtering"
|
||||
if schema["type"] == "object":
|
||||
description += f". Nested attributes can be filtered too, e.g. `{field.alias}.[additional].[attributes]`"
|
||||
|
||||
parameter = {
|
||||
"name": field.alias,
|
||||
"in": "query",
|
||||
"required": field.required if keep_optional else False,
|
||||
"schema": schema,
|
||||
"description": description,
|
||||
}
|
||||
params.append(parameter)
|
||||
|
||||
return {
|
||||
"parameters": params,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user