lnurlpos: return lnurl error on xor_decrypt failure.
This commit is contained in:
parent
40aadbfec7
commit
6b5aaa442d
|
@ -24,24 +24,27 @@ from .crud import (
|
||||||
update_lnurlpospayment,
|
update_lnurlpospayment,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def bech32_decode(bech):
|
def bech32_decode(bech):
|
||||||
"""tweaked version of bech32_decode that ignores length limitations"""
|
"""tweaked version of bech32_decode that ignores length limitations"""
|
||||||
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
|
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
|
||||||
(bech.lower() != bech and bech.upper() != bech)):
|
bech.lower() != bech and bech.upper() != bech
|
||||||
|
):
|
||||||
return
|
return
|
||||||
bech = bech.lower()
|
bech = bech.lower()
|
||||||
pos = bech.rfind('1')
|
pos = bech.rfind("1")
|
||||||
if pos < 1 or pos + 7 > len(bech):
|
if pos < 1 or pos + 7 > len(bech):
|
||||||
return
|
return
|
||||||
if not all(x in bech32.CHARSET for x in bech[pos+1:]):
|
if not all(x in bech32.CHARSET for x in bech[pos + 1 :]):
|
||||||
return
|
return
|
||||||
hrp = bech[:pos]
|
hrp = bech[:pos]
|
||||||
data = [bech32.CHARSET.find(x) for x in bech[pos+1:]]
|
data = [bech32.CHARSET.find(x) for x in bech[pos + 1 :]]
|
||||||
encoding = bech32.bech32_verify_checksum(hrp, data)
|
encoding = bech32.bech32_verify_checksum(hrp, data)
|
||||||
if encoding is None:
|
if encoding is None:
|
||||||
return
|
return
|
||||||
return bytes(bech32.convertbits(data[:-6], 5, 8, False))
|
return bytes(bech32.convertbits(data[:-6], 5, 8, False))
|
||||||
|
|
||||||
|
|
||||||
def xor_decrypt(key, blob):
|
def xor_decrypt(key, blob):
|
||||||
s = BytesIO(blob)
|
s = BytesIO(blob)
|
||||||
variant = s.read(1)[0]
|
variant = s.read(1)[0]
|
||||||
|
@ -62,10 +65,12 @@ def xor_decrypt(key, blob):
|
||||||
if len(payload) != l:
|
if len(payload) != l:
|
||||||
raise RuntimeError("Missing payload bytes")
|
raise RuntimeError("Missing payload bytes")
|
||||||
hmacval = s.read()
|
hmacval = s.read()
|
||||||
expected = hmac.new(key, b"Data:" + blob[:-len(hmacval)], digestmod="sha256").digest()
|
expected = hmac.new(
|
||||||
|
key, b"Data:" + blob[: -len(hmacval)], digestmod="sha256"
|
||||||
|
).digest()
|
||||||
if len(hmacval) < 8:
|
if len(hmacval) < 8:
|
||||||
raise RuntimeError("HMAC is too short")
|
raise RuntimeError("HMAC is too short")
|
||||||
if hmacval != expected[:len(hmacval)]:
|
if hmacval != expected[: len(hmacval)]:
|
||||||
raise RuntimeError("HMAC is invalid")
|
raise RuntimeError("HMAC is invalid")
|
||||||
secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest()
|
secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest()
|
||||||
payload = bytearray(payload)
|
payload = bytearray(payload)
|
||||||
|
@ -76,6 +81,7 @@ def xor_decrypt(key, blob):
|
||||||
amount_in_cent = compact.read_from(s)
|
amount_in_cent = compact.read_from(s)
|
||||||
return pin, amount_in_cent
|
return pin, amount_in_cent
|
||||||
|
|
||||||
|
|
||||||
@lnurlpos_ext.get(
|
@lnurlpos_ext.get(
|
||||||
"/api/v1/lnurl/{pos_id}",
|
"/api/v1/lnurl/{pos_id}",
|
||||||
status_code=HTTPStatus.OK,
|
status_code=HTTPStatus.OK,
|
||||||
|
@ -85,12 +91,6 @@ async def lnurl_v1_params(
|
||||||
request: Request,
|
request: Request,
|
||||||
pos_id: str = Query(None),
|
pos_id: str = Query(None),
|
||||||
p: str = Query(None),
|
p: str = Query(None),
|
||||||
):
|
|
||||||
return await handle_lnurl_firstrequest(request, pos_id, p)
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_lnurl_firstrequest(
|
|
||||||
request: Request, pos_id: str, payload: str
|
|
||||||
):
|
):
|
||||||
pos = await get_lnurlpos(pos_id)
|
pos = await get_lnurlpos(pos_id)
|
||||||
if not pos:
|
if not pos:
|
||||||
|
@ -100,10 +100,18 @@ async def handle_lnurl_firstrequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(payload) % 4 > 0:
|
if len(payload) % 4 > 0:
|
||||||
payload += "="*(4-(len(payload)%4))
|
payload += "=" * (4 - (len(payload) % 4))
|
||||||
|
|
||||||
data = base64.urlsafe_b64decode(payload)
|
data = base64.urlsafe_b64decode(payload)
|
||||||
pin, amount_in_cent = xor_decrypt(pos.key.encode(), data)
|
pin = 0
|
||||||
|
amount_in_cent = 0
|
||||||
|
try:
|
||||||
|
result = xor_decrypt(pos.key.encode(), data)
|
||||||
|
pin = result[0]
|
||||||
|
amount_in_cent = result[1]
|
||||||
|
except Exception as exc:
|
||||||
|
return {"status": "ERROR", "reason": str(exc)}
|
||||||
|
|
||||||
price_msat = (
|
price_msat = (
|
||||||
await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency)
|
await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency)
|
||||||
if pos.currency != "sat"
|
if pos.currency != "sat"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user