fix proxyfix.

This commit is contained in:
fiatjaf 2020-09-28 00:21:33 -03:00
parent 098089af75
commit 49baa07141
2 changed files with 37 additions and 42 deletions

View File

@ -9,7 +9,7 @@ from .commands import db_migrate
from .core import core_app
from .db import open_db
from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored
from .proxy_fix import ProxyFix
from .proxy_fix import ASGIProxyFix
secure_headers = SecureHeaders(hsts=False)
@ -20,10 +20,10 @@ def create_app(config_object="lnbits.settings") -> Quart:
"""
app = Quart(__name__, static_folder="static")
app.config.from_object(config_object)
app.asgi_http_class = ASGIProxyFix
cors(app)
Compress(app)
ProxyFix(app, x_proto=1, x_host=1)
register_assets(app)
register_blueprints(app)

View File

@ -1,48 +1,46 @@
from typing import Optional, List
from typing import Optional, List, Callable
from functools import partial
from urllib.request import parse_http_list as _parse_list_header
from urllib.parse import urlparse
from werkzeug.datastructures import Headers
from quart import request
from quart import Request
from quart.asgi import ASGIHTTPConnection
class ProxyFix:
def __init__(self, app=None, x_for: int = 1, x_proto: int = 1, x_host: int = 0, x_port: int = 0, x_prefix: int = 0):
self.app = app
self.x_for = x_for
self.x_proto = x_proto
self.x_host = x_host
self.x_port = x_port
self.x_prefix = x_prefix
class ASGIProxyFix(ASGIHTTPConnection):
def _create_request_from_scope(self, send: Callable) -> Request:
headers = Headers()
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
for name, value in self.scope["headers"]:
headers.add(name.decode("latin1").title(), value.decode("latin1"))
if self.scope["http_version"] < "1.1":
headers.setdefault("Host", self.app.config["SERVER_NAME"] or "")
if app:
self.init_app(app)
path = self.scope["path"]
path = path if path[0] == "/" else urlparse(path).path
def init_app(self, app):
@app.before_request
async def before_request():
x_for = self._get_real_value(self.x_for, request.headers.get("X-Forwarded-For"))
if x_for:
request.headers["Remote-Addr"] = x_for
x_proto = self._get_real_value(1, headers.get("X-Forwarded-Proto"))
if x_proto:
self.scope["scheme"] = x_proto
x_proto = self._get_real_value(self.x_proto, request.headers.get("X-Forwarded-Proto"))
if x_proto:
request.scheme = x_proto
x_host = self._get_real_value(1, headers.get("X-Forwarded-Host"))
if x_host:
headers["host"] = x_host.lower()
x_host = self._get_real_value(self.x_host, request.headers.get("X-Forwarded-Host"))
if x_host:
request.headers["host"] = x_host.lower()
parts = x_host.split(":", 1)
# environ["SERVER_NAME"] = parts[0]
# if len(parts) == 2:
# environ["SERVER_PORT"] = parts[1]
x_port = self._get_real_value(self.x_port, request.headers.get("X-Forwarded-Port"))
if x_port:
host = request.host
if host:
parts = host.split(":", 1)
host = parts[0] if len(parts) == 2 else host
request.headers["host"] = f"{host}:{x_port}"
# environ["SERVER_PORT"] = x_port
return self.app.request_class(
self.scope["method"],
self.scope["scheme"],
path,
self.scope["query_string"],
headers,
self.scope.get("root_path", ""),
self.scope["http_version"],
max_content_length=self.app.config["MAX_CONTENT_LENGTH"],
body_timeout=self.app.config["BODY_TIMEOUT"],
send_push_promise=partial(self._send_push_promise, send),
scope=self.scope,
)
def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]:
"""Get the real value from a list header based on the configured
@ -95,6 +93,3 @@ class ProxyFix:
if not is_filename or value[:2] != "\\\\":
return value.replace("\\\\", "\\").replace('\\"', '"')
return value
# host, request.root_path, subdomain, request.scheme, request.method, request.path, request.query_string.decode(),