493 lines
16 KiB
Python
Executable File
493 lines
16 KiB
Python
Executable File
# This file is part of Tryton. The COPYRIGHT file at the top level of
|
|
# this repository contains the full copyright notices and license terms.
|
|
import datetime as dt
|
|
import json
|
|
import logging
|
|
import os
|
|
import selectors
|
|
import threading
|
|
from collections import OrderedDict, defaultdict
|
|
from copy import deepcopy
|
|
from weakref import WeakKeyDictionary
|
|
|
|
from sql import Table
|
|
from sql.aggregate import Max
|
|
from sql.functions import CurrentTimestamp, Function
|
|
|
|
from trytond import backend
|
|
from trytond.config import config
|
|
from trytond.pool import Pool
|
|
from trytond.tools import grouped_slice, resolve
|
|
from trytond.transaction import Transaction
|
|
|
|
__all__ = ['BaseCache', 'Cache', 'LRUDict', 'LRUDictTransaction']
|
|
_clear_timeout = config.getint('cache', 'clean_timeout', default=5 * 60)
|
|
_default_size_limit = config.getint('cache', 'default')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _cast(column):
|
|
class SQLite_DateTime(Function):
|
|
__slots__ = ()
|
|
_function = 'DATETIME'
|
|
|
|
if backend.name == 'sqlite':
|
|
column = SQLite_DateTime(column)
|
|
return column
|
|
|
|
|
|
def freeze(o):
|
|
if isinstance(o, (set, tuple, list)):
|
|
return tuple(freeze(x) for x in o)
|
|
elif isinstance(o, dict):
|
|
return frozenset((x, freeze(y)) for x, y in o.items())
|
|
else:
|
|
return o
|
|
|
|
|
|
def unfreeze(o):
|
|
if isinstance(o, tuple):
|
|
return [unfreeze(x) for x in o]
|
|
elif isinstance(o, frozenset):
|
|
return dict((x, unfreeze(y)) for x, y in o)
|
|
else:
|
|
return o
|
|
|
|
|
|
def _get_modules(cursor):
|
|
ir_module = Table('ir_module')
|
|
cursor.execute(*ir_module.select(
|
|
ir_module.name,
|
|
where=ir_module.state.in_(
|
|
['activated', 'to upgrade', 'to remove'])))
|
|
return {m for m, in cursor}
|
|
|
|
|
|
class BaseCache(object):
|
|
_instances = {}
|
|
context_ignored_keys = {
|
|
'client', '_request', '_check_access', '_skip_warnings',
|
|
}
|
|
|
|
def __init__(
|
|
self, name, duration=None, context=True,
|
|
context_ignored_keys=None):
|
|
assert ((context_ignored_keys is not None and context)
|
|
or (context_ignored_keys is None)), (
|
|
f"context_ignored_keys ({context_ignored_keys}) is not valid"
|
|
f" in regards to context ({context}).")
|
|
self._name = name
|
|
self.size_limit = config.getint(
|
|
'cache', name, default=_default_size_limit)
|
|
self.context = context
|
|
self.context_ignored_keys = set()
|
|
if context and context_ignored_keys:
|
|
self.context_ignored_keys.update(context_ignored_keys)
|
|
self.hit = self.miss = 0
|
|
if isinstance(duration, dt.timedelta):
|
|
self.duration = duration
|
|
elif isinstance(duration, (int, float)):
|
|
self.duration = dt.timedelta(seconds=duration)
|
|
elif duration:
|
|
self.duration = dt.timedelta(**duration)
|
|
else:
|
|
self.duration = None
|
|
assert self._name not in self._instances
|
|
self._instances[self._name] = self
|
|
|
|
@classmethod
|
|
def stats(cls):
|
|
for name, inst in cls._instances.items():
|
|
yield {
|
|
'name': name,
|
|
'hit': inst.hit,
|
|
'miss': inst.miss,
|
|
}
|
|
|
|
def _key(self, key):
|
|
if self.context:
|
|
context = Transaction().context.copy()
|
|
for k in (self.__class__.context_ignored_keys
|
|
| self.context_ignored_keys):
|
|
context.pop(k, None)
|
|
return (key, freeze(context))
|
|
return key
|
|
|
|
def get(self, key, default=None):
|
|
raise NotImplementedError
|
|
|
|
def set(self, key, value):
|
|
raise NotImplementedError
|
|
|
|
def clear(self):
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def clear_all(cls):
|
|
for inst in cls._instances.values():
|
|
inst.clear()
|
|
|
|
@classmethod
|
|
def sync(cls, transaction):
|
|
raise NotImplementedError
|
|
|
|
def sync_since(self, value):
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def commit(cls, transaction):
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def rollback(cls, transaction):
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def drop(cls, dbname):
|
|
raise NotImplementedError
|
|
|
|
|
|
class MemoryCache(BaseCache):
|
|
"""
|
|
A key value LRU cache with size limit.
|
|
"""
|
|
_reset = WeakKeyDictionary()
|
|
_clean_last = None
|
|
_default_lower = Transaction.monotonic_time()
|
|
_listener = {}
|
|
_listener_lock = defaultdict(threading.Lock)
|
|
_table = 'ir_cache'
|
|
_channel = _table
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(MemoryCache, self).__init__(*args, **kwargs)
|
|
self._database_cache = defaultdict(lambda: LRUDict(self.size_limit))
|
|
self._transaction_cache = WeakKeyDictionary()
|
|
self._transaction_lower = {}
|
|
self._timestamp = {}
|
|
|
|
def _get_cache(self):
|
|
transaction = Transaction()
|
|
dbname = transaction.database.name
|
|
lower = self._transaction_lower.get(dbname, self._default_lower)
|
|
if (self._name in self._reset.get(transaction, set())
|
|
or transaction.started_at < lower):
|
|
try:
|
|
return self._transaction_cache[transaction]
|
|
except KeyError:
|
|
cache = self._database_cache.default_factory()
|
|
self._transaction_cache[transaction] = cache
|
|
return cache
|
|
else:
|
|
return self._database_cache[dbname]
|
|
|
|
def get(self, key, default=None):
|
|
key = self._key(key)
|
|
cache = self._get_cache()
|
|
try:
|
|
(expire, result) = cache.pop(key)
|
|
if expire and expire < dt.datetime.now():
|
|
self.miss += 1
|
|
return default
|
|
cache[key] = (expire, result)
|
|
self.hit += 1
|
|
return deepcopy(result)
|
|
except (KeyError, TypeError):
|
|
self.miss += 1
|
|
return default
|
|
|
|
def set(self, key, value):
|
|
key = self._key(key)
|
|
cache = self._get_cache()
|
|
if self.duration:
|
|
expire = dt.datetime.now() + self.duration
|
|
else:
|
|
expire = None
|
|
try:
|
|
cache[key] = (expire, deepcopy(value))
|
|
except TypeError:
|
|
pass
|
|
return value
|
|
|
|
def clear(self):
|
|
transaction = Transaction()
|
|
self._reset.setdefault(transaction, set()).add(self._name)
|
|
self._transaction_cache.pop(transaction, None)
|
|
|
|
def _clear(self, dbname, timestamp=None):
|
|
logger.debug("clearing cache '%s' of '%s'", self._name, dbname)
|
|
self._timestamp[dbname] = timestamp
|
|
self._database_cache[dbname] = self._database_cache.default_factory()
|
|
self._transaction_lower[dbname] = max(
|
|
Transaction.monotonic_time(),
|
|
self._transaction_lower.get(dbname, self._default_lower))
|
|
|
|
@classmethod
|
|
def _clear_all(cls, dbname):
|
|
for inst in cls._instances.values():
|
|
inst._clear(dbname)
|
|
|
|
@classmethod
|
|
def sync(cls, transaction):
|
|
if cls._clean_last is None:
|
|
cls._clean_last = dt.datetime.now()
|
|
return
|
|
|
|
database = transaction.database
|
|
dbname = database.name
|
|
if not _clear_timeout and database.has_channel():
|
|
pid = os.getpid()
|
|
with cls._listener_lock[pid]:
|
|
if (pid, dbname) not in cls._listener:
|
|
cls._listener[pid, dbname] = listener = threading.Thread(
|
|
target=cls._listen, args=(dbname,), daemon=True)
|
|
listener.start()
|
|
return
|
|
last_clean = (dt.datetime.now() - cls._clean_last).total_seconds()
|
|
if last_clean < _clear_timeout:
|
|
return
|
|
connection = database.get_connection(readonly=True, autocommit=True)
|
|
try:
|
|
with connection.cursor() as cursor:
|
|
table = Table(cls._table)
|
|
cursor.execute(*table.select(
|
|
_cast(table.timestamp), table.name))
|
|
timestamps = {}
|
|
for timestamp, name in cursor:
|
|
timestamps[name] = timestamp
|
|
modules = _get_modules(cursor)
|
|
finally:
|
|
database.put_connection(connection)
|
|
for name, timestamp in timestamps.items():
|
|
try:
|
|
inst = cls._instances[name]
|
|
except KeyError:
|
|
continue
|
|
inst_timestamp = inst._timestamp.get(dbname)
|
|
if not inst_timestamp or timestamp > inst_timestamp:
|
|
inst._clear(dbname, timestamp)
|
|
Pool.refresh(dbname, modules)
|
|
cls._clean_last = dt.datetime.now()
|
|
|
|
def sync_since(self, value):
|
|
return self._clean_last > value
|
|
|
|
@classmethod
|
|
def commit(cls, transaction):
|
|
table = Table(cls._table)
|
|
reset = cls._reset.pop(transaction, None)
|
|
if not reset:
|
|
return
|
|
database = transaction.database
|
|
dbname = database.name
|
|
if not _clear_timeout and transaction.database.has_channel():
|
|
with transaction.connection.cursor() as cursor:
|
|
# The count computed as
|
|
# 8000 (max notify size) / 64 (max name data len)
|
|
for sub_reset in grouped_slice(reset, 125):
|
|
cursor.execute(
|
|
'NOTIFY "%s", %%s' % cls._channel,
|
|
(json.dumps(list(sub_reset), separators=(',', ':')),))
|
|
else:
|
|
connection = database.get_connection(
|
|
readonly=False, autocommit=True)
|
|
try:
|
|
with connection.cursor() as cursor:
|
|
for name in reset:
|
|
cursor.execute(*table.select(table.name, table.id,
|
|
table.timestamp,
|
|
where=table.name == name,
|
|
limit=1))
|
|
if cursor.fetchone():
|
|
# It would be better to insert only
|
|
cursor.execute(*table.update([table.timestamp],
|
|
[CurrentTimestamp()],
|
|
where=table.name == name))
|
|
else:
|
|
cursor.execute(*table.insert(
|
|
[table.timestamp, table.name],
|
|
[[CurrentTimestamp(), name]]))
|
|
|
|
cursor.execute(*table.select(
|
|
Max(table.timestamp),
|
|
where=table.name == name))
|
|
timestamp, = cursor.fetchone()
|
|
|
|
cursor.execute(*table.select(
|
|
_cast(Max(table.timestamp)),
|
|
where=table.name == name))
|
|
timestamp, = cursor.fetchone()
|
|
|
|
inst = cls._instances[name]
|
|
inst._clear(dbname, timestamp)
|
|
connection.commit()
|
|
finally:
|
|
database.put_connection(connection)
|
|
cls._clean_last = dt.datetime.now()
|
|
reset.clear()
|
|
|
|
@classmethod
|
|
def rollback(cls, transaction):
|
|
cls._reset.pop(transaction, None)
|
|
|
|
@classmethod
|
|
def drop(cls, dbname):
|
|
pid = os.getpid()
|
|
with cls._listener_lock[pid]:
|
|
listener = cls._listener.pop((pid, dbname), None)
|
|
if listener:
|
|
database = backend.Database(dbname)
|
|
conn = database.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute('NOTIFY "%s"' % cls._channel)
|
|
conn.commit()
|
|
finally:
|
|
database.put_connection(conn)
|
|
listener.join()
|
|
for inst in cls._instances.values():
|
|
inst._timestamp.pop(dbname, None)
|
|
inst._database_cache.pop(dbname, None)
|
|
inst._transaction_lower.pop(dbname, None)
|
|
|
|
@classmethod
|
|
def refresh_pool(cls, transaction):
|
|
database = transaction.database
|
|
dbname = database.name
|
|
if not _clear_timeout and database.has_channel():
|
|
database = backend.Database(dbname)
|
|
conn = database.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
'NOTIFY "%s", %%s' % cls._channel, ('refresh pool',))
|
|
conn.commit()
|
|
finally:
|
|
database.put_connection(conn)
|
|
|
|
@classmethod
|
|
def _listen(cls, dbname):
|
|
current_thread = threading.current_thread()
|
|
pid = os.getpid()
|
|
|
|
conn, selector = None, None
|
|
try:
|
|
database = backend.Database(dbname)
|
|
if not database.has_channel():
|
|
raise NotImplementedError
|
|
|
|
logger.info(
|
|
"listening on channel '%s' of '%s'", cls._channel, dbname)
|
|
conn = database.get_connection(autocommit=True)
|
|
selector = selectors.DefaultSelector()
|
|
|
|
cursor = conn.cursor()
|
|
cursor.execute('LISTEN "%s"' % cls._channel)
|
|
# Clear everything in case we missed a payload
|
|
Pool.refresh(dbname, _get_modules(cursor))
|
|
cls._clear_all(dbname)
|
|
current_thread.listening = True
|
|
|
|
selector.register(conn, selectors.EVENT_READ)
|
|
while cls._listener.get((pid, dbname)) == current_thread:
|
|
selector.select(timeout=60)
|
|
conn.poll()
|
|
while conn.notifies:
|
|
notification = conn.notifies.pop()
|
|
if notification.payload == 'refresh pool':
|
|
Pool.refresh(dbname, _get_modules(cursor))
|
|
elif notification.payload:
|
|
reset = json.loads(notification.payload)
|
|
for name in reset:
|
|
inst = cls._instances[name]
|
|
inst._clear(dbname)
|
|
cls._clean_last = dt.datetime.now()
|
|
except Exception:
|
|
logger.error(
|
|
"cache listener on '%s' crashed", dbname, exc_info=True)
|
|
raise
|
|
finally:
|
|
if selector:
|
|
selector.close()
|
|
if conn:
|
|
database.put_connection(conn)
|
|
with cls._listener_lock[pid]:
|
|
if cls._listener.get((pid, dbname)) == current_thread:
|
|
del cls._listener[pid, dbname]
|
|
|
|
|
|
if config.get('cache', 'class'):
|
|
Cache = resolve(config.get('cache', 'class'))
|
|
else:
|
|
Cache = MemoryCache
|
|
|
|
|
|
class LRUDict(OrderedDict):
|
|
"""
|
|
Dictionary with a size limit.
|
|
If size limit is reached, it will remove the first added items.
|
|
The default_factory provides the same behavior as in standard
|
|
collections.defaultdict.
|
|
If default_factory_with_key is set, the default_factory is called with the
|
|
missing key.
|
|
"""
|
|
__slots__ = ('size_limit',)
|
|
|
|
def __init__(self, size_limit,
|
|
default_factory=None, default_factory_with_key=False,
|
|
*args, **kwargs):
|
|
assert size_limit > 0
|
|
self.size_limit = size_limit
|
|
super(LRUDict, self).__init__(*args, **kwargs)
|
|
self.default_factory = default_factory
|
|
self.default_factory_with_key = default_factory_with_key
|
|
self._check_size_limit()
|
|
|
|
def __setitem__(self, key, value):
|
|
super(LRUDict, self).__setitem__(key, value)
|
|
self._check_size_limit()
|
|
|
|
def __missing__(self, key):
|
|
if self.default_factory is None:
|
|
raise KeyError(key)
|
|
if self.default_factory_with_key:
|
|
value = self.default_factory(key)
|
|
else:
|
|
value = self.default_factory()
|
|
self[key] = value
|
|
return value
|
|
|
|
def update(self, *args, **kwargs):
|
|
super(LRUDict, self).update(*args, **kwargs)
|
|
self._check_size_limit()
|
|
|
|
def setdefault(self, key, default=None):
|
|
default = super(LRUDict, self).setdefault(key, default=default)
|
|
self._check_size_limit()
|
|
return default
|
|
|
|
def _check_size_limit(self):
|
|
while len(self) > self.size_limit:
|
|
self.popitem(last=False)
|
|
|
|
|
|
class LRUDictTransaction(LRUDict):
|
|
"""
|
|
Dictionary with a size limit and default_factory. (see LRUDict)
|
|
It is refreshed when transaction counter is changed.
|
|
"""
|
|
__slots__ = ('transaction', 'counter')
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(LRUDictTransaction, self).__init__(*args, **kwargs)
|
|
self.transaction = Transaction()
|
|
self.counter = self.transaction.counter
|
|
|
|
def clear(self):
|
|
super(LRUDictTransaction, self).clear()
|
|
self.counter = self.transaction.counter
|
|
|
|
def refresh(self):
|
|
if self.counter != self.transaction.counter:
|
|
self.clear()
|