Initial import from Docker volume
This commit is contained in:
492
cache.py
Executable file
492
cache.py
Executable file
@@ -0,0 +1,492 @@
|
||||
# This file is part of Tryton. The COPYRIGHT file at the top level of
|
||||
# this repository contains the full copyright notices and license terms.
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import selectors
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from sql import Table
|
||||
from sql.aggregate import Max
|
||||
from sql.functions import CurrentTimestamp, Function
|
||||
|
||||
from trytond import backend
|
||||
from trytond.config import config
|
||||
from trytond.pool import Pool
|
||||
from trytond.tools import grouped_slice, resolve
|
||||
from trytond.transaction import Transaction
|
||||
|
||||
__all__ = ['BaseCache', 'Cache', 'LRUDict', 'LRUDictTransaction']
|
||||
_clear_timeout = config.getint('cache', 'clean_timeout', default=5 * 60)
|
||||
_default_size_limit = config.getint('cache', 'default')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cast(column):
|
||||
class SQLite_DateTime(Function):
|
||||
__slots__ = ()
|
||||
_function = 'DATETIME'
|
||||
|
||||
if backend.name == 'sqlite':
|
||||
column = SQLite_DateTime(column)
|
||||
return column
|
||||
|
||||
|
||||
def freeze(o):
|
||||
if isinstance(o, (set, tuple, list)):
|
||||
return tuple(freeze(x) for x in o)
|
||||
elif isinstance(o, dict):
|
||||
return frozenset((x, freeze(y)) for x, y in o.items())
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
def unfreeze(o):
|
||||
if isinstance(o, tuple):
|
||||
return [unfreeze(x) for x in o]
|
||||
elif isinstance(o, frozenset):
|
||||
return dict((x, unfreeze(y)) for x, y in o)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
def _get_modules(cursor):
|
||||
ir_module = Table('ir_module')
|
||||
cursor.execute(*ir_module.select(
|
||||
ir_module.name,
|
||||
where=ir_module.state.in_(
|
||||
['activated', 'to upgrade', 'to remove'])))
|
||||
return {m for m, in cursor}
|
||||
|
||||
|
||||
class BaseCache(object):
|
||||
_instances = {}
|
||||
context_ignored_keys = {
|
||||
'client', '_request', '_check_access', '_skip_warnings',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, name, duration=None, context=True,
|
||||
context_ignored_keys=None):
|
||||
assert ((context_ignored_keys is not None and context)
|
||||
or (context_ignored_keys is None)), (
|
||||
f"context_ignored_keys ({context_ignored_keys}) is not valid"
|
||||
f" in regards to context ({context}).")
|
||||
self._name = name
|
||||
self.size_limit = config.getint(
|
||||
'cache', name, default=_default_size_limit)
|
||||
self.context = context
|
||||
self.context_ignored_keys = set()
|
||||
if context and context_ignored_keys:
|
||||
self.context_ignored_keys.update(context_ignored_keys)
|
||||
self.hit = self.miss = 0
|
||||
if isinstance(duration, dt.timedelta):
|
||||
self.duration = duration
|
||||
elif isinstance(duration, (int, float)):
|
||||
self.duration = dt.timedelta(seconds=duration)
|
||||
elif duration:
|
||||
self.duration = dt.timedelta(**duration)
|
||||
else:
|
||||
self.duration = None
|
||||
assert self._name not in self._instances
|
||||
self._instances[self._name] = self
|
||||
|
||||
@classmethod
|
||||
def stats(cls):
|
||||
for name, inst in cls._instances.items():
|
||||
yield {
|
||||
'name': name,
|
||||
'hit': inst.hit,
|
||||
'miss': inst.miss,
|
||||
}
|
||||
|
||||
def _key(self, key):
|
||||
if self.context:
|
||||
context = Transaction().context.copy()
|
||||
for k in (self.__class__.context_ignored_keys
|
||||
| self.context_ignored_keys):
|
||||
context.pop(k, None)
|
||||
return (key, freeze(context))
|
||||
return key
|
||||
|
||||
def get(self, key, default=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def set(self, key, value):
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def clear_all(cls):
|
||||
for inst in cls._instances.values():
|
||||
inst.clear()
|
||||
|
||||
@classmethod
|
||||
def sync(cls, transaction):
|
||||
raise NotImplementedError
|
||||
|
||||
def sync_since(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def commit(cls, transaction):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def rollback(cls, transaction):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def drop(cls, dbname):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MemoryCache(BaseCache):
|
||||
"""
|
||||
A key value LRU cache with size limit.
|
||||
"""
|
||||
_reset = WeakKeyDictionary()
|
||||
_clean_last = None
|
||||
_default_lower = Transaction.monotonic_time()
|
||||
_listener = {}
|
||||
_listener_lock = defaultdict(threading.Lock)
|
||||
_table = 'ir_cache'
|
||||
_channel = _table
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MemoryCache, self).__init__(*args, **kwargs)
|
||||
self._database_cache = defaultdict(lambda: LRUDict(self.size_limit))
|
||||
self._transaction_cache = WeakKeyDictionary()
|
||||
self._transaction_lower = {}
|
||||
self._timestamp = {}
|
||||
|
||||
def _get_cache(self):
|
||||
transaction = Transaction()
|
||||
dbname = transaction.database.name
|
||||
lower = self._transaction_lower.get(dbname, self._default_lower)
|
||||
if (self._name in self._reset.get(transaction, set())
|
||||
or transaction.started_at < lower):
|
||||
try:
|
||||
return self._transaction_cache[transaction]
|
||||
except KeyError:
|
||||
cache = self._database_cache.default_factory()
|
||||
self._transaction_cache[transaction] = cache
|
||||
return cache
|
||||
else:
|
||||
return self._database_cache[dbname]
|
||||
|
||||
def get(self, key, default=None):
|
||||
key = self._key(key)
|
||||
cache = self._get_cache()
|
||||
try:
|
||||
(expire, result) = cache.pop(key)
|
||||
if expire and expire < dt.datetime.now():
|
||||
self.miss += 1
|
||||
return default
|
||||
cache[key] = (expire, result)
|
||||
self.hit += 1
|
||||
return deepcopy(result)
|
||||
except (KeyError, TypeError):
|
||||
self.miss += 1
|
||||
return default
|
||||
|
||||
def set(self, key, value):
|
||||
key = self._key(key)
|
||||
cache = self._get_cache()
|
||||
if self.duration:
|
||||
expire = dt.datetime.now() + self.duration
|
||||
else:
|
||||
expire = None
|
||||
try:
|
||||
cache[key] = (expire, deepcopy(value))
|
||||
except TypeError:
|
||||
pass
|
||||
return value
|
||||
|
||||
def clear(self):
|
||||
transaction = Transaction()
|
||||
self._reset.setdefault(transaction, set()).add(self._name)
|
||||
self._transaction_cache.pop(transaction, None)
|
||||
|
||||
def _clear(self, dbname, timestamp=None):
|
||||
logger.debug("clearing cache '%s' of '%s'", self._name, dbname)
|
||||
self._timestamp[dbname] = timestamp
|
||||
self._database_cache[dbname] = self._database_cache.default_factory()
|
||||
self._transaction_lower[dbname] = max(
|
||||
Transaction.monotonic_time(),
|
||||
self._transaction_lower.get(dbname, self._default_lower))
|
||||
|
||||
@classmethod
|
||||
def _clear_all(cls, dbname):
|
||||
for inst in cls._instances.values():
|
||||
inst._clear(dbname)
|
||||
|
||||
@classmethod
|
||||
def sync(cls, transaction):
|
||||
if cls._clean_last is None:
|
||||
cls._clean_last = dt.datetime.now()
|
||||
return
|
||||
|
||||
database = transaction.database
|
||||
dbname = database.name
|
||||
if not _clear_timeout and database.has_channel():
|
||||
pid = os.getpid()
|
||||
with cls._listener_lock[pid]:
|
||||
if (pid, dbname) not in cls._listener:
|
||||
cls._listener[pid, dbname] = listener = threading.Thread(
|
||||
target=cls._listen, args=(dbname,), daemon=True)
|
||||
listener.start()
|
||||
return
|
||||
last_clean = (dt.datetime.now() - cls._clean_last).total_seconds()
|
||||
if last_clean < _clear_timeout:
|
||||
return
|
||||
connection = database.get_connection(readonly=True, autocommit=True)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
table = Table(cls._table)
|
||||
cursor.execute(*table.select(
|
||||
_cast(table.timestamp), table.name))
|
||||
timestamps = {}
|
||||
for timestamp, name in cursor:
|
||||
timestamps[name] = timestamp
|
||||
modules = _get_modules(cursor)
|
||||
finally:
|
||||
database.put_connection(connection)
|
||||
for name, timestamp in timestamps.items():
|
||||
try:
|
||||
inst = cls._instances[name]
|
||||
except KeyError:
|
||||
continue
|
||||
inst_timestamp = inst._timestamp.get(dbname)
|
||||
if not inst_timestamp or timestamp > inst_timestamp:
|
||||
inst._clear(dbname, timestamp)
|
||||
Pool.refresh(dbname, modules)
|
||||
cls._clean_last = dt.datetime.now()
|
||||
|
||||
def sync_since(self, value):
|
||||
return self._clean_last > value
|
||||
|
||||
@classmethod
|
||||
def commit(cls, transaction):
|
||||
table = Table(cls._table)
|
||||
reset = cls._reset.pop(transaction, None)
|
||||
if not reset:
|
||||
return
|
||||
database = transaction.database
|
||||
dbname = database.name
|
||||
if not _clear_timeout and transaction.database.has_channel():
|
||||
with transaction.connection.cursor() as cursor:
|
||||
# The count computed as
|
||||
# 8000 (max notify size) / 64 (max name data len)
|
||||
for sub_reset in grouped_slice(reset, 125):
|
||||
cursor.execute(
|
||||
'NOTIFY "%s", %%s' % cls._channel,
|
||||
(json.dumps(list(sub_reset), separators=(',', ':')),))
|
||||
else:
|
||||
connection = database.get_connection(
|
||||
readonly=False, autocommit=True)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
for name in reset:
|
||||
cursor.execute(*table.select(table.name, table.id,
|
||||
table.timestamp,
|
||||
where=table.name == name,
|
||||
limit=1))
|
||||
if cursor.fetchone():
|
||||
# It would be better to insert only
|
||||
cursor.execute(*table.update([table.timestamp],
|
||||
[CurrentTimestamp()],
|
||||
where=table.name == name))
|
||||
else:
|
||||
cursor.execute(*table.insert(
|
||||
[table.timestamp, table.name],
|
||||
[[CurrentTimestamp(), name]]))
|
||||
|
||||
cursor.execute(*table.select(
|
||||
Max(table.timestamp),
|
||||
where=table.name == name))
|
||||
timestamp, = cursor.fetchone()
|
||||
|
||||
cursor.execute(*table.select(
|
||||
_cast(Max(table.timestamp)),
|
||||
where=table.name == name))
|
||||
timestamp, = cursor.fetchone()
|
||||
|
||||
inst = cls._instances[name]
|
||||
inst._clear(dbname, timestamp)
|
||||
connection.commit()
|
||||
finally:
|
||||
database.put_connection(connection)
|
||||
cls._clean_last = dt.datetime.now()
|
||||
reset.clear()
|
||||
|
||||
@classmethod
|
||||
def rollback(cls, transaction):
|
||||
cls._reset.pop(transaction, None)
|
||||
|
||||
@classmethod
|
||||
def drop(cls, dbname):
|
||||
pid = os.getpid()
|
||||
with cls._listener_lock[pid]:
|
||||
listener = cls._listener.pop((pid, dbname), None)
|
||||
if listener:
|
||||
database = backend.Database(dbname)
|
||||
conn = database.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('NOTIFY "%s"' % cls._channel)
|
||||
conn.commit()
|
||||
finally:
|
||||
database.put_connection(conn)
|
||||
listener.join()
|
||||
for inst in cls._instances.values():
|
||||
inst._timestamp.pop(dbname, None)
|
||||
inst._database_cache.pop(dbname, None)
|
||||
inst._transaction_lower.pop(dbname, None)
|
||||
|
||||
@classmethod
|
||||
def refresh_pool(cls, transaction):
|
||||
database = transaction.database
|
||||
dbname = database.name
|
||||
if not _clear_timeout and database.has_channel():
|
||||
database = backend.Database(dbname)
|
||||
conn = database.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
'NOTIFY "%s", %%s' % cls._channel, ('refresh pool',))
|
||||
conn.commit()
|
||||
finally:
|
||||
database.put_connection(conn)
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, dbname):
|
||||
current_thread = threading.current_thread()
|
||||
pid = os.getpid()
|
||||
|
||||
conn, selector = None, None
|
||||
try:
|
||||
database = backend.Database(dbname)
|
||||
if not database.has_channel():
|
||||
raise NotImplementedError
|
||||
|
||||
logger.info(
|
||||
"listening on channel '%s' of '%s'", cls._channel, dbname)
|
||||
conn = database.get_connection(autocommit=True)
|
||||
selector = selectors.DefaultSelector()
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('LISTEN "%s"' % cls._channel)
|
||||
# Clear everything in case we missed a payload
|
||||
Pool.refresh(dbname, _get_modules(cursor))
|
||||
cls._clear_all(dbname)
|
||||
current_thread.listening = True
|
||||
|
||||
selector.register(conn, selectors.EVENT_READ)
|
||||
while cls._listener.get((pid, dbname)) == current_thread:
|
||||
selector.select(timeout=60)
|
||||
conn.poll()
|
||||
while conn.notifies:
|
||||
notification = conn.notifies.pop()
|
||||
if notification.payload == 'refresh pool':
|
||||
Pool.refresh(dbname, _get_modules(cursor))
|
||||
elif notification.payload:
|
||||
reset = json.loads(notification.payload)
|
||||
for name in reset:
|
||||
inst = cls._instances[name]
|
||||
inst._clear(dbname)
|
||||
cls._clean_last = dt.datetime.now()
|
||||
except Exception:
|
||||
logger.error(
|
||||
"cache listener on '%s' crashed", dbname, exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if selector:
|
||||
selector.close()
|
||||
if conn:
|
||||
database.put_connection(conn)
|
||||
with cls._listener_lock[pid]:
|
||||
if cls._listener.get((pid, dbname)) == current_thread:
|
||||
del cls._listener[pid, dbname]
|
||||
|
||||
|
||||
if config.get('cache', 'class'):
|
||||
Cache = resolve(config.get('cache', 'class'))
|
||||
else:
|
||||
Cache = MemoryCache
|
||||
|
||||
|
||||
class LRUDict(OrderedDict):
|
||||
"""
|
||||
Dictionary with a size limit.
|
||||
If size limit is reached, it will remove the first added items.
|
||||
The default_factory provides the same behavior as in standard
|
||||
collections.defaultdict.
|
||||
If default_factory_with_key is set, the default_factory is called with the
|
||||
missing key.
|
||||
"""
|
||||
__slots__ = ('size_limit',)
|
||||
|
||||
def __init__(self, size_limit,
|
||||
default_factory=None, default_factory_with_key=False,
|
||||
*args, **kwargs):
|
||||
assert size_limit > 0
|
||||
self.size_limit = size_limit
|
||||
super(LRUDict, self).__init__(*args, **kwargs)
|
||||
self.default_factory = default_factory
|
||||
self.default_factory_with_key = default_factory_with_key
|
||||
self._check_size_limit()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(LRUDict, self).__setitem__(key, value)
|
||||
self._check_size_limit()
|
||||
|
||||
def __missing__(self, key):
|
||||
if self.default_factory is None:
|
||||
raise KeyError(key)
|
||||
if self.default_factory_with_key:
|
||||
value = self.default_factory(key)
|
||||
else:
|
||||
value = self.default_factory()
|
||||
self[key] = value
|
||||
return value
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
super(LRUDict, self).update(*args, **kwargs)
|
||||
self._check_size_limit()
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
default = super(LRUDict, self).setdefault(key, default=default)
|
||||
self._check_size_limit()
|
||||
return default
|
||||
|
||||
def _check_size_limit(self):
|
||||
while len(self) > self.size_limit:
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
class LRUDictTransaction(LRUDict):
|
||||
"""
|
||||
Dictionary with a size limit and default_factory. (see LRUDict)
|
||||
It is refreshed when transaction counter is changed.
|
||||
"""
|
||||
__slots__ = ('transaction', 'counter')
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LRUDictTransaction, self).__init__(*args, **kwargs)
|
||||
self.transaction = Transaction()
|
||||
self.counter = self.transaction.counter
|
||||
|
||||
def clear(self):
|
||||
super(LRUDictTransaction, self).clear()
|
||||
self.counter = self.transaction.counter
|
||||
|
||||
def refresh(self):
|
||||
if self.counter != self.transaction.counter:
|
||||
self.clear()
|
||||
Reference in New Issue
Block a user