# This file is part of Tryton. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. from sql.operators import Equal from trytond.transaction import Transaction, without_check_access from .modelsql import Exclude, ModelSQL from .modelstorage import ModelStorage class ModelSingleton(ModelStorage): """ Define a singleton model in Tryton. """ @classmethod def __setup__(cls): super().__setup__() # Cache disable because it is used as a read by the client cls.__rpc__['default_get'].cache = None if issubclass(cls, ModelSQL): table = cls.__table__() cls._sql_constraints.append( ('singleton', Exclude(table, (table.id * 0, Equal)), 'ir.msg_singleton')) @classmethod def get_singleton(cls): ''' Return the instance of the unique record if there is one. ''' singletons = super(ModelSingleton, cls).search([], limit=1) if singletons: return singletons[0] @classmethod def create(cls, vlist): assert len(vlist) == 1 singleton = cls.get_singleton() if not singleton: if issubclass(cls, ModelSQL): cls.lock() return super(ModelSingleton, cls).create(vlist) cls.write([singleton], vlist[0]) return [singleton] @classmethod def read(cls, ids, fields_names): singleton = cls.get_singleton() if not singleton: fname_no_rec_name = [ f for f in fields_names if '.' not in f and not f.startswith('_')] res = cls.default_get(fname_no_rec_name, with_rec_name=len(fname_no_rec_name) != len(fields_names)) for field_name in fields_names: if field_name not in res: res[field_name] = None res['id'] = ids[0] res['_write'] = True res['_delete'] = True return [res] res = super(ModelSingleton, cls).read([singleton.id], fields_names) res[0]['id'] = ids[0] return res @classmethod def write(cls, records, values, *args): singleton = cls.get_singleton() if not singleton: with without_check_access(): singleton, = cls.create([values]) actions = (records, {}) + args else: actions = (records, values) + args args = [] for values in actions[1:None:2]: args.extend(([singleton], values)) super(ModelSingleton, cls).write(*args) # Clean local cache of original records for record in sum(actions[0:None:2], []): record._local_cache.pop(record.id, None) # Clean transaction cache of all ids for cache in Transaction().cache.values(): if cls.__name__ in cache: cache[cls.__name__].clear() @classmethod def delete(cls, records): singleton = cls.get_singleton() if singleton: super(ModelSingleton, cls).delete([singleton]) # Clean transaction cache of all ids for cache in Transaction().cache.values(): if cls.__name__ in cache: cache[cls.__name__].clear() @classmethod def copy(cls, records, default=None): if default: cls.write(records, default) return records @classmethod def search(cls, domain, offset=0, limit=None, order=None, count=False): res = super(ModelSingleton, cls).search(domain, offset=offset, limit=limit, order=order, count=count) if not res and not domain: if count: return 1 return [cls(1)] return res @classmethod def default_get(cls, fields_names, with_rec_name=True): if '_timestamp' in fields_names: fields_names = list(fields_names) fields_names.remove('_timestamp') default = super(ModelSingleton, cls).default_get(fields_names, with_rec_name=with_rec_name) singleton = cls.get_singleton() if singleton: if with_rec_name: fields_names = fields_names[:] for field in fields_names[:]: if cls._fields[field]._type in ('many2one',): fields_names.append(field + '.rec_name') default, = cls.read([singleton.id], fields_names=fields_names) del default['id'] return default