Initial import from Docker volume
This commit is contained in:
269
bus.py
Executable file
269
bus.py
Executable file
@@ -0,0 +1,269 @@
|
||||
# This file is part of Tryton. The COPYRIGHT file at the top level of
|
||||
# this repository contains the full copyright notices and license terms.
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import selectors
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from trytond import backend
|
||||
from trytond.config import config
|
||||
from trytond.protocols.jsonrpc import JSONDecoder, JSONEncoder
|
||||
from trytond.protocols.wrappers import (
|
||||
HTTPStatus, Response, exceptions, redirect)
|
||||
from trytond.tools import resolve
|
||||
from trytond.transaction import Transaction
|
||||
from trytond.wsgi import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_db_timeout = config.getint('database', 'timeout')
|
||||
_cache_timeout = config.getint('bus', 'cache_timeout')
|
||||
_select_timeout = config.getint('bus', 'select_timeout')
|
||||
_long_polling_timeout = config.getint('bus', 'long_polling_timeout')
|
||||
_allow_subscribe = config.getboolean('bus', 'allow_subscribe')
|
||||
_url_host = config.get('bus', 'url_host')
|
||||
_web_cache_timeout = config.getint('web', 'cache_timeout')
|
||||
|
||||
|
||||
class _MessageQueue:
|
||||
|
||||
Message = collections.namedtuple('Message', 'channel content timestamp')
|
||||
|
||||
def __init__(self, timeout):
|
||||
super().__init__()
|
||||
self._lock = collections.defaultdict(threading.Lock)
|
||||
self._timeout = timeout
|
||||
self._messages = []
|
||||
|
||||
def append(self, channel, element):
|
||||
self._messages.append(
|
||||
self.Message(channel, element, time.time()))
|
||||
|
||||
def get_next(self, channels, from_id=None):
|
||||
oldest = time.time() - self._timeout
|
||||
to_delete_index = 0
|
||||
found = False
|
||||
first_message = None
|
||||
message = self.Message(None, None, None)
|
||||
for idx, item in enumerate(self._messages):
|
||||
if item.timestamp < oldest:
|
||||
to_delete_index = idx
|
||||
continue
|
||||
if item.channel not in channels:
|
||||
continue
|
||||
if not first_message:
|
||||
first_message = item
|
||||
if from_id is None or found:
|
||||
message = item
|
||||
break
|
||||
found = item.content['message_id'] == from_id
|
||||
else:
|
||||
if first_message and not found:
|
||||
message = first_message
|
||||
|
||||
with self._lock[os.getpid()]:
|
||||
del self._messages[:to_delete_index]
|
||||
|
||||
return message.channel, message.content
|
||||
|
||||
|
||||
class LongPollingBus:
|
||||
|
||||
_channel = 'bus'
|
||||
_queues_lock = collections.defaultdict(threading.Lock)
|
||||
_queues = collections.defaultdict(
|
||||
lambda: {'timeout': None, 'events': collections.defaultdict(list)})
|
||||
_messages = {}
|
||||
|
||||
@classmethod
|
||||
def subscribe(cls, database, channels, last_message=None):
|
||||
pid = os.getpid()
|
||||
with cls._queues_lock[pid]:
|
||||
start_listener = (pid, database) not in cls._queues
|
||||
cls._queues[pid, database]['timeout'] = time.time() + _db_timeout
|
||||
if start_listener:
|
||||
listener = threading.Thread(
|
||||
target=cls._listen, args=(database,), daemon=True)
|
||||
cls._queues[pid, database]['listener'] = listener
|
||||
listener.start()
|
||||
|
||||
messages = cls._messages.get(database)
|
||||
if messages:
|
||||
channel, content = messages.get_next(channels, last_message)
|
||||
if content:
|
||||
return cls.create_response(channel, content)
|
||||
|
||||
event = threading.Event()
|
||||
for channel in channels:
|
||||
if channel in cls._queues[pid, database]['events']:
|
||||
event_channel = cls._queues[pid, database]['events'][channel]
|
||||
else:
|
||||
with cls._queues_lock[pid]:
|
||||
event_channel = cls._queues[pid, database][
|
||||
'events'][channel]
|
||||
event_channel.append(event)
|
||||
|
||||
triggered = event.wait(_long_polling_timeout)
|
||||
if not triggered:
|
||||
response = cls.create_response(None, None)
|
||||
else:
|
||||
response = cls.create_response(
|
||||
*cls._messages[database].get_next(channels, last_message))
|
||||
|
||||
with cls._queues_lock[pid]:
|
||||
for channel in channels:
|
||||
events = cls._queues[pid, database]['events'][channel]
|
||||
for e in events[:]:
|
||||
if e.is_set():
|
||||
events.remove(e)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def create_response(cls, channel, message):
|
||||
response_data = {
|
||||
'message': message,
|
||||
'channel': channel,
|
||||
}
|
||||
logger.debug('Bus: %s', response_data)
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, database):
|
||||
db = backend.Database(database)
|
||||
if not db.has_channel():
|
||||
raise exceptions.NotImplemented
|
||||
|
||||
logger.info("listening on channel '%s'", cls._channel)
|
||||
conn = db.get_connection(autocommit=True)
|
||||
pid = os.getpid()
|
||||
selector = selectors.DefaultSelector()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('LISTEN "%s"' % cls._channel)
|
||||
|
||||
cls._messages[database] = messages = _MessageQueue(_cache_timeout)
|
||||
|
||||
now = time.time()
|
||||
selector.register(conn, selectors.EVENT_READ)
|
||||
while cls._queues[pid, database]['timeout'] > now:
|
||||
selector.select(timeout=_select_timeout)
|
||||
conn.poll()
|
||||
while conn.notifies:
|
||||
notification = conn.notifies.pop()
|
||||
payload = json.loads(
|
||||
notification.payload,
|
||||
object_hook=JSONDecoder())
|
||||
channel = payload['channel']
|
||||
message = payload['message']
|
||||
messages.append(channel, message)
|
||||
|
||||
with cls._queues_lock[pid]:
|
||||
events = cls._queues[pid, database][
|
||||
'events'][channel].copy()
|
||||
cls._queues[pid, database]['events'][channel].clear()
|
||||
for event in events:
|
||||
event.set()
|
||||
now = time.time()
|
||||
except Exception:
|
||||
logger.error('bus listener on "%s" crashed', database,
|
||||
exc_info=True)
|
||||
|
||||
with cls._queues_lock[pid]:
|
||||
del cls._queues[pid, database]
|
||||
raise
|
||||
finally:
|
||||
selector.close()
|
||||
db.put_connection(conn)
|
||||
|
||||
with cls._queues_lock[pid]:
|
||||
if cls._queues[pid, database]['timeout'] <= now:
|
||||
del cls._queues[pid, database]
|
||||
else:
|
||||
# A query arrived between the end of the while and here
|
||||
listener = threading.Thread(
|
||||
target=cls._listen, args=(database,), daemon=True)
|
||||
cls._queues[pid, database]['listener'] = listener
|
||||
listener.start()
|
||||
|
||||
@classmethod
|
||||
def publish(cls, channel, message):
|
||||
transaction = Transaction()
|
||||
if not transaction.database.has_channel():
|
||||
logger.debug('Database backend do not support channels')
|
||||
return
|
||||
|
||||
cursor = transaction.connection.cursor()
|
||||
message['message_id'] = str(uuid.uuid4())
|
||||
payload = json.dumps({
|
||||
'channel': channel,
|
||||
'message': message,
|
||||
}, cls=JSONEncoder, separators=(',', ':'))
|
||||
cursor.execute('NOTIFY "%s", %%s' % cls._channel, (payload,))
|
||||
|
||||
|
||||
if config.get('bus', 'class'):
|
||||
Bus = resolve(config.get('bus', 'class'))
|
||||
else:
|
||||
Bus = LongPollingBus
|
||||
|
||||
|
||||
@app.route('/<string:database_name>/bus', methods=['POST'])
|
||||
@app.auth_required
|
||||
def subscribe(request, database_name):
|
||||
if not _allow_subscribe:
|
||||
raise exceptions.NotImplemented
|
||||
if _url_host and _url_host != request.host_url:
|
||||
response = redirect(
|
||||
urljoin(_url_host, request.path), HTTPStatus.PERMANENT_REDIRECT)
|
||||
# Allow to change the redirection after some time
|
||||
response.headers['Cache-Control'] = (
|
||||
'private, max-age=%s' % _web_cache_timeout)
|
||||
return response
|
||||
user = request.authorization.get('userid')
|
||||
channels = request.parsed_data.get('channels', [])
|
||||
if user is None:
|
||||
raise exceptions.BadRequest
|
||||
|
||||
channels = set(filter(lambda c: not c.startswith('user:'), channels))
|
||||
channels.add('user:%s' % user)
|
||||
|
||||
last_message = request.parsed_data.get('last_message')
|
||||
|
||||
logger.debug(
|
||||
"getting bus messages from %s@%s%s for %s since %s",
|
||||
request.authorization.username, request.remote_addr, request.path,
|
||||
channels, last_message)
|
||||
bus_response = Bus.subscribe(database_name, channels, last_message)
|
||||
return Response(
|
||||
json.dumps(bus_response, cls=JSONEncoder, separators=(',', ':')),
|
||||
content_type='application/json')
|
||||
|
||||
|
||||
def notify(title, body=None, priority=1, user=None, client=None):
|
||||
if user is None:
|
||||
if client is None:
|
||||
context_client = Transaction().context.get('client')
|
||||
if context_client:
|
||||
channel = 'client:%s' % context_client
|
||||
else:
|
||||
return
|
||||
else:
|
||||
channel = 'client:%s' % client
|
||||
elif client is None:
|
||||
channel = 'user:%s' % user
|
||||
else:
|
||||
channel = 'client:%s' % client
|
||||
|
||||
return Bus.publish(channel, {
|
||||
'type': 'notification',
|
||||
'title': title,
|
||||
'body': body,
|
||||
'priority': priority,
|
||||
})
|
||||
Reference in New Issue
Block a user