Merge pull request #469 from lnbits/StepansFix

Makes lnurlpos work with latest lnurlpos.ino
This commit is contained in:
Arc 2021-12-21 17:15:19 +00:00 committed by GitHub
commit 2d41f0c4ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,12 @@ import hashlib
from http import HTTPStatus
from typing import Optional
from embit import bech32
from embit import compact
import base64
from io import BytesIO
import hmac
from fastapi import Request
from fastapi.param_functions import Query
from starlette.exceptions import HTTPException
@ -18,39 +24,73 @@ from .crud import (
update_lnurlpospayment,
)
def bech32_decode(bech):
"""tweaked version of bech32_decode that ignores length limitations"""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
(bech.lower() != bech and bech.upper() != bech)):
return
bech = bech.lower()
pos = bech.rfind('1')
if pos < 1 or pos + 7 > len(bech):
return
if not all(x in bech32.CHARSET for x in bech[pos+1:]):
return
hrp = bech[:pos]
data = [bech32.CHARSET.find(x) for x in bech[pos+1:]]
encoding = bech32.bech32_verify_checksum(hrp, data)
if encoding is None:
return
return bytes(bech32.convertbits(data[:-6], 5, 8, False))
def xor_decrypt(key, blob):
s = BytesIO(blob)
variant = s.read(1)[0]
if variant != 1:
raise RuntimeError("Not implemented")
# reading nonce
l = s.read(1)[0]
nonce = s.read(l)
if len(nonce) != l:
raise RuntimeError("Missing nonce bytes")
if l < 8:
raise RuntimeError("Nonce is too short")
# reading payload
l = s.read(1)[0]
payload = s.read(l)
if len(payload) > 32:
raise RuntimeError("Payload is too long for this encryption method")
if len(payload) != l:
raise RuntimeError("Missing payload bytes")
hmacval = s.read()
expected = hmac.new(key, b"Data:" + blob[:-len(hmacval)], digestmod="sha256").digest()
if len(hmacval) < 8:
raise RuntimeError("HMAC is too short")
if hmacval != expected[:len(hmacval)]:
raise RuntimeError("HMAC is invalid")
secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest()
payload = bytearray(payload)
for i in range(len(payload)):
payload[i] = payload[i] ^ secret[i]
s = BytesIO(payload)
pin = compact.read_from(s)
amount_in_cent = compact.read_from(s)
return pin, amount_in_cent
@lnurlpos_ext.get(
"/api/v1/lnurl/{nonce}/{payload}/{pos_id}",
"/api/v1/lnurl/{pos_id}",
status_code=HTTPStatus.OK,
name="lnurlpos.lnurl_response",
name="lnurlpos.lnurl_v1_params",
)
async def lnurl_response(
request: Request,
nonce: str = Query(None),
pos_id: str = Query(None),
payload: str = Query(None),
):
return await handle_lnurl_firstrequest(
request, pos_id, nonce, payload, verify_checksum=False
)
@lnurlpos_ext.get(
"/api/v2/lnurl/{pos_id}",
status_code=HTTPStatus.OK,
name="lnurlpos.lnurl_v2_params",
)
async def lnurl_v2_params(
async def lnurl_v1_params(
request: Request,
pos_id: str = Query(None),
n: str = Query(None),
p: str = Query(None),
):
return await handle_lnurl_firstrequest(request, pos_id, n, p, verify_checksum=True)
return await handle_lnurl_firstrequest(request, pos_id, p)
async def handle_lnurl_firstrequest(
request: Request, pos_id: str, nonce: str, payload: str, verify_checksum: bool
request: Request, pos_id: str, payload: str
):
pos = await get_lnurlpos(pos_id)
if not pos:
@ -59,53 +99,13 @@ async def handle_lnurl_firstrequest(
"reason": f"lnurlpos {pos_id} not found on this server",
}
try:
nonceb = bytes.fromhex(nonce)
except ValueError:
try:
nonce += "=" * ((4 - len(nonce) % 4) % 4)
nonceb = base64.urlsafe_b64decode(nonce)
except:
return {
"status": "ERROR",
"reason": f"Invalid hex or base64 nonce: {nonce}",
}
try:
payloadb = bytes.fromhex(payload)
except ValueError:
try:
payload += "=" * ((4 - len(payload) % 4) % 4)
payloadb = base64.urlsafe_b64decode(payload)
except:
return {
"status": "ERROR",
"reason": f"Invalid hex or base64 payload: {payload}",
}
# check payload and nonce sizes
if len(payloadb) != 8 or len(nonceb) != 8:
return {"status": "ERROR", "reason": "Expected 8 bytes"}
# verify hmac
if verify_checksum:
expected = hmac.new(
pos.key.encode(), payloadb[:-2], digestmod="sha256"
).digest()
if expected[:2] != payloadb[-2:]:
return {"status": "ERROR", "reason": "Invalid HMAC"}
# decrypt
s = hmac.new(pos.key.encode(), nonceb, digestmod="sha256").digest()
res = bytearray(payloadb)
for i in range(len(res)):
res[i] = res[i] ^ s[i]
pin = int.from_bytes(res[0:2], "little")
amount = int.from_bytes(res[2:6], "little")
if len(payload) % 4 > 0:
payload += "="*(4-(len(payload)%4))
data = base64.urlsafe_b64decode(payload)
pin, amount_in_cent = xor_decrypt(pos.key.encode(), data)
price_msat = (
await fiat_amount_as_satoshis(float(amount) / 100, pos.currency)
await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency)
if pos.currency != "sat"
else amount
) * 1000
@ -161,7 +161,7 @@ async def lnurl_callback(request: Request, paymentid: str = Query(None)):
"successAction": {
"tag": "url",
"description": "Check the attached link",
"url": req.url_for("lnurlpos.displaypin", paymentid=paymentid),
"url": request.url_for("lnurlpos.displaypin", paymentid=paymentid),
},
"routes": [],
}