fix: mypy errors

This commit is contained in:
Eneko Illarramendi 2020-04-26 13:28:19 +02:00 committed by Sebastian Geisler
parent 976a3d4e5c
commit c3e337a319
18 changed files with 91 additions and 94 deletions

View File

@ -1,9 +1,9 @@
import importlib
from flask import Flask
from flask_assets import Environment, Bundle
from flask_compress import Compress
from flask_talisman import Talisman
from flask_assets import Environment, Bundle # type: ignore
from flask_compress import Compress # type: ignore
from flask_talisman import Talisman # type: ignore
from os import getenv
from werkzeug.middleware.proxy_fix import ProxyFix
@ -15,7 +15,7 @@ from .settings import FORCE_HTTPS
disabled_extensions = getenv("LNBITS_DISABLED_EXTENSIONS", "").split(",")
app = Flask(__name__)
app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1)
app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) # type: ignore
valid_extensions = [ext for ext in ExtensionManager(disabled=disabled_extensions).extensions if ext.is_valid]

View File

@ -1,3 +1,5 @@
# type: ignore
import bitstring
import re
from binascii import hexlify

View File

@ -1,7 +1,7 @@
from flask import Blueprint
core_app = Blueprint("core", __name__, template_folder="templates", static_folder="static")
core_app: Blueprint = Blueprint("core", __name__, template_folder="templates", static_folder="static")
from .views.api import * # noqa

View File

@ -16,7 +16,10 @@ def create_account() -> User:
user_id = uuid4().hex
db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
return get_account(user_id=user_id)
new_account = get_account(user_id=user_id)
assert new_account, "Newly created account couldn't be retrieved"
return new_account
def get_account(user_id: str) -> Optional[User]:
@ -74,7 +77,10 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
(wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex),
)
return get_wallet(wallet_id=wallet_id)
new_wallet = get_wallet(wallet_id=wallet_id)
assert new_wallet, "Newly created wallet couldn't be retrieved"
return new_wallet
def delete_wallet(*, user_id: str, wallet_id: str) -> None:
@ -175,7 +181,7 @@ def delete_wallet_payments_expired(wallet_id: str, *, seconds: int = 86400) -> N
def create_payment(
*, wallet_id: str, checking_id: str, amount: str, memo: str, fee: int = 0, pending: bool = True
*, wallet_id: str, checking_id: str, amount: int, memo: str, fee: int = 0, pending: bool = True
) -> Payment:
with open_db() as db:
db.execute(
@ -186,7 +192,10 @@ def create_payment(
(wallet_id, checking_id, amount, int(pending), memo, fee),
)
return get_wallet_payment(wallet_id, checking_id)
new_payment = get_wallet_payment(wallet_id, checking_id)
assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment
def update_payment_status(checking_id: str, pending: bool) -> None:

View File

@ -4,8 +4,8 @@ from typing import List, NamedTuple, Optional
class User(NamedTuple):
id: str
email: str
extensions: Optional[List[str]] = []
wallets: Optional[List["Wallet"]] = []
extensions: List[str] = []
wallets: List["Wallet"] = []
password: Optional[str] = None
@property
@ -27,9 +27,9 @@ class Wallet(NamedTuple):
@property
def balance(self) -> int:
return int(self.balance / 1000)
return self.balance // 1000
def get_payment(self, checking_id: str) -> "Payment":
def get_payment(self, checking_id: str) -> Optional["Payment"]:
from .crud import get_wallet_payment
return get_wallet_payment(self.id, checking_id)
@ -59,7 +59,7 @@ class Payment(NamedTuple):
@property
def sat(self) -> int:
return self.amount / 1000
return self.amount // 1000
@property
def is_in(self) -> bool:

View File

@ -1,6 +1,6 @@
from typing import Optional, Tuple
from lnbits.bolt11 import decode as bolt11_decode
from lnbits.bolt11 import decode as bolt11_decode # type: ignore
from lnbits.helpers import urlsafe_short_hash
from lnbits.settings import WALLET
@ -24,7 +24,6 @@ def create_invoice(*, wallet_id: str, amount: int, memo: str) -> Tuple[str, str]
def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -> str:
temp_id = f"temp_{urlsafe_short_hash()}"
try:
invoice = bolt11_decode(bolt11)
@ -34,7 +33,7 @@ def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -
if max_sat and invoice.amount_msat > max_sat * 1000:
raise ValueError("Amount in invoice is too high.")
fee_reserve = max(1000, invoice.amount_msat * 0.01)
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
create_payment(
wallet_id=wallet_id,
checking_id=temp_id,
@ -43,7 +42,9 @@ def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -
memo=temp_id,
)
if get_wallet(wallet_id).balance_msat < 0:
wallet = get_wallet(wallet_id)
assert wallet, "invalid wallet id"
if wallet.balance_msat < 0:
raise PermissionError("Insufficient balance.")
ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(bolt11)

View File

@ -1,8 +1,8 @@
import requests
from flask import abort, redirect, request, url_for
from lnurl import LnurlWithdrawResponse, handle as handle_lnurl
from lnurl.exceptions import LnurlException
from lnurl import LnurlWithdrawResponse, handle as handle_lnurl # type: ignore
from lnurl.exceptions import LnurlException # type: ignore
from time import sleep
from lnbits.core import core_app

View File

@ -1,4 +1,4 @@
from cerberus import Validator
from cerberus import Validator # type: ignore
from flask import g, abort, jsonify, request
from functools import wraps
from typing import List, Union

View File

@ -1,7 +1,7 @@
from flask import Blueprint
amilk_ext = Blueprint("amilk", __name__, static_folder="static", template_folder="templates")
amilk_ext: Blueprint = Blueprint("amilk", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa

View File

@ -1,7 +1,7 @@
from flask import Blueprint
diagonalley_ext = Blueprint("diagonalley", __name__, static_folder="static", template_folder="templates")
diagonalley_ext: Blueprint = Blueprint("diagonalley", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa

View File

@ -1,7 +1,7 @@
from flask import Blueprint
events_ext = Blueprint("events", __name__, static_folder="static", template_folder="templates")
events_ext: Blueprint = Blueprint("events", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa

View File

@ -1,7 +1,7 @@
from flask import Blueprint
example_ext = Blueprint("example", __name__, static_folder="static", template_folder="templates")
example_ext: Blueprint = Blueprint("example", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa

View File

@ -1,7 +1,7 @@
from flask import Blueprint
tpos_ext = Blueprint("tpos", __name__, static_folder="static", template_folder="templates")
tpos_ext: Blueprint = Blueprint("tpos", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa

View File

@ -1,7 +1,7 @@
from flask import Blueprint
withdraw_ext = Blueprint("withdraw", __name__, static_folder="static", template_folder="templates")
withdraw_ext: Blueprint = Blueprint("withdraw", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa

View File

@ -1,6 +1,6 @@
import json
import os
import shortuuid
import shortuuid # type: ignore
from typing import List, NamedTuple, Optional
@ -34,7 +34,14 @@ class ExtensionManager:
config = {}
is_valid = False
output.append(Extension(**{**{"code": extension, "is_valid": is_valid}, **config}))
output.append(Extension(
extension,
is_valid,
config.get('name'),
config.get('short_description'),
config.get('icon'),
config.get('contributors')
))
return output

View File

@ -1,7 +1,6 @@
from requests import get, post
from os import getenv
from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet
from lightning import LightningRpc
from lightning import LightningRpc # type: ignore
import random
class CLightningWallet(Wallet):
@ -17,7 +16,7 @@ class CLightningWallet(Wallet):
def pay_invoice(self, bolt11: str) -> PaymentResponse:
r = self.l1.pay(bolt11)
ok, checking_id, fee_msat, error_message = True, None, None, None
ok, checking_id, fee_msat, error_message = True, None, 0, None
return PaymentResponse(ok, checking_id, fee_msat, error_message)
def get_invoice_status(self, checking_id: str) -> PaymentStatus:
@ -29,8 +28,8 @@ class CLightningWallet(Wallet):
def get_payment_status(self, checking_id: str) -> PaymentStatus:
r = self.l1.listsendpays(checking_id)
if not r.ok:
return PaymentStatus(r, None)
payments = [p for p in r.json()["payments"] if p["payment_hash"] == payment_hash]
return PaymentStatus(None)
payments = [p for p in r.json()["payments"] if p["payment_hash"] == checking_id]
payment = payments[0] if payments else None
statuses = {"UNKNOWN": None, "IN_FLIGHT": None, "SUCCEEDED": True, "FAILED": False}
return PaymentStatus(statuses[payment["status"]] if payment else None)

View File

@ -1,14 +1,13 @@
from os import getenv
import os
import base64
import lnd_grpc # https://github.com/willcl-ark/lnd_grpc
import lnd_grpc # type: ignore
from os import getenv
from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet
class LndWallet(Wallet):
def __init__(self):
endpoint = getenv("LND_GRPC_ENDPOINT")
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
self.port = getenv("LND_GRPC_PORT")
@ -18,31 +17,21 @@ class LndWallet(Wallet):
self.auth_cert = getenv("LND_CERT")
lnd_rpc = lnd_grpc.Client(
lnd_dir = None,
tls_cert_path = self.auth_cert,
network = 'mainnet',
grpc_host = self.endpoint,
grpc_port = self.port
lnd_dir=None, tls_cert_path=self.auth_cert, network="mainnet", grpc_host=self.endpoint, grpc_port=self.port
)
def create_invoice(self, amount: int, memo: str = "") -> InvoiceResponse:
lnd_rpc = lnd_grpc.Client(
lnd_dir=None,
macaroon_path=self.auth_invoice,
tls_cert_path=self.auth_cert,
network = 'mainnet',
network="mainnet",
grpc_host=self.endpoint,
grpc_port = self.port
grpc_port=self.port,
)
lndResponse = lnd_rpc.add_invoice(
memo = memo,
value = amount,
expiry = 600,
private = True
)
decoded_hash = base64.b64encode(lndResponse.r_hash).decode('utf-8').replace("/","_")
lndResponse = lnd_rpc.add_invoice(memo=memo, value=amount, expiry=600, private=True)
decoded_hash = base64.b64encode(lndResponse.r_hash).decode("utf-8").replace("/", "_")
print(lndResponse.r_hash)
ok, checking_id, payment_request, error_message = True, decoded_hash, str(lndResponse.payment_request), None
return InvoiceResponse(ok, checking_id, payment_request, error_message)
@ -53,21 +42,19 @@ class LndWallet(Wallet):
lnd_dir=None,
macaroon_path=self.auth_admin,
tls_cert_path=self.auth_cert,
network = 'mainnet',
network="mainnet",
grpc_host=self.endpoint,
grpc_port = self.port
grpc_port=self.port,
)
payinvoice = lnd_rpc.pay_invoice(
payment_request = bolt11,
)
payinvoice = lnd_rpc.pay_invoice(payment_request=bolt11,)
ok, checking_id, fee_msat, error_message = True, None, 0, None
if payinvoice.payment_error:
ok, error_message = False, payinvoice.payment_error
else:
checking_id = base64.b64encode(payinvoice.payment_hash).decode('utf-8').replace("/","_")
checking_id = base64.b64encode(payinvoice.payment_hash).decode("utf-8").replace("/", "_")
return PaymentResponse(ok, checking_id, fee_msat, error_message)
@ -79,23 +66,16 @@ class LndWallet(Wallet):
lnd_dir=None,
macaroon_path=self.auth_invoice,
tls_cert_path=self.auth_cert,
network = 'mainnet',
network="mainnet",
grpc_host=self.endpoint,
grpc_port = self.port
grpc_port=self.port,
)
for _response in lnd_rpc.subscribe_single_invoice(check_id):
if _response.state == 1:
return PaymentStatus(True)
invoiceThread = threading.Thread(
target=detectPayment,
args=[lndResponse.check_id, ],
daemon=True
)
invoiceThread.start()
return PaymentStatus(None)
def get_payment_status(self, checking_id: str) -> PaymentStatus:

View File

@ -71,11 +71,10 @@ class LndRestWallet(Wallet):
return PaymentStatus(r.json()["settled"])
def get_payment_status(self, checking_id: str) -> PaymentStatus:
r = get(url=f"{self.endpoint}/v1/payments", headers=self.auth_admin, verify=self.auth_cert, params={"include_incomplete": True, "max_payments": "20"})
r = get(url=f"{self.endpoint}/v1/payments", headers=self.auth_admin, verify=self.auth_cert, params={"include_incomplete": "True", "max_payments": "20"})
if not r.ok:
return PaymentStatus(r, None)
return PaymentStatus(None)
payments = [p for p in r.json()["payments"] if p["payment_hash"] == checking_id]
print(checking_id)