register channel listeners instead of callbacks.

makes for a little less black magic and more reasonable use of nurseries
and less unnecessary pseudo-requests.
This commit is contained in:
fiatjaf 2020-10-06 01:50:55 -03:00
parent 95e8573ff8
commit c5352c0309
8 changed files with 65 additions and 30 deletions

View File

@ -12,7 +12,7 @@ from .core import core_app
from .db import open_db, open_ext_db
from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored
from .proxy_fix import ASGIProxyFix
from .tasks import invoice_listener, webhook_handler, grab_app_for_later
from .tasks import run_deferred_async, invoice_listener, webhook_handler, grab_app_for_later
secure_headers = SecureHeaders(hsts=False)
@ -111,6 +111,8 @@ def register_async_tasks(app):
@app.before_serving
async def listeners():
run_deferred_async(app.nursery)
app.nursery.start_soon(invoice_listener)
print("started global invoice_listener.")

View File

@ -8,8 +8,8 @@ core_app: Blueprint = Blueprint(
from .views.api import * # noqa
from .views.generic import * # noqa
from .tasks import on_invoice_paid
from .tasks import register_listeners
from lnbits.tasks import register_invoice_listener
from lnbits.tasks import record_async
register_invoice_listener("core", on_invoice_paid)
core_app.record(record_async(register_listeners))

View File

@ -70,6 +70,7 @@ class Payment(NamedTuple):
preimage: str
payment_hash: str
extra: Dict
wallet_id: str
@classmethod
def from_row(cls, row: Row):
@ -84,6 +85,7 @@ class Payment(NamedTuple):
fee=row["fee"],
memo=row["memo"],
time=row["time"],
wallet_id=row["wallet"],
)
@property

View File

@ -1,15 +1,22 @@
import trio # type: ignore
from typing import List
from .models import Payment
from lnbits.tasks import register_invoice_listener
sse_listeners: List[trio.MemorySendChannel] = []
async def on_invoice_paid(payment: Payment):
for send_channel in sse_listeners:
try:
send_channel.send_nowait(payment)
except trio.WouldBlock:
print("removing sse listener", send_channel)
sse_listeners.remove(send_channel)
async def register_listeners():
invoice_paid_chan_send, invoice_paid_chan_recv = trio.open_memory_channel(5)
register_invoice_listener(invoice_paid_chan_send)
await wait_for_paid_invoices(invoice_paid_chan_recv)
async def wait_for_paid_invoices(invoice_paid_chan: trio.MemoryReceiveChannel):
async for payment in invoice_paid_chan:
for send_channel in sse_listeners:
try:
send_channel.send_nowait(payment)
except trio.WouldBlock:
print("removing sse listener", send_channel)
sse_listeners.remove(send_channel)

View File

@ -128,6 +128,7 @@ async def api_payment(payment_hash):
@api_check_wallet_key("invoice")
async def api_payments_sse():
g.db.close()
this_wallet_id = g.wallet.id
send_payment, receive_payment = trio.open_memory_channel(0)
@ -138,7 +139,8 @@ async def api_payments_sse():
async def payment_received() -> None:
async for payment in receive_payment:
await send_event.send(("payment", payment))
if payment.wallet_id == this_wallet_id:
await send_event.send(("payment", payment))
async def repeat_keepalive():
await trio.sleep(1)
@ -160,7 +162,6 @@ async def api_payments_sse():
yield b"\n".join(message) + b"\r\n\r\n"
except trio.Cancelled:
print("canceled!")
return
response = await make_response(

View File

@ -7,8 +7,8 @@ lnurlp_ext: Blueprint = Blueprint("lnurlp", __name__, static_folder="static", te
from .views_api import * # noqa
from .views import * # noqa
from .lnurl import * # noqa
from .tasks import on_invoice_paid
from .tasks import register_listeners
from lnbits.tasks import register_invoice_listener
from lnbits.tasks import record_async
register_invoice_listener("lnurlp", on_invoice_paid)
lnurlp_ext.record(record_async(register_listeners))

View File

@ -1,10 +1,23 @@
import trio # type: ignore
import httpx
from lnbits.core.models import Payment
from lnbits.tasks import run_on_pseudo_request, register_invoice_listener
from .crud import get_pay_link_by_invoice, mark_webhook_sent
async def register_listeners():
invoice_paid_chan_send, invoice_paid_chan_recv = trio.open_memory_channel(2)
register_invoice_listener(invoice_paid_chan_send)
await wait_for_paid_invoices(invoice_paid_chan_recv)
async def wait_for_paid_invoices(invoice_paid_chan: trio.MemoryReceiveChannel):
async for payment in invoice_paid_chan:
await run_on_pseudo_request(on_invoice_paid, payment)
async def on_invoice_paid(payment: Payment) -> None:
islnurlp = "lnurlp" == payment.extra.get("tag")
if islnurlp:

View File

@ -1,14 +1,13 @@
import trio # type: ignore
from http import HTTPStatus
from typing import Optional, Tuple, List, Callable, Awaitable
from typing import Optional, List, Callable
from quart import Request, g
from quart_trio import QuartTrio
from werkzeug.datastructures import Headers
from lnbits.db import open_db, open_ext_db
from lnbits.db import open_db
from lnbits.settings import WALLET
from lnbits.core.models import Payment
from lnbits.core.crud import get_standalone_payment
main_app: Optional[QuartTrio] = None
@ -19,6 +18,21 @@ def grab_app_for_later(app: QuartTrio):
main_app = app
deferred_async: List[Callable] = []
def record_async(func: Callable) -> Callable:
def recorder(state):
deferred_async.append(func)
return recorder
def run_deferred_async(nursery):
for func in deferred_async:
nursery.start_soon(func)
async def send_push_promise(a, b) -> None:
pass
@ -45,16 +59,16 @@ async def run_on_pseudo_request(func: Callable, *args):
nursery.start_soon(run)
invoice_listeners: List[Tuple[str, Callable[[Payment], Awaitable[None]]]] = []
invoice_listeners: List[trio.MemorySendChannel] = []
def register_invoice_listener(ext_name: str, cb: Callable[[Payment], Awaitable[None]]):
def register_invoice_listener(send_chan: trio.MemorySendChannel):
"""
A method intended for extensions to call when they want to be notified about
new invoice payments incoming.
"""
print(f"registering {ext_name} invoice_listener callback: {cb}")
invoice_listeners.append((ext_name, cb))
print(f"registering invoice_listener: {send_chan}")
invoice_listeners.append(send_chan)
async def webhook_handler():
@ -73,9 +87,5 @@ async def invoice_callback_dispatcher(checking_id: str):
payment = get_standalone_payment(checking_id)
if payment and payment.is_in:
payment.set_pending(False)
for ext_name, cb in invoice_listeners:
if ext_name == "core":
await cb(payment)
else:
with open_ext_db(ext_name) as g.ext_db: # type: ignore
await cb(payment)
for send_chan in invoice_listeners:
await send_chan.send(payment)