Added Stepans decode functions
This commit is contained in:
parent
6b1c1af148
commit
c87abef20f
|
@ -3,6 +3,12 @@ import hashlib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from embit import bech32
|
||||||
|
from embit import compact
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
import hmac
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.param_functions import Query
|
from fastapi.param_functions import Query
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
@ -18,39 +24,79 @@ from .crud import (
|
||||||
update_lnurlpospayment,
|
update_lnurlpospayment,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def bech32_decode(bech):
|
||||||
|
"""tweaked version of bech32_decode that ignores length limitations"""
|
||||||
|
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
|
||||||
|
(bech.lower() != bech and bech.upper() != bech)):
|
||||||
|
return
|
||||||
|
bech = bech.lower()
|
||||||
|
pos = bech.rfind('1')
|
||||||
|
if pos < 1 or pos + 7 > len(bech):
|
||||||
|
return
|
||||||
|
if not all(x in bech32.CHARSET for x in bech[pos+1:]):
|
||||||
|
return
|
||||||
|
hrp = bech[:pos]
|
||||||
|
data = [bech32.CHARSET.find(x) for x in bech[pos+1:]]
|
||||||
|
encoding = bech32.bech32_verify_checksum(hrp, data)
|
||||||
|
if encoding is None:
|
||||||
|
return
|
||||||
|
return bytes(bech32.convertbits(data[:-6], 5, 8, False))
|
||||||
|
|
||||||
|
def xor_decrypt(key, blob):
|
||||||
|
s = BytesIO(blob)
|
||||||
|
variant = s.read(1)[0]
|
||||||
|
if variant != 1:
|
||||||
|
raise RuntimeError("Not implemented")
|
||||||
|
# reading nonce
|
||||||
|
l = s.read(1)[0]
|
||||||
|
nonce = s.read(l)
|
||||||
|
if len(nonce) != l:
|
||||||
|
raise RuntimeError("Missing nonce bytes")
|
||||||
|
if l < 8:
|
||||||
|
raise RuntimeError("Nonce is too short")
|
||||||
|
# reading payload
|
||||||
|
l = s.read(1)[0]
|
||||||
|
payload = s.read(l)
|
||||||
|
if len(payload) > 32:
|
||||||
|
raise RuntimeError("Payload is too long for this encryption method")
|
||||||
|
if len(payload) != l:
|
||||||
|
raise RuntimeError("Missing payload bytes")
|
||||||
|
hmacval = s.read()
|
||||||
|
expected = hmac.new(key, b"Data:" + blob[:-len(hmacval)], digestmod="sha256").digest()
|
||||||
|
if len(hmacval) < 8:
|
||||||
|
raise RuntimeError("HMAC is too short")
|
||||||
|
if hmacval != expected[:len(hmacval)]:
|
||||||
|
raise RuntimeError("HMAC is invalid")
|
||||||
|
secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest()
|
||||||
|
payload = bytearray(payload)
|
||||||
|
for i in range(len(payload)):
|
||||||
|
payload[i] = payload[i] ^ secret[i]
|
||||||
|
s = BytesIO(payload)
|
||||||
|
pin = compact.read_from(s)
|
||||||
|
# currency
|
||||||
|
currency = s.read(1)
|
||||||
|
if currency != USD_CENTS:
|
||||||
|
raise RuntimeError("Unsupported currency: %s" % currency)
|
||||||
|
amount_in_cent = compact.read_from(s)
|
||||||
|
if s.read():
|
||||||
|
raise RuntimeError("Unexpected data")
|
||||||
|
return pin, amount_in_cent
|
||||||
|
|
||||||
@lnurlpos_ext.get(
|
@lnurlpos_ext.get(
|
||||||
"/api/v1/lnurl/{nonce}/{payload}/{pos_id}",
|
"/api/v1/lnurl/{pos_id}",
|
||||||
status_code=HTTPStatus.OK,
|
status_code=HTTPStatus.OK,
|
||||||
name="lnurlpos.lnurl_response",
|
name="lnurlpos.lnurl_v1_params",
|
||||||
)
|
)
|
||||||
async def lnurl_response(
|
async def lnurl_v1_params(
|
||||||
request: Request,
|
|
||||||
nonce: str = Query(None),
|
|
||||||
pos_id: str = Query(None),
|
|
||||||
payload: str = Query(None),
|
|
||||||
):
|
|
||||||
return await handle_lnurl_firstrequest(
|
|
||||||
request, pos_id, nonce, payload, verify_checksum=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@lnurlpos_ext.get(
|
|
||||||
"/api/v2/lnurl/{pos_id}",
|
|
||||||
status_code=HTTPStatus.OK,
|
|
||||||
name="lnurlpos.lnurl_v2_params",
|
|
||||||
)
|
|
||||||
async def lnurl_v2_params(
|
|
||||||
request: Request,
|
request: Request,
|
||||||
pos_id: str = Query(None),
|
pos_id: str = Query(None),
|
||||||
n: str = Query(None),
|
|
||||||
p: str = Query(None),
|
p: str = Query(None),
|
||||||
):
|
):
|
||||||
return await handle_lnurl_firstrequest(request, pos_id, n, p, verify_checksum=True)
|
return await handle_lnurl_firstrequest(request, pos_id, p)
|
||||||
|
|
||||||
|
|
||||||
async def handle_lnurl_firstrequest(
|
async def handle_lnurl_firstrequest(
|
||||||
request: Request, pos_id: str, nonce: str, payload: str, verify_checksum: bool
|
request: Request, pos_id: str, payload: str
|
||||||
):
|
):
|
||||||
pos = await get_lnurlpos(pos_id)
|
pos = await get_lnurlpos(pos_id)
|
||||||
if not pos:
|
if not pos:
|
||||||
|
@ -59,53 +105,14 @@ async def handle_lnurl_firstrequest(
|
||||||
"reason": f"lnurlpos {pos_id} not found on this server",
|
"reason": f"lnurlpos {pos_id} not found on this server",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
if len(payload) % 4 > 0:
|
||||||
nonceb = bytes.fromhex(nonce)
|
payload += "="*(4-(len(payload)%4))
|
||||||
except ValueError:
|
|
||||||
try:
|
|
||||||
nonce += "=" * ((4 - len(nonce) % 4) % 4)
|
|
||||||
nonceb = base64.urlsafe_b64decode(nonce)
|
|
||||||
except:
|
|
||||||
return {
|
|
||||||
"status": "ERROR",
|
|
||||||
"reason": f"Invalid hex or base64 nonce: {nonce}",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
data = base64.urlsafe_b64decode(payload)
|
||||||
payloadb = bytes.fromhex(payload)
|
pin, amount_in_cent = xor_decrypt(key, data)
|
||||||
except ValueError:
|
|
||||||
try:
|
|
||||||
payload += "=" * ((4 - len(payload) % 4) % 4)
|
|
||||||
payloadb = base64.urlsafe_b64decode(payload)
|
|
||||||
except:
|
|
||||||
return {
|
|
||||||
"status": "ERROR",
|
|
||||||
"reason": f"Invalid hex or base64 payload: {payload}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# check payload and nonce sizes
|
|
||||||
if len(payloadb) != 8 or len(nonceb) != 8:
|
|
||||||
return {"status": "ERROR", "reason": "Expected 8 bytes"}
|
|
||||||
|
|
||||||
# verify hmac
|
|
||||||
if verify_checksum:
|
|
||||||
expected = hmac.new(
|
|
||||||
pos.key.encode(), payloadb[:-2], digestmod="sha256"
|
|
||||||
).digest()
|
|
||||||
if expected[:2] != payloadb[-2:]:
|
|
||||||
return {"status": "ERROR", "reason": "Invalid HMAC"}
|
|
||||||
|
|
||||||
# decrypt
|
|
||||||
s = hmac.new(pos.key.encode(), nonceb, digestmod="sha256").digest()
|
|
||||||
res = bytearray(payloadb)
|
|
||||||
for i in range(len(res)):
|
|
||||||
res[i] = res[i] ^ s[i]
|
|
||||||
|
|
||||||
pin = int.from_bytes(res[0:2], "little")
|
|
||||||
amount = int.from_bytes(res[2:6], "little")
|
|
||||||
|
|
||||||
price_msat = (
|
price_msat = (
|
||||||
await fiat_amount_as_satoshis(float(amount) / 100, pos.currency)
|
await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency)
|
||||||
if pos.currency != "sat"
|
if pos.currency != "sat"
|
||||||
else amount
|
else amount
|
||||||
) * 1000
|
) * 1000
|
||||||
|
|
Loading…
Reference in New Issue
Block a user