lnurlpos: return lnurl error on xor_decrypt failure.

This commit is contained in:
fiatjaf 2021-12-22 21:03:56 -03:00
parent 40aadbfec7
commit 6b5aaa442d

View File

@ -24,24 +24,27 @@ from .crud import (
update_lnurlpospayment,
)
def bech32_decode(bech):
"""tweaked version of bech32_decode that ignores length limitations"""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
(bech.lower() != bech and bech.upper() != bech)):
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
bech.lower() != bech and bech.upper() != bech
):
return
bech = bech.lower()
pos = bech.rfind('1')
pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech):
return
if not all(x in bech32.CHARSET for x in bech[pos+1:]):
if not all(x in bech32.CHARSET for x in bech[pos + 1 :]):
return
hrp = bech[:pos]
data = [bech32.CHARSET.find(x) for x in bech[pos+1:]]
data = [bech32.CHARSET.find(x) for x in bech[pos + 1 :]]
encoding = bech32.bech32_verify_checksum(hrp, data)
if encoding is None:
return
return bytes(bech32.convertbits(data[:-6], 5, 8, False))
def xor_decrypt(key, blob):
s = BytesIO(blob)
variant = s.read(1)[0]
@ -62,10 +65,12 @@ def xor_decrypt(key, blob):
if len(payload) != l:
raise RuntimeError("Missing payload bytes")
hmacval = s.read()
expected = hmac.new(key, b"Data:" + blob[:-len(hmacval)], digestmod="sha256").digest()
expected = hmac.new(
key, b"Data:" + blob[: -len(hmacval)], digestmod="sha256"
).digest()
if len(hmacval) < 8:
raise RuntimeError("HMAC is too short")
if hmacval != expected[:len(hmacval)]:
if hmacval != expected[: len(hmacval)]:
raise RuntimeError("HMAC is invalid")
secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest()
payload = bytearray(payload)
@ -76,6 +81,7 @@ def xor_decrypt(key, blob):
amount_in_cent = compact.read_from(s)
return pin, amount_in_cent
@lnurlpos_ext.get(
"/api/v1/lnurl/{pos_id}",
status_code=HTTPStatus.OK,
@ -85,12 +91,6 @@ async def lnurl_v1_params(
request: Request,
pos_id: str = Query(None),
p: str = Query(None),
):
return await handle_lnurl_firstrequest(request, pos_id, p)
async def handle_lnurl_firstrequest(
request: Request, pos_id: str, payload: str
):
pos = await get_lnurlpos(pos_id)
if not pos:
@ -100,10 +100,18 @@ async def handle_lnurl_firstrequest(
}
if len(payload) % 4 > 0:
payload += "="*(4-(len(payload)%4))
payload += "=" * (4 - (len(payload) % 4))
data = base64.urlsafe_b64decode(payload)
pin, amount_in_cent = xor_decrypt(pos.key.encode(), data)
pin = 0
amount_in_cent = 0
try:
result = xor_decrypt(pos.key.encode(), data)
pin = result[0]
amount_in_cent = result[1]
except Exception as exc:
return {"status": "ERROR", "reason": str(exc)}
price_msat = (
await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency)
if pos.currency != "sat"