Files
tradon/wsgi.py
2025-12-26 13:11:43 +00:00

257 lines
9.5 KiB
Python
Executable File

# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import base64
import http.client
import logging
import os
import posixpath
import sys
import traceback
import urllib.parse
from functools import wraps
from werkzeug.routing import BaseConverter, Map, Rule
try:
from werkzeug.middleware.proxy_fix import ProxyFix
def NumProxyFix(app, num_proxies):
return ProxyFix(app,
x_for=num_proxies, x_proto=num_proxies, x_host=num_proxies,
x_port=num_proxies, x_prefix=num_proxies)
except ImportError:
from werkzeug.contrib.fixers import ProxyFix as NumProxyFix
try:
from werkzeug.middleware.shared_data import SharedDataMiddleware
except ImportError:
from werkzeug.wsgi import SharedDataMiddleware
from trytond import backend
from trytond.config import config
from trytond.protocols.jsonrpc import JSONProtocol
from trytond.protocols.wrappers import (
HTTPStatus, Request, Response, abort, exceptions)
from trytond.protocols.xmlrpc import XMLProtocol
from trytond.status import processing
from trytond.tools import resolve, safe_join
__all__ = ['TrytondWSGI', 'app']
logger = logging.getLogger(__name__)
class Base64Converter(BaseConverter):
def to_python(self, value):
return base64.urlsafe_b64decode(value).decode('utf-8')
def to_url(self, value):
return base64.urlsafe_b64encode(value.encode('utf-8')).decode('ascii')
class TrytondWSGI(object):
def __init__(self):
self.url_map = Map([], converters={
'base64': Base64Converter,
})
self.protocols = [JSONProtocol, XMLProtocol]
self.error_handlers = []
def route(self, string, methods=None, defaults=None):
def decorator(func):
self.url_map.add(Rule(
string, endpoint=func, methods=methods, defaults=defaults))
return func
return decorator
def error_handler(self, handler):
self.error_handlers.append(handler)
return handler
def auth_required(self, func):
@wraps(func)
def wrapper(request, *args, **kwargs):
if request.user_id:
return func(request, *args, **kwargs)
else:
headers = {}
if request.headers.get('X-Requested-With') != 'XMLHttpRequest':
headers['WWW-Authenticate'] = 'Basic realm="Tryton"'
response = Response(None, http.client.UNAUTHORIZED, headers)
abort(http.client.UNAUTHORIZED, response=response)
return wrapper
def check_request_size(self, request, size=None):
if request.method not in {'POST', 'PUT', 'PATCH'}:
return
if size is None:
if request.user_id:
max_size = config.getint(
'request', 'max_size_authenticated')
else:
max_size = config.getint(
'request', 'max_size')
else:
max_size = size
if max_size:
content_length = request.content_length
if content_length is None:
abort(http.client.LENGTH_REQUIRED)
elif content_length > max_size:
abort(http.client.REQUEST_ENTITY_TOO_LARGE)
def dispatch_request(self, request):
adapter = self.url_map.bind_to_environ(request.environ)
try:
endpoint, request.view_args = adapter.match()
max_request_size = getattr(endpoint, 'max_request_size', None)
self.check_request_size(request, max_request_size)
return endpoint(request, **request.view_args)
except exceptions.HTTPException as e:
logger.debug(
"Exception when processing %s", request, exc_info=True)
return e
except backend.DatabaseOperationalError as e:
logger.debug(
"Exception when processing %s", request, exc_info=True)
return exceptions.ServiceUnavailable(description=str(e))
except Exception as e:
logger.debug(
"Exception when processing %s", request, exc_info=True)
tb_s = ''.join(traceback.format_exception(*sys.exc_info()))
for path in sys.path:
tb_s = tb_s.replace(path, '')
e.__format_traceback__ = tb_s
response = e
for error_handler in self.error_handlers:
rv = error_handler(self, request, e)
if isinstance(rv, Response):
response = rv
return response
def make_response(self, request, data):
for cls in self.protocols:
for mimetype, _ in request.accept_mimetypes:
if cls.content_type in mimetype:
response = cls.response(data, request)
break
else:
continue
break
else:
for cls in self.protocols:
if cls.content_type in request.environ.get('CONTENT_TYPE', ''):
response = cls.response(data, request)
break
else:
if isinstance(data, Exception):
try:
response = exceptions.InternalServerError(
original_exception=data)
except TypeError:
response = exceptions.InternalServerError(data)
else:
response = Response(data)
return response
def wsgi_app(self, environ, start_response):
for cls in self.protocols:
if cls.content_type in environ.get('CONTENT_TYPE', ''):
request = cls.request(environ)
break
else:
request = Request(environ)
logger.info('REQUEST%s', request)
origin = request.headers.get('Origin')
origin_host = urllib.parse.urlparse(origin).netloc if origin else ''
host = request.headers.get('Host')
if origin and origin != 'null' and origin_host != host:
cors = filter(
None, config.get('web', 'cors', default='').splitlines())
if origin not in cors:
if (origin.startswith('moz-extension://')
or origin.startswith('chrome-extension://')):
origin = 'null'
else:
abort(HTTPStatus.FORBIDDEN)
if origin == 'null':
adapter = self.url_map.bind_to_environ(request.environ)
endpoint = adapter.match()[0]
if not getattr(endpoint, 'allow_null_origin', False):
abort(HTTPStatus.FORBIDDEN)
with processing(request):
data = self.dispatch_request(request)
if not isinstance(data, (Response, exceptions.HTTPException)):
response = self.make_response(request, data)
else:
response = data
if origin and isinstance(response, Response):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Vary'] = 'Origin'
method = request.headers.get('Access-Control-Request-Method')
if method:
response.headers['Access-Control-Allow-Methods'] = method
headers = request.headers.get('Access-Control-Request-Headers')
if headers:
response.headers['Access-Control-Allow-Headers'] = headers
response.headers['Access-Control-Max-Age'] = config.getint(
'web', 'cache_timeout')
return response(environ, start_response)
def __call__(self, environ, start_response):
return self.wsgi_app(environ, start_response)
class SharedDataMiddlewareIndex(SharedDataMiddleware):
def __call__(self, environ, start_response):
if environ['REQUEST_METHOD'] not in {'GET', 'HEAD'}:
return self.app(environ, start_response)
return super(SharedDataMiddlewareIndex, self).__call__(
environ, start_response)
def get_directory_loader(self, directory):
def loader(path):
if path is not None:
path = safe_join(directory, path)
else:
path = directory
if path is not None:
if os.path.isdir(path):
path = posixpath.join(path, 'index.html')
if os.path.isfile(path):
return os.path.basename(path), self._opener(path)
return None, None
return loader
app = TrytondWSGI()
if config.get('web', 'root'):
static_files = {
'/': config.get('web', 'root'),
}
app.wsgi_app = SharedDataMiddlewareIndex(
app.wsgi_app, static_files,
cache_timeout=config.getint('web', 'cache_timeout'))
num_proxies = config.getint('web', 'num_proxies')
if num_proxies:
app.wsgi_app = NumProxyFix(app.wsgi_app, num_proxies)
if config.has_section('wsgi middleware'):
for middleware in config.options('wsgi middleware'):
Middleware = resolve(config.get('wsgi middleware', middleware))
args, kwargs = (), {}
section = 'wsgi %s' % middleware
if config.has_section(section):
if config.has_option(section, 'args'):
args = eval(config.get(section, 'args'))
if config.has_option(section, 'kwargs'):
kwargs = eval(config.get(section, 'kwargs'))
app.wsgi_app = Middleware(app.wsgi_app, *args, **kwargs)
import trytond.bus # noqa: E402,F401
import trytond.protocols.dispatcher # noqa: E402,F401