Added Stepans decode functions

This commit is contained in:
benarc 2021-12-21 15:33:52 +00:00
parent 6b1c1af148
commit c87abef20f

View File

@ -3,6 +3,12 @@ import hashlib
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import Optional
from embit import bech32
from embit import compact
import base64
from io import BytesIO
import hmac
from fastapi import Request from fastapi import Request
from fastapi.param_functions import Query from fastapi.param_functions import Query
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@ -18,39 +24,79 @@ from .crud import (
update_lnurlpospayment, update_lnurlpospayment,
) )
def bech32_decode(bech):
"""tweaked version of bech32_decode that ignores length limitations"""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
(bech.lower() != bech and bech.upper() != bech)):
return
bech = bech.lower()
pos = bech.rfind('1')
if pos < 1 or pos + 7 > len(bech):
return
if not all(x in bech32.CHARSET for x in bech[pos+1:]):
return
hrp = bech[:pos]
data = [bech32.CHARSET.find(x) for x in bech[pos+1:]]
encoding = bech32.bech32_verify_checksum(hrp, data)
if encoding is None:
return
return bytes(bech32.convertbits(data[:-6], 5, 8, False))
def xor_decrypt(key, blob):
s = BytesIO(blob)
variant = s.read(1)[0]
if variant != 1:
raise RuntimeError("Not implemented")
# reading nonce
l = s.read(1)[0]
nonce = s.read(l)
if len(nonce) != l:
raise RuntimeError("Missing nonce bytes")
if l < 8:
raise RuntimeError("Nonce is too short")
# reading payload
l = s.read(1)[0]
payload = s.read(l)
if len(payload) > 32:
raise RuntimeError("Payload is too long for this encryption method")
if len(payload) != l:
raise RuntimeError("Missing payload bytes")
hmacval = s.read()
expected = hmac.new(key, b"Data:" + blob[:-len(hmacval)], digestmod="sha256").digest()
if len(hmacval) < 8:
raise RuntimeError("HMAC is too short")
if hmacval != expected[:len(hmacval)]:
raise RuntimeError("HMAC is invalid")
secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest()
payload = bytearray(payload)
for i in range(len(payload)):
payload[i] = payload[i] ^ secret[i]
s = BytesIO(payload)
pin = compact.read_from(s)
# currency
currency = s.read(1)
if currency != USD_CENTS:
raise RuntimeError("Unsupported currency: %s" % currency)
amount_in_cent = compact.read_from(s)
if s.read():
raise RuntimeError("Unexpected data")
return pin, amount_in_cent
@lnurlpos_ext.get( @lnurlpos_ext.get(
"/api/v1/lnurl/{nonce}/{payload}/{pos_id}", "/api/v1/lnurl/{pos_id}",
status_code=HTTPStatus.OK, status_code=HTTPStatus.OK,
name="lnurlpos.lnurl_response", name="lnurlpos.lnurl_v1_params",
) )
async def lnurl_response( async def lnurl_v1_params(
request: Request,
nonce: str = Query(None),
pos_id: str = Query(None),
payload: str = Query(None),
):
return await handle_lnurl_firstrequest(
request, pos_id, nonce, payload, verify_checksum=False
)
@lnurlpos_ext.get(
"/api/v2/lnurl/{pos_id}",
status_code=HTTPStatus.OK,
name="lnurlpos.lnurl_v2_params",
)
async def lnurl_v2_params(
request: Request, request: Request,
pos_id: str = Query(None), pos_id: str = Query(None),
n: str = Query(None),
p: str = Query(None), p: str = Query(None),
): ):
return await handle_lnurl_firstrequest(request, pos_id, n, p, verify_checksum=True) return await handle_lnurl_firstrequest(request, pos_id, p)
async def handle_lnurl_firstrequest( async def handle_lnurl_firstrequest(
request: Request, pos_id: str, nonce: str, payload: str, verify_checksum: bool request: Request, pos_id: str, payload: str
): ):
pos = await get_lnurlpos(pos_id) pos = await get_lnurlpos(pos_id)
if not pos: if not pos:
@ -59,53 +105,14 @@ async def handle_lnurl_firstrequest(
"reason": f"lnurlpos {pos_id} not found on this server", "reason": f"lnurlpos {pos_id} not found on this server",
} }
try: if len(payload) % 4 > 0:
nonceb = bytes.fromhex(nonce) payload += "="*(4-(len(payload)%4))
except ValueError:
try:
nonce += "=" * ((4 - len(nonce) % 4) % 4)
nonceb = base64.urlsafe_b64decode(nonce)
except:
return {
"status": "ERROR",
"reason": f"Invalid hex or base64 nonce: {nonce}",
}
try: data = base64.urlsafe_b64decode(payload)
payloadb = bytes.fromhex(payload) pin, amount_in_cent = xor_decrypt(key, data)
except ValueError:
try:
payload += "=" * ((4 - len(payload) % 4) % 4)
payloadb = base64.urlsafe_b64decode(payload)
except:
return {
"status": "ERROR",
"reason": f"Invalid hex or base64 payload: {payload}",
}
# check payload and nonce sizes
if len(payloadb) != 8 or len(nonceb) != 8:
return {"status": "ERROR", "reason": "Expected 8 bytes"}
# verify hmac
if verify_checksum:
expected = hmac.new(
pos.key.encode(), payloadb[:-2], digestmod="sha256"
).digest()
if expected[:2] != payloadb[-2:]:
return {"status": "ERROR", "reason": "Invalid HMAC"}
# decrypt
s = hmac.new(pos.key.encode(), nonceb, digestmod="sha256").digest()
res = bytearray(payloadb)
for i in range(len(res)):
res[i] = res[i] ^ s[i]
pin = int.from_bytes(res[0:2], "little")
amount = int.from_bytes(res[2:6], "little")
price_msat = ( price_msat = (
await fiat_amount_as_satoshis(float(amount) / 100, pos.currency) await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency)
if pos.currency != "sat" if pos.currency != "sat"
else amount else amount
) * 1000 ) * 1000