Files
tradon/cache.py
2025-12-26 13:11:43 +00:00

493 lines
16 KiB
Python
Executable File

# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import datetime as dt
import json
import logging
import os
import selectors
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from weakref import WeakKeyDictionary
from sql import Table
from sql.aggregate import Max
from sql.functions import CurrentTimestamp, Function
from trytond import backend
from trytond.config import config
from trytond.pool import Pool
from trytond.tools import grouped_slice, resolve
from trytond.transaction import Transaction
__all__ = ['BaseCache', 'Cache', 'LRUDict', 'LRUDictTransaction']
_clear_timeout = config.getint('cache', 'clean_timeout', default=5 * 60)
_default_size_limit = config.getint('cache', 'default')
logger = logging.getLogger(__name__)
def _cast(column):
class SQLite_DateTime(Function):
__slots__ = ()
_function = 'DATETIME'
if backend.name == 'sqlite':
column = SQLite_DateTime(column)
return column
def freeze(o):
if isinstance(o, (set, tuple, list)):
return tuple(freeze(x) for x in o)
elif isinstance(o, dict):
return frozenset((x, freeze(y)) for x, y in o.items())
else:
return o
def unfreeze(o):
if isinstance(o, tuple):
return [unfreeze(x) for x in o]
elif isinstance(o, frozenset):
return dict((x, unfreeze(y)) for x, y in o)
else:
return o
def _get_modules(cursor):
ir_module = Table('ir_module')
cursor.execute(*ir_module.select(
ir_module.name,
where=ir_module.state.in_(
['activated', 'to upgrade', 'to remove'])))
return {m for m, in cursor}
class BaseCache(object):
_instances = {}
context_ignored_keys = {
'client', '_request', '_check_access', '_skip_warnings',
}
def __init__(
self, name, duration=None, context=True,
context_ignored_keys=None):
assert ((context_ignored_keys is not None and context)
or (context_ignored_keys is None)), (
f"context_ignored_keys ({context_ignored_keys}) is not valid"
f" in regards to context ({context}).")
self._name = name
self.size_limit = config.getint(
'cache', name, default=_default_size_limit)
self.context = context
self.context_ignored_keys = set()
if context and context_ignored_keys:
self.context_ignored_keys.update(context_ignored_keys)
self.hit = self.miss = 0
if isinstance(duration, dt.timedelta):
self.duration = duration
elif isinstance(duration, (int, float)):
self.duration = dt.timedelta(seconds=duration)
elif duration:
self.duration = dt.timedelta(**duration)
else:
self.duration = None
assert self._name not in self._instances
self._instances[self._name] = self
@classmethod
def stats(cls):
for name, inst in cls._instances.items():
yield {
'name': name,
'hit': inst.hit,
'miss': inst.miss,
}
def _key(self, key):
if self.context:
context = Transaction().context.copy()
for k in (self.__class__.context_ignored_keys
| self.context_ignored_keys):
context.pop(k, None)
return (key, freeze(context))
return key
def get(self, key, default=None):
raise NotImplementedError
def set(self, key, value):
raise NotImplementedError
def clear(self):
raise NotImplementedError
@classmethod
def clear_all(cls):
for inst in cls._instances.values():
inst.clear()
@classmethod
def sync(cls, transaction):
raise NotImplementedError
def sync_since(self, value):
raise NotImplementedError
@classmethod
def commit(cls, transaction):
raise NotImplementedError
@classmethod
def rollback(cls, transaction):
raise NotImplementedError
@classmethod
def drop(cls, dbname):
raise NotImplementedError
class MemoryCache(BaseCache):
"""
A key value LRU cache with size limit.
"""
_reset = WeakKeyDictionary()
_clean_last = None
_default_lower = Transaction.monotonic_time()
_listener = {}
_listener_lock = defaultdict(threading.Lock)
_table = 'ir_cache'
_channel = _table
def __init__(self, *args, **kwargs):
super(MemoryCache, self).__init__(*args, **kwargs)
self._database_cache = defaultdict(lambda: LRUDict(self.size_limit))
self._transaction_cache = WeakKeyDictionary()
self._transaction_lower = {}
self._timestamp = {}
def _get_cache(self):
transaction = Transaction()
dbname = transaction.database.name
lower = self._transaction_lower.get(dbname, self._default_lower)
if (self._name in self._reset.get(transaction, set())
or transaction.started_at < lower):
try:
return self._transaction_cache[transaction]
except KeyError:
cache = self._database_cache.default_factory()
self._transaction_cache[transaction] = cache
return cache
else:
return self._database_cache[dbname]
def get(self, key, default=None):
key = self._key(key)
cache = self._get_cache()
try:
(expire, result) = cache.pop(key)
if expire and expire < dt.datetime.now():
self.miss += 1
return default
cache[key] = (expire, result)
self.hit += 1
return deepcopy(result)
except (KeyError, TypeError):
self.miss += 1
return default
def set(self, key, value):
key = self._key(key)
cache = self._get_cache()
if self.duration:
expire = dt.datetime.now() + self.duration
else:
expire = None
try:
cache[key] = (expire, deepcopy(value))
except TypeError:
pass
return value
def clear(self):
transaction = Transaction()
self._reset.setdefault(transaction, set()).add(self._name)
self._transaction_cache.pop(transaction, None)
def _clear(self, dbname, timestamp=None):
logger.debug("clearing cache '%s' of '%s'", self._name, dbname)
self._timestamp[dbname] = timestamp
self._database_cache[dbname] = self._database_cache.default_factory()
self._transaction_lower[dbname] = max(
Transaction.monotonic_time(),
self._transaction_lower.get(dbname, self._default_lower))
@classmethod
def _clear_all(cls, dbname):
for inst in cls._instances.values():
inst._clear(dbname)
@classmethod
def sync(cls, transaction):
if cls._clean_last is None:
cls._clean_last = dt.datetime.now()
return
database = transaction.database
dbname = database.name
if not _clear_timeout and database.has_channel():
pid = os.getpid()
with cls._listener_lock[pid]:
if (pid, dbname) not in cls._listener:
cls._listener[pid, dbname] = listener = threading.Thread(
target=cls._listen, args=(dbname,), daemon=True)
listener.start()
return
last_clean = (dt.datetime.now() - cls._clean_last).total_seconds()
if last_clean < _clear_timeout:
return
connection = database.get_connection(readonly=True, autocommit=True)
try:
with connection.cursor() as cursor:
table = Table(cls._table)
cursor.execute(*table.select(
_cast(table.timestamp), table.name))
timestamps = {}
for timestamp, name in cursor:
timestamps[name] = timestamp
modules = _get_modules(cursor)
finally:
database.put_connection(connection)
for name, timestamp in timestamps.items():
try:
inst = cls._instances[name]
except KeyError:
continue
inst_timestamp = inst._timestamp.get(dbname)
if not inst_timestamp or timestamp > inst_timestamp:
inst._clear(dbname, timestamp)
Pool.refresh(dbname, modules)
cls._clean_last = dt.datetime.now()
def sync_since(self, value):
return self._clean_last > value
@classmethod
def commit(cls, transaction):
table = Table(cls._table)
reset = cls._reset.pop(transaction, None)
if not reset:
return
database = transaction.database
dbname = database.name
if not _clear_timeout and transaction.database.has_channel():
with transaction.connection.cursor() as cursor:
# The count computed as
# 8000 (max notify size) / 64 (max name data len)
for sub_reset in grouped_slice(reset, 125):
cursor.execute(
'NOTIFY "%s", %%s' % cls._channel,
(json.dumps(list(sub_reset), separators=(',', ':')),))
else:
connection = database.get_connection(
readonly=False, autocommit=True)
try:
with connection.cursor() as cursor:
for name in reset:
cursor.execute(*table.select(table.name, table.id,
table.timestamp,
where=table.name == name,
limit=1))
if cursor.fetchone():
# It would be better to insert only
cursor.execute(*table.update([table.timestamp],
[CurrentTimestamp()],
where=table.name == name))
else:
cursor.execute(*table.insert(
[table.timestamp, table.name],
[[CurrentTimestamp(), name]]))
cursor.execute(*table.select(
Max(table.timestamp),
where=table.name == name))
timestamp, = cursor.fetchone()
cursor.execute(*table.select(
_cast(Max(table.timestamp)),
where=table.name == name))
timestamp, = cursor.fetchone()
inst = cls._instances[name]
inst._clear(dbname, timestamp)
connection.commit()
finally:
database.put_connection(connection)
cls._clean_last = dt.datetime.now()
reset.clear()
@classmethod
def rollback(cls, transaction):
cls._reset.pop(transaction, None)
@classmethod
def drop(cls, dbname):
pid = os.getpid()
with cls._listener_lock[pid]:
listener = cls._listener.pop((pid, dbname), None)
if listener:
database = backend.Database(dbname)
conn = database.get_connection()
try:
cursor = conn.cursor()
cursor.execute('NOTIFY "%s"' % cls._channel)
conn.commit()
finally:
database.put_connection(conn)
listener.join()
for inst in cls._instances.values():
inst._timestamp.pop(dbname, None)
inst._database_cache.pop(dbname, None)
inst._transaction_lower.pop(dbname, None)
@classmethod
def refresh_pool(cls, transaction):
database = transaction.database
dbname = database.name
if not _clear_timeout and database.has_channel():
database = backend.Database(dbname)
conn = database.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'NOTIFY "%s", %%s' % cls._channel, ('refresh pool',))
conn.commit()
finally:
database.put_connection(conn)
@classmethod
def _listen(cls, dbname):
current_thread = threading.current_thread()
pid = os.getpid()
conn, selector = None, None
try:
database = backend.Database(dbname)
if not database.has_channel():
raise NotImplementedError
logger.info(
"listening on channel '%s' of '%s'", cls._channel, dbname)
conn = database.get_connection(autocommit=True)
selector = selectors.DefaultSelector()
cursor = conn.cursor()
cursor.execute('LISTEN "%s"' % cls._channel)
# Clear everything in case we missed a payload
Pool.refresh(dbname, _get_modules(cursor))
cls._clear_all(dbname)
current_thread.listening = True
selector.register(conn, selectors.EVENT_READ)
while cls._listener.get((pid, dbname)) == current_thread:
selector.select(timeout=60)
conn.poll()
while conn.notifies:
notification = conn.notifies.pop()
if notification.payload == 'refresh pool':
Pool.refresh(dbname, _get_modules(cursor))
elif notification.payload:
reset = json.loads(notification.payload)
for name in reset:
inst = cls._instances[name]
inst._clear(dbname)
cls._clean_last = dt.datetime.now()
except Exception:
logger.error(
"cache listener on '%s' crashed", dbname, exc_info=True)
raise
finally:
if selector:
selector.close()
if conn:
database.put_connection(conn)
with cls._listener_lock[pid]:
if cls._listener.get((pid, dbname)) == current_thread:
del cls._listener[pid, dbname]
if config.get('cache', 'class'):
Cache = resolve(config.get('cache', 'class'))
else:
Cache = MemoryCache
class LRUDict(OrderedDict):
"""
Dictionary with a size limit.
If size limit is reached, it will remove the first added items.
The default_factory provides the same behavior as in standard
collections.defaultdict.
If default_factory_with_key is set, the default_factory is called with the
missing key.
"""
__slots__ = ('size_limit',)
def __init__(self, size_limit,
default_factory=None, default_factory_with_key=False,
*args, **kwargs):
assert size_limit > 0
self.size_limit = size_limit
super(LRUDict, self).__init__(*args, **kwargs)
self.default_factory = default_factory
self.default_factory_with_key = default_factory_with_key
self._check_size_limit()
def __setitem__(self, key, value):
super(LRUDict, self).__setitem__(key, value)
self._check_size_limit()
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
if self.default_factory_with_key:
value = self.default_factory(key)
else:
value = self.default_factory()
self[key] = value
return value
def update(self, *args, **kwargs):
super(LRUDict, self).update(*args, **kwargs)
self._check_size_limit()
def setdefault(self, key, default=None):
default = super(LRUDict, self).setdefault(key, default=default)
self._check_size_limit()
return default
def _check_size_limit(self):
while len(self) > self.size_limit:
self.popitem(last=False)
class LRUDictTransaction(LRUDict):
"""
Dictionary with a size limit and default_factory. (see LRUDict)
It is refreshed when transaction counter is changed.
"""
__slots__ = ('transaction', 'counter')
def __init__(self, *args, **kwargs):
super(LRUDictTransaction, self).__init__(*args, **kwargs)
self.transaction = Transaction()
self.counter = self.transaction.counter
def clear(self):
super(LRUDictTransaction, self).clear()
self.counter = self.transaction.counter
def refresh(self):
if self.counter != self.transaction.counter:
self.clear()