FEAT: Filters for GET requests, add it to GET /payments (#1557)

* feat filters, add them to GET payments

* add limit and offset to filters (#1563)

* add limit and offset to filters
* move filters example to parse_filters doc string

* black

* add openapi docs

* remove example commentC

* improve typing and make nested filter possible in openapi

* typo in fn name

* readd Type

---------

Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com>
Co-authored-by: calle <93376500+callebtc@users.noreply.github.com>
This commit is contained in:
dni ⚡ 2023-04-03 14:55:49 +02:00 committed by GitHub
parent fe9e821af5
commit 8ce84ce592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 218 additions and 29 deletions

View File

@ -7,7 +7,7 @@ from uuid import uuid4
import shortuuid
from lnbits import bolt11
from lnbits.db import COCKROACH, POSTGRES, Connection
from lnbits.db import COCKROACH, POSTGRES, Connection, Filters
from lnbits.extension_manager import InstallableExtension
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
@ -347,8 +347,7 @@ async def get_payments(
incoming: bool = False,
since: Optional[int] = None,
exclude_uncheckable: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filters: Optional[Filters[Payment]] = None,
conn: Optional[Connection] = None,
) -> List[Payment]:
"""
@ -393,29 +392,20 @@ async def get_payments(
clause.append("checking_id NOT LIKE 'temp_%'")
clause.append("checking_id NOT LIKE 'internal_%'")
limit_clause = f"LIMIT {limit}" if type(limit) == int and limit > 0 else ""
offset_clause = f"OFFSET {offset}" if type(offset) == int and offset > 0 else ""
# combine limit and offset clauses
limit_offset_clause = (
f"{limit_clause} {offset_clause}"
if limit_clause and offset_clause
else limit_clause or offset_clause
)
where = ""
if clause:
where = f"WHERE {' AND '.join(clause)}"
if not filters:
filters = Filters()
rows = await (conn or db).fetchall(
f"""
SELECT *
FROM apipayments
{where}
{filters.where(clause)}
ORDER BY time DESC
{limit_offset_clause}
{filters.pagination()}
""",
tuple(args),
filters.values(args),
)
return [Payment.from_row(row) for row in rows]

View File

@ -34,10 +34,12 @@ from lnbits.core.helpers import (
stop_extension_background_work,
)
from lnbits.core.models import Payment, User, Wallet
from lnbits.db import Filters
from lnbits.decorators import (
WalletTypeInfo,
check_admin,
get_key_type,
parse_filters,
require_admin_key,
require_invoice_key,
)
@ -48,7 +50,7 @@ from lnbits.extension_manager import (
InstallableExtension,
get_valid_extensions,
)
from lnbits.helpers import url_for
from lnbits.helpers import generate_filter_params_openapi, url_for
from lnbits.settings import get_wallet_class, settings
from lnbits.utils.exchange_rates import (
currencies,
@ -114,18 +116,23 @@ async def api_update_wallet(
}
@core_app.get("/api/v1/payments")
@core_app.get(
"/api/v1/payments",
name="Payment List",
summary="get list of payments",
response_description="list of payments",
response_model=List[Payment],
openapi_extra=generate_filter_params_openapi(Payment),
)
async def api_payments(
limit: Optional[int] = None,
offset: Optional[int] = None,
wallet: WalletTypeInfo = Depends(get_key_type),
filters: Filters = Depends(parse_filters(Payment)),
):
pendingPayments = await get_payments(
wallet_id=wallet.wallet.id,
pending=True,
exclude_uncheckable=True,
limit=limit,
offset=offset,
filters=filters,
)
for payment in pendingPayments:
await check_transaction_status(
@ -135,8 +142,7 @@ async def api_payments(
wallet_id=wallet.wallet.id,
pending=True,
complete=True,
limit=limit,
offset=offset,
filters=filters,
)

View File

@ -4,9 +4,11 @@ import os
import re
import time
from contextlib import asynccontextmanager
from typing import Optional
from enum import Enum
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
from loguru import logger
from pydantic import BaseModel, ValidationError
from sqlalchemy import create_engine
from sqlalchemy_aio.base import AsyncConnection
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY
@ -224,3 +226,123 @@ class Database(Compat):
@asynccontextmanager
async def reuse_conn(self, conn: Connection):
yield conn
class Operator(Enum):
GT = "gt"
LT = "lt"
EQ = "eq"
NE = "ne"
INCLUDE = "in"
EXCLUDE = "ex"
@property
def as_sql(self):
if self == Operator.EQ:
return "="
elif self == Operator.NE:
return "!="
elif self == Operator.INCLUDE:
return "IN"
elif self == Operator.EXCLUDE:
return "NOT IN"
elif self == Operator.GT:
return ">"
elif self == Operator.LT:
return "<"
else:
raise ValueError("Unknown SQL Operator")
TModel = TypeVar("TModel", bound=BaseModel)
class Filter(BaseModel, Generic[TModel]):
field: str
nested: Optional[list[str]]
op: Operator = Operator.EQ
values: list[Any]
@classmethod
def parse_query(cls, key: str, raw_values: list[Any], model: Type[TModel]):
# Key format:
# key[operator]
# e.g. name[eq]
if key.endswith("]"):
split = key[:-1].split("[")
if len(split) != 2:
raise ValueError("Invalid key")
field_names = split[0].split(".")
op = Operator(split[1])
else:
field_names = key.split(".")
op = Operator("eq")
field = field_names[0]
nested = field_names[1:]
if field in model.__fields__:
compare_field = model.__fields__[field]
values = []
for raw_value in raw_values:
# If there is a nested field, pydantic expects a dict, so the raw value is turned into a dict before
# and the converted value is extracted afterwards
for name in reversed(nested):
raw_value = {name: raw_value}
validated, errors = compare_field.validate(raw_value, {}, loc="none")
if errors:
raise ValidationError(errors=[errors], model=model)
for name in nested:
if isinstance(validated, dict):
validated = validated[name]
else:
validated = getattr(validated, name)
values.append(validated)
else:
raise ValueError("Unknown filter field")
return cls(field=field, op=op, nested=nested, values=values)
@property
def statement(self):
accessor = self.field
if self.nested:
for name in self.nested:
accessor = f"({accessor} ->> '{name}')"
if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
placeholders = ", ".join(["?"] * len(self.values))
stmt = [f"{accessor} {self.op.as_sql} ({placeholders})"]
else:
stmt = [f"{accessor} {self.op.as_sql} ?"] * len(self.values)
return " OR ".join(stmt)
class Filters(BaseModel, Generic[TModel]):
filters: List[Filter[TModel]] = []
limit: Optional[int]
offset: Optional[int]
def pagination(self) -> str:
stmt = ""
if self.limit:
stmt += f"LIMIT {self.limit} "
if self.offset:
stmt += f"OFFSET {self.offset}"
return stmt
def where(self, where_stmts: List[str]) -> str:
if self.filters:
for filter in self.filters:
where_stmts.append(filter.statement)
if where_stmts:
return "WHERE " + " AND ".join(where_stmts)
return ""
def values(self, values: List[str]) -> Tuple:
if self.filters:
for filter in self.filters:
values.extend(filter.values)
return tuple(values)

View File

@ -1,15 +1,18 @@
from http import HTTPStatus
from typing import Optional, Type
from fastapi import Security, status
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
from fastapi.security.base import SecurityBase
from pydantic import BaseModel
from pydantic.types import UUID4
from starlette.requests import Request
from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet
from lnbits.db import Filter, Filters
from lnbits.requestvars import g
from lnbits.settings import settings
@ -266,3 +269,29 @@ async def check_super_user(usr: UUID4) -> User:
detail="User not authorized. No super user privileges.",
)
return user
def parse_filters(model: Type[BaseModel]):
"""
Parses the query params as filters.
:param model: model used for validation of filter values
"""
def dependency(
request: Request, limit: Optional[int] = None, offset: Optional[int] = None
):
params = request.query_params
filters = []
for key in params.keys():
try:
filters.append(Filter.parse_query(key, params.getlist(key), model))
except ValueError:
continue
return Filters(
filters=filters,
limit=limit,
offset=offset,
)
return dependency

View File

@ -1,7 +1,13 @@
from typing import Any, List, Optional
from typing import Any, List, Optional, Type
import jinja2
import shortuuid # type: ignore
import shortuuid
from pydantic import BaseModel
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from lnbits.jinja2_templating import Jinja2Templates
from lnbits.requestvars import g
@ -102,3 +108,39 @@ def get_current_extension_name() -> str:
except:
ext_name = extension_director_name
return ext_name
def generate_filter_params_openapi(model: Type[BaseModel], keep_optional=False):
"""
Generate openapi documentation for Filters. This is intended to be used along parse_filters (see example)
:param model: Filter model
:param keep_optional: If false, all parameters will be optional, otherwise inferred from model
"""
fields = list(model.__fields__.values())
models = get_flat_models_from_fields(fields, set())
namemap = get_model_name_map(models)
params = []
for field in fields:
schema, definitions, _ = field_schema(field, model_name_map=namemap)
# Support nested definition
if "$ref" in schema:
name = schema["$ref"].split("/")[-1]
schema = definitions[name]
description = "Supports Filtering"
if schema["type"] == "object":
description += f". Nested attributes can be filtered too, e.g. `{field.alias}.[additional].[attributes]`"
parameter = {
"name": field.alias,
"in": "query",
"required": field.required if keep_optional else False,
"schema": schema,
"description": description,
}
params.append(parameter)
return {
"parameters": params,
}