# This file is part of Tryton. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. import datetime from collections import OrderedDict, defaultdict from functools import wraps from itertools import chain, groupby, islice, product, repeat from sql import ( Asc, Column, Desc, Expression, For, Literal, Null, NullsFirst, NullsLast, Table, Union, Window, With) from sql.aggregate import Count, Max from sql.conditionals import Coalesce from sql.functions import CurrentTimestamp, Extract, RowNumber, Substring from sql.operators import And, Concat, Equal, Exists, Operator, Or from trytond import backend from trytond.cache import freeze from trytond.config import config from trytond.exceptions import ConcurrencyException from trytond.i18n import gettext from trytond.pool import Pool from trytond.pyson import PYSONDecoder, PYSONEncoder from trytond.rpc import RPC from trytond.tools import cursor_dict, grouped_slice, reduce_ids from trytond.transaction import ( Transaction, inactive_records, record_cache_size, without_check_access) from . import fields from .descriptors import dualmethod from .modelstorage import ( AccessError, ModelStorage, RequiredValidationError, SizeValidationError, ValidationError, is_leaf) from .modelview import ModelView class ForeignKeyError(ValidationError): pass class SQLConstraintError(ValidationError): pass class Constraint(object): __slots__ = ('_table',) def __init__(self, table): assert isinstance(table, Table) self._table = table @property def table(self): return self._table def __str__(self): raise NotImplementedError @property def params(self): raise NotImplementedError class Check(Constraint): __slots__ = ('_expression',) def __init__(self, table, expression): super(Check, self).__init__(table) assert isinstance(expression, Expression) self._expression = expression @property def expression(self): return self._expression def __str__(self): return 'CHECK(%s)' % self.expression @property def params(self): return self.expression.params class Unique(Constraint): __slots__ = ('_columns',) def __init__(self, table, *columns): super(Unique, self).__init__(table) assert all(isinstance(col, Column) for col in columns) self._columns = tuple(columns) @property def columns(self): return self._columns @property def operators(self): return tuple(Equal for c in self._columns) def __str__(self): return 'UNIQUE(%s)' % (', '.join(map(str, self.columns))) @property def params(self): p = [] for column in self.columns: p.extend(column.params) return tuple(p) class Exclude(Constraint): __slots__ = ('_excludes', '_where') def __init__(self, table, *excludes, **kwargs): super(Exclude, self).__init__(table) assert all(isinstance(c, Expression) and issubclass(o, Operator) for c, o in excludes), excludes self._excludes = tuple(excludes) where = kwargs.get('where') if where is not None: assert isinstance(where, Expression) self._where = where @property def excludes(self): return self._excludes @property def columns(self): return tuple(c for c, _ in self._excludes) @property def operators(self): return tuple(o for _, o in self._excludes) @property def where(self): return self._where def __str__(self): exclude = ', '.join('%s WITH %s' % (column, operator._operator) for column, operator in self.excludes) where = '' if self.where: where = ' WHERE ' + str(self.where) return 'EXCLUDE (%s)' % exclude + where @property def params(self): p = [] for column, operator in self._excludes: p.extend(column.params) if self.where: p.extend(self.where.params) return tuple(p) class Index: __slots__ = ('table', 'expressions', 'options') def __init__(self, table, *expressions, **options): self.table = table assert all( isinstance(e, Expression) and isinstance(u, self.Usage) for e, u in expressions) self.expressions = expressions self.options = options def __hash__(self): table_def = ( self.table._name, self.table._schema, self.table._database) expressions = ( (str(e), e.params, hash(u)) for e, u in self.expressions) return hash((table_def, *expressions)) def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplementedError return ( str(self.table) == str(other.table) and len(self.expressions) == len(other.expressions) and all((str(c), u) == (str(oc), ou) for (c, u), (oc, ou) in zip( self.expressions, other.expressions)) and self._options_cmp == other._options_cmp) @property def _options_cmp(self): def _format(value): return str(value) if isinstance(value, Expression) else value return {k: _format(v) for k, v in self.options.items()} def __lt__(self, other): if not isinstance(other, self.__class__): return NotImplementedError if (self.table != other.table or self._options_cmp != other._options_cmp): return False if self == other: return False if len(self.expressions) >= len(other.expressions): return False for (c, u), (oc, ou) in zip(self.expressions, other.expressions): if (str(c), u) != (str(oc), ou): return False return True def __le__(self, other): if not isinstance(other, self.__class__): return NotImplementedError return self < other or self == other class Unaccent(Expression): "Unaccent function if database support for index" __slots__ = ('_expression',) def __init__(self, expression): self._expression = expression @property def expression(self): expression = self._expression database = Transaction().database if database.has_unaccent_indexable(): expression = database.unaccent(expression) return expression def __str__(self): return str(self.expression) @property def params(self): return self.expression.params class Usage: __slots__ = ('options',) def __init__(self, **options): self.options = options def __hash__(self): return hash((self.__class__.__name__, *self.options.items())) def __eq__(self, other): return (self.__class__ == other.__class__ and self.options == other.options) class Equality(Usage): __slots__ = () class Range(Usage): __slots__ = () class Similarity(Usage): __slots__ = () def no_table_query(func): @wraps(func) def wrapper(cls, *args, **kwargs): if callable(cls.table_query): raise NotImplementedError("On table_query") return func(cls, *args, **kwargs) return wrapper class ModelSQL(ModelStorage): """ Define a model with storage in database. """ __slots__ = () _table = None # The name of the table in database _order = None _order_name = None # Use to force order field when sorting on Many2One _history = False table_query = None @classmethod def __setup__(cls): cls._table = config.get('table', cls.__name__, default=cls._table) if not cls._table: cls._table = cls.__name__.replace('.', '_') assert cls._table[-9:] != '__history', \ 'Model _table %s cannot end with "__history"' % cls._table super(ModelSQL, cls).__setup__() cls._sql_constraints = [] cls._sql_indexes = set() cls._history_sql_indexes = set() if not callable(cls.table_query): table = cls.__table__() cls._sql_constraints.append( ('id_positive', Check(table, table.id >= 0), 'ir.msg_id_positive')) rec_name_field = getattr(cls, cls._rec_name, None) if (isinstance(rec_name_field, fields.Field) and not hasattr(rec_name_field, 'set')): column = Column(table, cls._rec_name) if getattr(rec_name_field, 'search_unaccented', False): column = Index.Unaccent(column) cls._sql_indexes.add( Index(table, (column, Index.Similarity()))) cls._order = [('id', None)] if issubclass(cls, ModelView): cls.__rpc__.update({ 'history_revisions': RPC(), }) if cls._history: history_table = cls.__table_history__() cls._history_sql_indexes.update({ Index( history_table, (history_table.id, Index.Equality())), Index( history_table, (Coalesce( history_table.write_date, history_table.create_date).desc, Index.Range()), include=[ Column(history_table, '__id'), history_table.id]), }) @classmethod def __post_setup__(cls): super().__post_setup__() # Define Range index to optimise with reduce_ids for field in cls._fields.values(): field_names = set() if isinstance(field, fields.One2Many): Target = field.get_target() if field.field: field_names.add(field.field) elif isinstance(field, fields.Many2Many): Target = field.get_relation() if field.origin: field_names.add(field.origin) if field.target: field_names.add(field.target) else: continue field_names.discard('id') for field_name in field_names: target_field = getattr(Target, field_name) if (issubclass(Target, ModelSQL) and not callable(Target.table_query) and not hasattr(target_field, 'set')): target = Target.__table__() column = Column(target, field_name) if not target_field.required and Target != cls: where = column != Null else: where = None if target_field._type == 'reference': Target._sql_indexes.update({ Index( target, (column, Index.Equality()), where=where), Index( target, (column, Index.Similarity(begin=True)), (target_field.sql_id(column, Target), Index.Range()), where=where), }) else: Target._sql_indexes.add( Index( target, (column, Index.Range()), where=where)) @classmethod def __table__(cls): if callable(cls.table_query): return cls.table_query() else: return Table(cls._table) @classmethod def __table_history__(cls): if not cls._history: raise ValueError('No history table') return Table(cls._table + '__history') @classmethod def __table_handler__(cls, module_name=None, history=False): return backend.TableHandler(cls, history=history) @classmethod def __register__(cls, module_name): cursor = Transaction().connection.cursor() super(ModelSQL, cls).__register__(module_name) if callable(cls.table_query): return pool = Pool() # Initiate after the callable test to prevent calling table_query which # may rely on other model being registered sql_table = cls.__table__() # create/update table in the database table = cls.__table_handler__(module_name) if cls._history: history_table = cls.__table_handler__(module_name, history=True) for field_name, field in cls._fields.items(): if field_name == 'id': continue sql_type = field.sql_type() if not sql_type: continue if field_name in cls._defaults: def default(): default_ = cls._clean_defaults({ field_name: cls._defaults[field_name](), })[field_name] return field.sql_format(default_) else: default = None table.add_column(field_name, field._sql_type, default=default) if cls._history: history_table.add_column(field_name, field._sql_type) if isinstance(field, (fields.Integer, fields.Float)): # migration from tryton 2.2 table.db_default(field_name, None) if isinstance(field, (fields.Boolean)): table.db_default(field_name, False) if isinstance(field, fields.Many2One): if field.model_name in ('res.user', 'res.group'): # XXX need to merge ir and res ref = field.model_name.replace('.', '_') else: ref_model = pool.get(field.model_name) if (issubclass(ref_model, ModelSQL) and not callable(ref_model.table_query)): ref = ref_model._table # Create foreign key table if missing if not backend.TableHandler.table_exist(ref): backend.TableHandler(ref_model) else: ref = None if field_name in ['create_uid', 'write_uid']: # migration from 3.6 table.drop_fk(field_name) elif ref: table.add_fk(field_name, ref, on_delete=field.ondelete) required = field.required # Do not set 'NOT NULL' for Binary field as the database column # will be left empty if stored in the filestore or filled later by # the set method. if isinstance(field, fields.Binary): required = False table.not_null_action( field_name, action=required and 'add' or 'remove') for field_name, field in cls._fields.items(): if (isinstance(field, fields.Many2One) and field.model_name == cls.__name__): if field.path: default_path = cls._defaults.get( field.path, lambda: None)() cursor.execute(*sql_table.select(sql_table.id, where=( Column(sql_table, field.path) == default_path) | (Column(sql_table, field.path) == Null), limit=1)) if cursor.fetchone(): cls._rebuild_path(field_name) if field.left and field.right: left_default = cls._defaults.get( field.left, lambda: None)() right_default = cls._defaults.get( field.right, lambda: None)() cursor.execute(*sql_table.select(sql_table.id, where=( Column(sql_table, field.left) == left_default) | (Column(sql_table, field.left) == Null) | (Column(sql_table, field.right) == right_default) | (Column(sql_table, field.right) == Null), limit=1)) if cursor.fetchone(): cls._rebuild_tree(field_name, None, 0) for ident, constraint, _ in cls._sql_constraints: assert ( not ident.startswith('idx_') and not ident.endswith('_index')) table.add_constraint(ident, constraint) if cls._history: cls._update_history_table() h_table = cls.__table_history__() cursor.execute(*sql_table.select(sql_table.id, limit=1)) if cursor.fetchone(): cursor.execute( *h_table.select(h_table.id, limit=1)) if not cursor.fetchone(): columns = [n for n, f in cls._fields.items() if f.sql_type()] cursor.execute(*h_table.insert( [Column(h_table, c) for c in columns], sql_table.select(*(Column(sql_table, c) for c in columns)))) cursor.execute(*h_table.update( [h_table.write_date], [None])) @classmethod def _update_sql_indexes(cls, concurrently=False): def no_subset(index): for j in cls._sql_indexes: if j != index and index < j: return False return True if not callable(cls.table_query): table_h = cls.__table_handler__() indexes = filter(no_subset, cls._sql_indexes) table_h.set_indexes(indexes, concurrently=concurrently) if cls._history: history_th = cls.__table_handler__(history=True) indexes = filter(no_subset, cls._history_sql_indexes) history_th.set_indexes(indexes, concurrently=concurrently) @classmethod def _update_history_table(cls): if cls._history: history_table = cls.__table_handler__(history=True) for field_name, field in cls._fields.items(): if not field.sql_type(): continue history_table.add_column(field_name, field._sql_type) @classmethod @without_check_access def __raise_integrity_error( cls, exception, values, field_names=None, transaction=None): pool = Pool() if field_names is None: field_names = list(cls._fields.keys()) if transaction is None: transaction = Transaction() for field_name in field_names: if field_name not in cls._fields: continue field = cls._fields[field_name] # Check required fields if (field.required and field.sql_type() and field_name not in ('create_uid', 'create_date')): if values.get(field_name) is None: raise RequiredValidationError( gettext('ir.msg_required_validation', **cls.__names__(field_name))) for name, _, error in cls._sql_constraints: if backend.TableHandler.convert_name(name) in str(exception): raise SQLConstraintError(gettext(error)) # Check foreign key in last because this can raise false positive # if the target is created during the same transaction. for field_name in field_names: if field_name not in cls._fields: continue field = cls._fields[field_name] if isinstance(field, fields.Many2One) and values.get(field_name): Model = pool.get(field.model_name) create_records = transaction.create_records[field.model_name] delete_records = transaction.delete_records[field.model_name] target_records = Model.search([ ('id', '=', field.sql_format(values[field_name])), ], order=[]) if not (( target_records or (values[field_name] in create_records)) and (values[field_name] not in delete_records)): error_args = cls.__names__(field_name) error_args['value'] = values[field_name] raise ForeignKeyError( gettext('ir.msg_foreign_model_missing', **error_args)) @classmethod @without_check_access def __raise_data_error( cls, exception, values, field_names=None, transaction=None): if field_names is None: field_names = list(cls._fields.keys()) if transaction is None: transaction = Transaction() for field_name in field_names: if field_name not in cls._fields: continue field = cls._fields[field_name] # Check field size if (hasattr(field, 'size') and isinstance(field.size, int) and field.sql_type()): value = values.get(field_name) or '' size = len(value) if size > field.size: error_args = cls.__names__(field_name) error_args['value'] = value error_args['size'] = size error_args['max_size'] = field.size raise SizeValidationError( gettext('ir.msg_size_validation', **error_args)) @classmethod def history_revisions(cls, ids): pool = Pool() ModelAccess = pool.get('ir.model.access') User = pool.get('res.user') cursor = Transaction().connection.cursor() ModelAccess.check(cls.__name__, 'read') table = cls.__table_history__() user = User.__table__() revisions = [] for sub_ids in grouped_slice(ids): where = reduce_ids(table.id, sub_ids) cursor.execute(*table.join(user, 'LEFT', Coalesce(table.write_uid, table.create_uid) == user.id) .select( Coalesce(table.write_date, table.create_date), table.id, user.name, where=where)) revisions.append(cursor.fetchall()) revisions = list(chain(*revisions)) revisions.sort(reverse=True) # SQLite uses char for COALESCE if revisions and isinstance(revisions[0][0], str): strptime = datetime.datetime.strptime format_ = '%Y-%m-%d %H:%M:%S.%f' revisions = [(strptime(timestamp, format_), id_, name) for timestamp, id_, name in revisions] return revisions @classmethod def _insert_history(cls, ids, deleted=False): transaction = Transaction() cursor = transaction.connection.cursor() if not cls._history: return user = transaction.user table = cls.__table__() history = cls.__table_history__() columns = [] hcolumns = [] if not deleted: fields = cls._fields else: fields = { 'id': cls.id, 'write_uid': cls.write_uid, 'write_date': cls.write_date, } for fname, field in sorted(fields.items()): if not field.sql_type(): continue columns.append(Column(table, fname)) hcolumns.append(Column(history, fname)) for sub_ids in grouped_slice(ids): if not deleted: where = reduce_ids(table.id, sub_ids) cursor.execute(*history.insert(hcolumns, table.select(*columns, where=where))) else: if transaction.database.has_multirow_insert(): cursor.execute(*history.insert(hcolumns, [[id_, CurrentTimestamp(), user] for id_ in sub_ids])) else: for id_ in sub_ids: cursor.execute(*history.insert(hcolumns, [[id_, CurrentTimestamp(), user]])) @classmethod def _restore_history(cls, ids, datetime, _before=False): if not cls._history: return transaction = Transaction() cursor = transaction.connection.cursor() table = cls.__table__() history = cls.__table_history__() transaction.counter += 1 for cache in transaction.cache.values(): if cls.__name__ in cache: cache_cls = cache[cls.__name__] for id_ in ids: cache_cls.pop(id_, None) columns = [] hcolumns = [] fnames = sorted(n for n, f in cls._fields.items() if f.sql_type()) for fname in fnames: columns.append(Column(table, fname)) if fname == 'write_uid': hcolumns.append(Literal(transaction.user)) elif fname == 'write_date': hcolumns.append(CurrentTimestamp()) else: hcolumns.append(Column(history, fname)) def is_deleted(values): return all(not v for n, v in zip(fnames, values) if n not in ['id', 'write_uid', 'write_date']) to_delete = [] to_update = [] for id_ in ids: column_datetime = Coalesce(history.write_date, history.create_date) if not _before: hwhere = (column_datetime <= datetime) else: hwhere = (column_datetime < datetime) hwhere &= (history.id == id_) horder = (column_datetime.desc, Column(history, '__id').desc) cursor.execute(*history.select(*hcolumns, where=hwhere, order_by=horder, limit=1)) values = cursor.fetchone() if not values or is_deleted(values): to_delete.append(id_) else: to_update.append(id_) values = list(values) cursor.execute(*table.update(columns, values, where=table.id == id_)) rowcount = cursor.rowcount if rowcount == -1 or rowcount is None: cursor.execute(*table.select(table.id, where=table.id == id_)) rowcount = len(cursor.fetchall()) if rowcount < 1: cursor.execute(*table.insert(columns, [values])) if to_delete: for sub_ids in grouped_slice(to_delete): where = reduce_ids(table.id, sub_ids) cursor.execute(*table.delete(where=where)) cls._insert_history(to_delete, True) if to_update: cls._insert_history(to_update) @classmethod def restore_history(cls, ids, datetime): 'Restore record ids from history at the date time' cls._restore_history(ids, datetime) @classmethod def restore_history_before(cls, ids, datetime): 'Restore record ids from history before the date time' cls._restore_history(ids, datetime, _before=True) @classmethod def __check_timestamp(cls, ids): transaction = Transaction() cursor = transaction.connection.cursor() table = cls.__table__() if not transaction.timestamp: return for sub_ids in grouped_slice(ids): where = Or() for id_ in sub_ids: try: timestamp = transaction.timestamp.pop( '%s,%s' % (cls.__name__, id_)) except KeyError: continue if timestamp is None: continue sql_type = fields.Char('timestamp').sql_type().base where.append((table.id == id_) & (Extract('EPOCH', Coalesce(table.write_date, table.create_date) ).cast(sql_type) != timestamp)) if where: cursor.execute(*table.select(table.id, where=where, limit=1)) if cursor.fetchone(): raise ConcurrencyException( 'Records were modified in the meanwhile') @classmethod @no_table_query def create(cls, vlist): transaction = Transaction() cursor = transaction.connection.cursor() in_max = transaction.database.IN_MAX pool = Pool() Translation = pool.get('ir.translation') super(ModelSQL, cls).create(vlist) table = cls.__table__() modified_fields = set() defaults_cache = {} # Store already computed default values missing_defaults = {} # Store missing default values by schema new_ids = [] vlist = [v.copy() for v in vlist] def db_insert(columns, vlist, column_names): if transaction.database.has_multirow_insert(): vlist = ( s for s in grouped_slice( vlist, in_max // (len(column_names) or 1))) else: vlist = ([v] for v in vlist) for values in vlist: values = list(values) cols = list(columns) try: if len(values) > 1: ids = transaction.database.nextid( transaction.connection, cls._table, count=len(values)) if ids is not None: assert len(ids) == len(values) cols.append(table.id) for val, id in zip(values, ids): val.append(id) cursor.execute(*table.insert(cols, values)) yield from ids continue for i, val in enumerate(values): if transaction.database.has_returning(): cursor.execute(*table.insert( cols, [val], [table.id])) yield from (r[0] for r in cursor) else: id_new = transaction.database.nextid( transaction.connection, cls._table) if id_new: if i == 0: cols.append(table.id) val.append(id_new) cursor.execute(*table.insert(cols, [val])) else: cursor.execute(*table.insert(cols, [val])) id_new = transaction.database.lastid(cursor) yield id_new except ( backend.DatabaseIntegrityError, backend.DatabaseDataError ) as exception: if isinstance(exception, backend.DatabaseIntegrityError): raise_func = cls.__raise_integrity_error elif isinstance(exception, backend.DatabaseDataError): raise_func = cls.__raise_data_error with transaction.new_transaction(): for value in values: skip = len(['create_uid', 'create_date']) recomposed = dict(zip(column_names, value[skip:])) raise_func( exception, recomposed, transaction=transaction) raise to_insert = [] previous_columns = [table.create_uid, table.create_date] previous_column_names = [] for values in vlist: # Clean values for key in ('create_uid', 'create_date', 'write_uid', 'write_date', 'id'): if key in values: del values[key] modified_fields |= values.keys() # Get default values values_schema = tuple(sorted(values)) if values_schema not in missing_defaults: default = [] missing_defaults[values_schema] = default_values = {} for fname, field in cls._fields.items(): if fname in values: continue if fname in [ 'create_uid', 'create_date', 'write_uid', 'write_date', 'id']: continue if isinstance(field, fields.Function) and not field.setter: continue if fname in defaults_cache: default_values[fname] = defaults_cache[fname] else: default.append(fname) if default: defaults = cls.default_get(default, with_rec_name=False) default_values.update(cls._clean_defaults(defaults)) defaults_cache.update(default_values) values.update(missing_defaults[values_schema]) current_column_names = [] current_columns = [table.create_uid, table.create_date] current_values = [transaction.user, CurrentTimestamp()] # Insert record for fname, value in sorted(values.items()): field = cls._fields[fname] if not hasattr(field, 'set'): current_columns.append(Column(table, fname)) current_values.append(field.sql_format(value)) current_column_names.append(fname) if current_column_names != previous_column_names: if to_insert: new_ids.extend(db_insert( previous_columns, to_insert, previous_column_names)) to_insert.clear() previous_columns = current_columns previous_column_names = current_column_names to_insert.append(current_values) else: if to_insert: new_ids.extend(db_insert( current_columns, to_insert, current_column_names)) transaction.create_records[cls.__name__].update(new_ids) # Update path before fields_to_set which could create children if cls._path_fields: field_names = list(sorted(cls._path_fields)) cls._set_path(field_names, repeat(new_ids, len(field_names))) # Update mptt before fields_to_set which could create children if cls._mptt_fields: field_names = list(sorted(cls._mptt_fields)) cls._update_mptt(field_names, repeat(new_ids, len(field_names))) translation_values = {} fields_to_set = {} for values, new_id in zip(vlist, new_ids): for fname, value in values.items(): field = cls._fields[fname] if (getattr(field, 'translate', False) and not hasattr(field, 'set')): translation_values.setdefault( '%s,%s' % (cls.__name__, fname), {})[new_id] = ( field.sql_format(value)) if hasattr(field, 'set'): args = fields_to_set.setdefault(fname, []) actions = iter(args) for ids, val in zip(actions, actions): if val == value: ids.append(new_id) break else: args.extend(([new_id], value)) if translation_values: for name, translations in translation_values.items(): Translation.set_ids(name, 'model', Transaction().language, list(translations.keys()), list(translations.values())) for fname in sorted(fields_to_set, key=cls.index_set_field): fargs = fields_to_set[fname] field = cls._fields[fname] field.set(cls, fname, *fargs) cls._insert_history(new_ids) cls.__check_domain_rule(new_ids, 'create') records = cls.browse(new_ids) for sub_records in grouped_slice( records, record_cache_size(transaction)): cls._validate(sub_records) cls.trigger_create(records) return records @classmethod def read(cls, ids, fields_names): pool = Pool() Rule = pool.get('ir.rule') Translation = pool.get('ir.translation') super(ModelSQL, cls).read(ids, fields_names=fields_names) transaction = Transaction() cursor = Transaction().connection.cursor() if not ids: return [] # construct a clause for the rules : domain = Rule.domain_get(cls.__name__, mode='read') fields_related = defaultdict(set) extra_fields = set() if 'write_date' not in fields_names: extra_fields.add('write_date') for field_name in fields_names: if field_name in {'_timestamp', '_write', '_delete'}: continue if '.' in field_name: field_name, field_related = field_name.split('.', 1) fields_related[field_name].add(field_related) if field_name.endswith(':string'): field_name = field_name[:-len(':string')] fields_related[field_name] field = cls._fields[field_name] if hasattr(field, 'datetime_field') and field.datetime_field: extra_fields.add(field.datetime_field) if field.context: extra_fields.update(fields.get_eval_fields(field.context)) extra_fields.discard('id') all_fields = ( set(fields_names) | fields_related.keys() | extra_fields) result = [] table = cls.__table__() in_max = transaction.database.IN_MAX history_order = None history_clause = None history_limit = None if (cls._history and transaction.context.get('_datetime') and not callable(cls.table_query)): in_max = 1 table = cls.__table_history__() column = Coalesce(table.write_date, table.create_date) history_clause = (column <= Transaction().context['_datetime']) history_order = (column.desc, Column(table, '__id').desc) history_limit = 1 columns = {} for f in all_fields: field = cls._fields.get(f) if field and field.sql_type(): columns[f] = field.sql_column(table).as_(f) if backend.name == 'sqlite': columns[f].output_name += ' [%s]' % field.sql_type().base elif f in {'_write', '_delete'}: if not callable(cls.table_query): rule_domain = Rule.domain_get( cls.__name__, mode=f.lstrip('_')) # No need to compute rule domain if it is the same as the # read rule domain because it is already applied as where # clause. if rule_domain and rule_domain != domain: rule_tables = {None: (table, None)} rule_tables, rule_expression = cls.search_domain( rule_domain, active_test=False, tables=rule_tables) if len(rule_tables) > 1: # The expression uses another table rule_tables, rule_expression = cls.search_domain( rule_domain, active_test=False) rule_from = convert_from(None, rule_tables) rule_table, _ = rule_tables[None] rule_where = rule_table.id == table.id rule_expression = rule_from.select( rule_expression, where=rule_where) columns[f] = rule_expression.as_(f) else: columns[f] = Literal(True).as_(f) elif f == '_timestamp' and not callable(cls.table_query): sql_type = fields.Char('timestamp').sql_type().base columns[f] = Extract( 'EPOCH', Coalesce(table.write_date, table.create_date) ).cast(sql_type).as_('_timestamp') if ('write_date' not in fields_names and columns.keys() == {'write_date'}): columns.pop('write_date') extra_fields.discard('write_date') if columns or domain: if 'id' not in fields_names: columns['id'] = table.id.as_('id') tables = {None: (table, None)} if domain: tables, dom_exp = cls.search_domain( domain, active_test=False, tables=tables) from_ = convert_from(None, tables) for sub_ids in grouped_slice(ids, in_max): sub_ids = list(sub_ids) red_sql = reduce_ids(table.id, sub_ids) where = red_sql if history_clause: where &= history_clause if domain: where &= dom_exp cursor.execute(*from_.select(*columns.values(), where=where, order_by=history_order, limit=history_limit)) fetchall = list(cursor_dict(cursor)) if not len(fetchall) == len({}.fromkeys(sub_ids)): cls.__check_domain_rule(ids, 'read') raise RuntimeError("Undetected access error") result.extend(fetchall) else: result = [{'id': x} for x in ids] cachable_fields = [] max_write_date = max( (r['write_date'] for r in result if r.get('write_date')), default=None) for fname, column in columns.items(): if fname.startswith('_'): continue field = cls._fields[fname] if not hasattr(field, 'get'): if getattr(field, 'translate', False): translations = Translation.get_ids( cls.__name__ + ',' + fname, 'model', Transaction().language, ids, cached_after=max_write_date) for row in result: row[fname] = translations.get(row['id']) or row[fname] if fname != 'id': cachable_fields.append(fname) # all fields for which there is a get attribute getter_fields = [f for f in all_fields if f in cls._fields and hasattr(cls._fields[f], 'get')] getter_fields = sorted(getter_fields, key=cls.index_get_field) cache = transaction.get_cache()[cls.__name__] if getter_fields and cachable_fields: for row in result: for fname in cachable_fields: cache[row['id']][fname] = row[fname] func_fields = {} for fname in getter_fields: field = cls._fields[fname] if isinstance(field, fields.Function): key = ( field.getter, field.getter_with_context, getattr(field, 'datetime_field', None)) func_fields.setdefault(key, []) func_fields[key].append(fname) elif getattr(field, 'datetime_field', None): for row in result: with Transaction().set_context( _datetime=row[field.datetime_field]): date_result = field.get([row['id']], cls, fname, values=[row]) row[fname] = date_result[row['id']] else: # get the value of that field for all records/ids getter_result = field.get(ids, cls, fname, values=result) for row in result: row[fname] = getter_result[row['id']] for key in func_fields: field_list = func_fields[key] fname = field_list[0] field = cls._fields[fname] _, getter_with_context, datetime_field = key if datetime_field: for row in result: with Transaction().set_context( _datetime=row[datetime_field]): date_results = field.get([row['id']], cls, field_list, values=[row]) for fname in field_list: date_result = date_results[fname] row[fname] = date_result[row['id']] else: for sub_results in grouped_slice( result, record_cache_size(transaction)): sub_results = list(sub_results) sub_ids = [] sub_values = [] for row in sub_results: if (row['id'] not in cache or any(f not in cache[row['id']] for f in field_list)): sub_ids.append(row['id']) sub_values.append(row) else: for fname in field_list: row[fname] = cache[row['id']][fname] getter_results = field.get( sub_ids, cls, field_list, values=sub_values) for fname in field_list: getter_result = getter_results[fname] for row in sub_values: row[fname] = getter_result[row['id']] if (transaction.readonly and not getter_with_context): cache[row['id']][fname] = row[fname] def read_related(field, Target, rows, fields): name = field.name target_ids = [] if field._type.endswith('2many'): add = target_ids.extend elif field._type == 'reference': def add(value): try: id_ = int(value.split(',', 1)[1]) except ValueError: pass else: if id_ >= 0: target_ids.append(id_) else: add = target_ids.append for row in rows: value = row[name] if value is not None: add(value) related_read_limit = transaction.context.get('related_read_limit') rows = Target.read(target_ids[:related_read_limit], fields) if related_read_limit is not None: rows += [{'id': i} for i in target_ids[related_read_limit:]] return rows def add_related(field, rows, targets): name = field.name key = name + '.' if field._type.endswith('2many'): for row in rows: row[key] = values = list() for target in row[name]: if target is not None: values.append(targets[target]) elif field._type in {'selection', 'multiselection'}: key = name + ':string' for row, target in zip(rows, targets): selection = field.get_selection(cls, name, target) if field._type == 'selection': row[key] = field.get_selection_string( selection, row[name]) else: row[key] = [ field.get_selection_string(selection, v) for v in row[name]] else: for row in rows: value = row[name] if isinstance(value, str): try: value = int(value.split(',', 1)[1]) except ValueError: value = None if value is not None and value >= 0: row[key] = targets[value] else: row[key] = None to_del = set() for fname in fields_related.keys() | extra_fields: if fname not in fields_names: to_del.add(fname) if fname not in cls._fields: continue if fname not in fields_related: continue field = cls._fields[fname] datetime_field = getattr(field, 'datetime_field', None) def groupfunc(row): ctx = {} if field.context: pyson_context = PYSONEncoder().encode(field.context) ctx.update(PYSONDecoder(row).decode(pyson_context)) if datetime_field: ctx['_datetime'] = row.get(datetime_field) if field._type in {'selection', 'multiselection'}: Target = None elif field._type == 'reference': value = row[fname] if not value: Target = None else: model, _ = value.split(',', 1) Target = pool.get(model) else: Target = field.get_target() return Target, ctx def orderfunc(row): Target, ctx = groupfunc(row) return (Target.__name__ if Target else '', freeze(ctx)) for (Target, ctx), rows in groupby( sorted(result, key=orderfunc), key=groupfunc): rows = list(rows) with Transaction().set_context(ctx): if Target: targets = read_related( field, Target, rows, list(fields_related[fname])) targets = {t['id']: t for t in targets} else: targets = cls.browse([r['id'] for r in rows]) add_related(field, rows, targets) for row, field in product(result, to_del): del row[field] return result @classmethod @no_table_query def write(cls, records, values, *args): transaction = Transaction() cursor = transaction.connection.cursor() pool = Pool() Translation = pool.get('ir.translation') Config = pool.get('ir.configuration') assert not len(args) % 2 # Remove possible duplicates from all records all_records = list(OrderedDict.fromkeys( sum(((records, values) + args)[0:None:2], []))) all_ids = [r.id for r in all_records] all_field_names = set() # Call before cursor cache cleaning trigger_eligibles = cls.trigger_write_get_eligibles(all_records) super(ModelSQL, cls).write(records, values, *args) table = cls.__table__() cls.__check_timestamp(all_ids) cls.__check_domain_rule(all_ids, 'write') fields_to_set = {} actions = iter((records, values) + args) for records, values in zip(actions, actions): ids = [r.id for r in records] values = values.copy() # Clean values for key in ('create_uid', 'create_date', 'write_uid', 'write_date', 'id'): if key in values: del values[key] columns = [table.write_uid, table.write_date] update_values = [transaction.user, CurrentTimestamp()] store_translation = Transaction().language == Config.get_language() for fname, value in values.items(): field = cls._fields[fname] if not hasattr(field, 'set'): if (not getattr(field, 'translate', False) or store_translation): columns.append(Column(table, fname)) update_values.append(field.sql_format(value)) for sub_ids in grouped_slice(ids): red_sql = reduce_ids(table.id, sub_ids) try: cursor.execute(*table.update(columns, update_values, where=red_sql)) except ( backend.DatabaseIntegrityError, backend.DatabaseDataError) as exception: transaction = Transaction() with Transaction().new_transaction(): if isinstance( exception, backend.DatabaseIntegrityError): cls.__raise_integrity_error( exception, values, list(values.keys()), transaction=transaction) elif isinstance(exception, backend.DatabaseDataError): cls.__raise_data_error( exception, values, list(values.keys()), transaction=transaction) raise for fname, value in values.items(): field = cls._fields[fname] if (getattr(field, 'translate', False) and not hasattr(field, 'set')): Translation.set_ids( '%s,%s' % (cls.__name__, fname), 'model', transaction.language, ids, [field.sql_format(value)] * len(ids)) if hasattr(field, 'set'): fields_to_set.setdefault(fname, []).extend((ids, value)) path_fields = cls._path_fields & values.keys() if path_fields: cls._update_path( list(sorted(path_fields)), repeat(ids, len(path_fields))) mptt_fields = cls._mptt_fields & values.keys() if mptt_fields: cls._update_mptt( list(sorted(mptt_fields)), repeat(ids, len(mptt_fields)), values) all_field_names |= values.keys() for fname in sorted(fields_to_set, key=cls.index_set_field): fargs = fields_to_set[fname] field = cls._fields[fname] field.set(cls, fname, *fargs) cls._insert_history(all_ids) cls.__check_domain_rule(all_ids, 'write') for sub_records in grouped_slice( all_records, record_cache_size(transaction)): cls._validate(sub_records, field_names=all_field_names) cls.trigger_write(trigger_eligibles) @classmethod @no_table_query def delete(cls, records): transaction = Transaction() in_max = transaction.database.IN_MAX cursor = transaction.connection.cursor() pool = Pool() Translation = pool.get('ir.translation') ids = list(map(int, records)) if not ids: return table = cls.__table__() if cls.__name__ in transaction.delete_records: ids = ids[:] for del_id in transaction.delete_records[cls.__name__]: while ids: try: ids.remove(del_id) except ValueError: break cls.__check_timestamp(ids) cls.__check_domain_rule(ids, 'delete') tree_ids = {} for fname in cls._mptt_fields: field = cls._fields[fname] tree_ids[fname] = [] for sub_ids in grouped_slice(ids): where = reduce_ids(field.sql_column(table), sub_ids) cursor.execute(*table.select(table.id, where=where)) tree_ids[fname] += [x[0] for x in cursor] has_translation = any( getattr(f, 'translate', False) and not hasattr(f, 'set') for f in cls._fields.values()) foreign_keys_tocheck = [] foreign_keys_toupdate = [] foreign_keys_todelete = [] for _, model in pool.iterobject(): if (not issubclass(model, ModelStorage) or callable(getattr(model, 'table_query', None))): continue for field_name, field in model._fields.items(): if (isinstance(field, fields.Many2One) and field.model_name == cls.__name__): if field.ondelete == 'CASCADE': foreign_keys_todelete.append((model, field_name)) elif field.ondelete == 'SET NULL': if field.required: foreign_keys_tocheck.append((model, field_name)) else: foreign_keys_toupdate.append((model, field_name)) else: foreign_keys_tocheck.append((model, field_name)) transaction.delete_records[cls.__name__].update(ids) cls.trigger_delete(records) if len(records) > in_max: # Clean self referencing foreign keys # before deleting them by small groups # Use the record id as value instead of NULL # in case the field is required foreign_fields_to_clean = [ fn for m, fn in foreign_keys_tocheck if m == cls] if foreign_fields_to_clean: for sub_ids in grouped_slice(ids): columns = [ Column(table, n) for n in foreign_fields_to_clean] cursor.execute(*table.update( columns, [table.id] * len(foreign_fields_to_clean), where=reduce_ids(table.id, sub_ids))) def get_related_records(Model, field_name, sub_ids): if issubclass(Model, ModelSQL): foreign_table = Model.__table__() foreign_red_sql = reduce_ids( Column(foreign_table, field_name), sub_ids) cursor.execute(*foreign_table.select(foreign_table.id, where=foreign_red_sql)) related_records = Model.browse([x[0] for x in cursor]) else: with without_check_access(), inactive_records(): related_records = Model.search( [(field_name, 'in', sub_ids)], order=[]) if Model == cls: related_records = list(set(related_records) - set(records)) return related_records for sub_ids, sub_records in zip( grouped_slice(ids), grouped_slice(records)): sub_ids = list(sub_ids) red_sql = reduce_ids(table.id, sub_ids) for Model, field_name in foreign_keys_toupdate: related_records = get_related_records( Model, field_name, sub_ids) if related_records: Model.write(related_records, { field_name: None, }) for Model, field_name in foreign_keys_todelete: related_records = get_related_records( Model, field_name, sub_ids) if related_records: Model.delete(related_records) for Model, field_name in foreign_keys_tocheck: if get_related_records(Model, field_name, sub_ids): error_args = Model.__names__(field_name) raise ForeignKeyError( gettext('ir.msg_foreign_model_exist', **error_args)) super(ModelSQL, cls).delete(list(sub_records)) try: cursor.execute(*table.delete(where=red_sql)) except backend.DatabaseIntegrityError as exception: transaction = Transaction() with Transaction().new_transaction(): cls.__raise_integrity_error( exception, {}, transaction=transaction) raise if has_translation: Translation.delete_ids(cls.__name__, 'model', ids) cls._insert_history(ids, deleted=True) cls._update_mptt(list(tree_ids.keys()), list(tree_ids.values())) @classmethod def __check_domain_rule(cls, ids, mode): pool = Pool() Lang = pool.get('ir.lang') Rule = pool.get('ir.rule') Model = pool.get('ir.model') try: User = pool.get('res.user') Group = pool.get('res.group') except KeyError: User = Group = None table = cls.__table__() transaction = Transaction() in_max = transaction.database.IN_MAX history_clause = None limit = None if (mode == 'read' and cls._history and transaction.context.get('_datetime') and not callable(cls.table_query)): in_max = 1 table = cls.__table_history__() column = Coalesce(table.write_date, table.create_date) history_clause = (column <= Transaction().context['_datetime']) limit = 1 cursor = transaction.connection.cursor() assert mode in Rule.modes def test_domain(ids, domain): result = [] tables = {None: (table, None)} if domain: tables, dom_exp = cls.search_domain( domain, active_test=False, tables=tables) from_ = convert_from(None, tables) for sub_ids in grouped_slice(ids, in_max): sub_ids = set(sub_ids) where = reduce_ids(table.id, sub_ids) if history_clause: where &= history_clause if domain: where &= dom_exp cursor.execute( *from_.select(table.id, where=where, limit=limit)) rowcount = cursor.rowcount if rowcount == -1 or rowcount is None: rowcount = len(cursor.fetchall()) if rowcount != len(sub_ids): cursor.execute( *from_.select(table.id, where=where, limit=limit)) result.extend( sub_ids.difference([x for x, in cursor])) return result domains = [] if mode in {'read', 'write'}: domains.append([]) domain = Rule.domain_get(cls.__name__, mode=mode) if domain: domains.append(domain) for domain in domains: wrong_ids = test_domain(ids, domain) if wrong_ids: model = cls.__name__ if Model: model = Model.get_name(cls.__name__) ids = ', '.join(map(str, wrong_ids[:5])) if len(wrong_ids) > 5: ids += '...' if domain: rules = [] clause, clause_global = Rule.get(cls.__name__, mode=mode) if clause: dom = list(clause.values()) dom.insert(0, 'OR') if test_domain(wrong_ids, dom): rules.extend(clause.keys()) for rule, dom in clause_global.items(): if test_domain(wrong_ids, dom): rules.append(rule) msg = gettext( f'ir.msg_{mode}_rule_error', ids=ids, model=model, rules='\n'.join(r.name for r in rules)) else: msg = gettext( f'ir.msg_{mode}_error', ids=ids, model=model) ctx_msg = [] lang = Lang.get() if cls._history and transaction.context.get('_datetime'): ctx_msg.append(gettext('ir.msg_context_datetime', datetime=lang.strftime( transaction.context['_datetime']))) if domain and User and Group: groups = Group.browse(User.get_groups()) ctx_msg.append(gettext('ir.msg_context_groups', groups=', '.join(g.rec_name for g in groups))) raise AccessError(msg, '\n'.join(ctx_msg)) @classmethod def __search_query(cls, domain, count, query, order): pool = Pool() Rule = pool.get('ir.rule') rule_domain = Rule.domain_get(cls.__name__, mode='read') joined_domains = None if domain and domain[0] == 'OR': local_domains, subquery_domains = split_subquery_domain(domain) if subquery_domains: joined_domains = subquery_domains if local_domains: local_domains.insert(0, 'OR') joined_domains.append(local_domains) def get_local_columns(order_exprs): local_columns = [] for order_expr in order_exprs: if (isinstance(order_expr, Column) and isinstance(order_expr._from, Table) and order_expr._from._name == cls._table): local_columns.append(order_expr._name) else: raise NotImplementedError return local_columns # The UNION optimization needs the columns used to order the query extra_columns = set() if order and joined_domains: tables = { None: (cls.__table__(), None), } for oexpr, otype in order: fname = oexpr.partition('.')[0] field = cls._fields[fname] field_orders = field.convert_order(oexpr, tables, cls) try: order_columns = get_local_columns(field_orders) extra_columns.update(order_columns) except NotImplementedError: joined_domains = None break # In case the search uses subqueries it's more efficient to use a UNION # of queries than using clauses with some JOIN because databases can # used indexes if joined_domains is not None: union_tables = [] for sub_domain in joined_domains: sub_domain = [sub_domain] # it may be a clause tables, expression = cls.search_domain(sub_domain) if rule_domain: tables, domain_exp = cls.search_domain( rule_domain, active_test=False, tables=tables) expression &= domain_exp main_table, _ = tables[None] table = convert_from(None, tables) columns = cls.__searched_columns(main_table, eager=not count and not query, extra_columns=extra_columns) union_tables.append(table.select( *columns, where=expression)) expression = None tables = { None: (Union(*union_tables, all_=False), None), } else: tables, expression = cls.search_domain(domain) if rule_domain: tables, domain_exp = cls.search_domain( rule_domain, active_test=False, tables=tables) expression &= domain_exp return tables, expression @classmethod def __searched_columns( cls, table, *, eager=False, history=False, extra_columns=None): if extra_columns is None: extra_columns = [] else: extra_columns = sorted(extra_columns - {'id', '__id', '_datetime'}) columns = [table.id.as_('id')] if (cls._history and Transaction().context.get('_datetime') and (eager or history)): columns.append( Coalesce(table.write_date, table.create_date).as_('_datetime')) columns.append(Column(table, '__id').as_('__id')) for column_name in extra_columns: field = cls._fields[column_name] sql_column = field.sql_column(table).as_(column_name) columns.append(sql_column) if eager: columns += [f.sql_column(table).as_(n) for n, f in sorted(cls._fields.items()) if not hasattr(f, 'get') and n not in extra_columns and n != 'id' and not getattr(f, 'translate', False) and f.loading == 'eager'] if not callable(cls.table_query): sql_type = fields.Char('timestamp').sql_type().base columns += [Extract('EPOCH', Coalesce(table.write_date, table.create_date) ).cast(sql_type).as_('_timestamp')] return columns @classmethod def __search_order(cls, order, tables): order_by = [] order_types = { 'DESC': Desc, 'ASC': Asc, } null_ordering_types = { 'NULLS FIRST': NullsFirst, 'NULLS LAST': NullsLast, None: lambda _: _ } for oexpr, otype in order: fname, _, extra_expr = oexpr.partition('.') field = cls._fields[fname] if not otype: otype, null_ordering = 'ASC', None else: otype = otype.upper() try: otype, null_ordering = otype.split(' ', 1) except ValueError: null_ordering = None Order = order_types[otype] NullOrdering = null_ordering_types[null_ordering] forder = field.convert_order(oexpr, tables, cls) order_by.extend((NullOrdering(Order(o)) for o in forder)) return order_by @classmethod def search(cls, domain, offset=0, limit=None, order=None, count=False, query=False): transaction = Transaction() cursor = transaction.connection.cursor() super(ModelSQL, cls).search( domain, offset=offset, limit=limit, order=order, count=count) if order is None or order is False: order = cls._order tables, expression = cls.__search_query(domain, count, query, order) main_table, _ = tables[None] if count: table = convert_from(None, tables) if (limit is not None and limit < cls.estimated_count()) or offset: select = table.select( Literal(1), where=expression, limit=limit, offset=offset ).select(Count(Literal('*'))) else: select = table.select(Count(Literal('*')), where=expression) if query: return select else: cursor.execute(*select) return cursor.fetchone()[0] order_by = cls.__search_order(order, tables) # compute it here because __search_order might modify tables table = convert_from(None, tables) if query: columns = [main_table.id.as_('id')] else: columns = cls.__searched_columns(main_table, eager=True) if backend.name == 'sqlite': for column in columns: field = cls._fields.get(column.output_name) if field: column.output_name += ' [%s]' % field.sql_type().base select = table.select( *columns, where=expression, limit=limit, offset=offset, order_by=order_by) if query: return select cursor.execute(*select) rows = list(cursor_dict(cursor, transaction.database.IN_MAX)) cache = transaction.get_cache() delete_records = transaction.delete_records[cls.__name__] # Can not cache the history value if we are not sure to have fetch all # the rows for each records if (not (cls._history and transaction.context.get('_datetime')) or len(rows) < transaction.database.IN_MAX): keys = None for data in islice(rows, 0, cache.size_limit): if data['id'] in delete_records: continue if keys is None: keys = list(data.keys()) for k in keys[:]: if k in ('_timestamp', '_datetime', '__id'): continue field = cls._fields[k] if not getattr(field, 'datetime_field', None): keys.remove(k) continue for k in keys: del data[k] cache[cls.__name__][data['id']]._update(data) return cls.browse([x['id'] for x in rows]) @classmethod def search_domain(cls, domain, active_test=None, tables=None): ''' Return SQL tables and expression Set active_test to add it. ''' transaction = Transaction() if active_test is None: active_test = transaction.active_records domain = cls._search_domain_active(domain, active_test=active_test) if tables is None: tables = {} if None not in tables: if cls._history and transaction.context.get('_datetime'): tables[None] = (cls.__table_history__(), None) else: tables[None] = (cls.__table__(), None) def convert(domain): if is_leaf(domain): fname = domain[0].split('.', 1)[0] field = cls._fields[fname] expression = field.convert_domain(domain, tables, cls) if not isinstance(expression, (Operator, Expression)): return convert(expression) return expression elif not domain or list(domain) in (['OR'], ['AND']): return Literal(True) elif domain[0] == 'OR': return Or((convert(d) for d in domain[1:])) else: return And((convert(d) for d in ( domain[1:] if domain[0] == 'AND' else domain))) with without_check_access(): expression = convert(domain) if cls._history and transaction.context.get('_datetime'): database = Transaction().database if database.has_window_functions(): table, _ = tables[None] history = cls.__table_history__() last_change = Coalesce(history.write_date, history.create_date) # prefilter the history records for a bit of a speedup selected_h_ids = convert_from(None, tables).select( table.id, where=expression) most_recent = history.select( history.create_date, Column(history, '__id'), RowNumber( window=Window([history.id], order_by=[ last_change.desc, Column(history, '__id').desc])).as_('rank'), where=((last_change <= transaction.context['_datetime']) & history.id.in_(selected_h_ids))) # Filter again as the latest records from most_recent might not # match the expression expression &= Exists(most_recent.select( Literal(1), where=( (Column(table, '__id') == Column(most_recent, '__id')) & (most_recent.create_date != Null) & (most_recent.rank == 1)))) else: table, _ = tables[None] history_1 = cls.__table_history__() history_2 = cls.__table_history__() last_change = Coalesce( history_1.write_date, history_1.create_date) latest_change = history_1.select( history_1.id, Max(last_change).as_('date'), where=(last_change <= transaction.context['_datetime']), group_by=[history_1.id]) most_recent = history_2.join( latest_change, condition=( (history_2.id == latest_change.id) & (Coalesce( history_2.write_date, history_2.create_date) == latest_change.date)) ).select( Max(Column(history_2, '__id')).as_('h_id'), where=(history_2.create_date != Null), group_by=[history_2.id]) expression &= Exists(most_recent.select( Literal(1), where=(Column(table, '__id') == most_recent.h_id))) return tables, expression @classmethod def _rebuild_path(cls, field_name): "Rebuild path for the tree." cursor = Transaction().connection.cursor() field = cls._fields[field_name] table = cls.__table__() tree = With('id', 'path', recursive=True) tree.query = table.select( table.id, Concat(table.id, '/'), where=Column(table, field_name) == Null) tree.query |= (table .join(tree, condition=Column(table, field_name) == tree.id) .select(table.id, Concat(Concat(tree.path, table.id), '/'))) query = table.update( [Column(table, field.path)], [tree.path], from_=[tree], where=table.id == tree.id, with_=[tree]) cursor.execute(*query) @classmethod def _set_path(cls, field_names, list_ids): cursor = Transaction().connection.cursor() table = cls.__table__() parent = cls.__table__() for field_name, ids in zip(field_names, list_ids): field = cls._fields[field_name] parent_column = Column(table, field_name) path_column = Column(table, field.path) query = table.update( [path_column], [Concat(Concat(Coalesce( parent.select(parent.path, where=parent.id == parent_column), ''), table.id), '/')]) for sub_ids in grouped_slice(ids): query.where = reduce_ids(table.id, sub_ids) cursor.execute(*query) @classmethod def _update_path(cls, field_names, list_ids): transaction = Transaction() cursor = transaction.connection.cursor() update = transaction.connection.cursor() table = cls.__table__() parent = cls.__table__() def update_path(query, column, sub_ids): updated = set() query.where = reduce_ids(table.id, sub_ids) cursor.execute(*query) for old_path, new_path in cursor: if old_path == new_path: continue if any(old_path.startswith(p) for p in updated): return False update.execute(*table.update( [column], [Concat(new_path, Substring(table.path, len(old_path) + 1))], where=table.path.like(old_path + '%'))) updated.add(old_path) return True for field_name, ids in zip(field_names, list_ids): field = cls._fields[field_name] parent_column = Column(table, field_name) parent_path_column = Column(parent, field.path) path_column = Column(table, field.path) query = (table .join(parent, 'LEFT', condition=parent_column == parent.id) .select(path_column, Concat(Concat( Coalesce(parent_path_column, ''), table.id), '/'))) for sub_ids in grouped_slice(ids): sub_ids = list(sub_ids) while not update_path(query, path_column, sub_ids): pass @classmethod def _update_mptt(cls, field_names, list_ids, values=None): for field_name, ids in zip(field_names, list_ids): field = cls._fields[field_name] if (values is not None and (field.left in values or field.right in values)): raise Exception('ValidateError', 'You can not update fields: "%s", "%s"' % (field.left, field.right)) if len(ids) < max(cls.estimated_count() / 4, 4): for id_ in ids: cls._update_tree(id_, field_name, field.left, field.right) else: cls._rebuild_tree(field_name, None, 0) @classmethod def _rebuild_tree(cls, parent, parent_id, left): ''' Rebuild left, right value for the tree. ''' cursor = Transaction().connection.cursor() table = cls.__table__() right = left + 1 cursor.execute(*table.select(table.id, where=Column(table, parent) == parent_id)) for child_id, in cursor: right = cls._rebuild_tree(parent, child_id, right) field = cls._fields[parent] if parent_id: cursor.execute(*table.update( [Column(table, field.left), Column(table, field.right)], [left, right], where=table.id == parent_id)) return right + 1 @classmethod def _update_tree(cls, record_id, field_name, left, right): ''' Update left, right values for the tree. Remarks: - the value (right - left - 1) / 2 will not give the number of children node ''' cursor = Transaction().connection.cursor() table = cls.__table__() left = Column(table, left) right = Column(table, right) field = Column(table, field_name) cursor.execute(*table.select(left, right, field, where=table.id == record_id)) fetchone = cursor.fetchone() if not fetchone: return old_left, old_right, parent_id = fetchone if old_left == old_right == 0: cursor.execute(*table.select(Max(right), where=field == Null)) old_left, = cursor.fetchone() old_left += 1 old_right = old_left + 1 cursor.execute(*table.update([left, right], [old_left, old_right], where=table.id == record_id)) size = old_right - old_left + 1 parent_right = 1 if parent_id: cursor.execute(*table.select(right, where=table.id == parent_id)) parent_right = cursor.fetchone()[0] else: cursor.execute(*table.select(Max(right), where=field == Null)) fetchone = cursor.fetchone() if fetchone: parent_right = fetchone[0] + 1 cursor.execute(*table.update([left], [left + size], where=left >= parent_right)) cursor.execute(*table.update([right], [right + size], where=right >= parent_right)) if old_left < parent_right: left_delta = parent_right - old_left right_delta = parent_right - old_left left_cond = old_left right_cond = old_right else: left_delta = parent_right - old_left - size right_delta = parent_right - old_left - size left_cond = old_left + size right_cond = old_right + size cursor.execute(*table.update([left, right], [left + left_delta, right + right_delta], where=(left >= left_cond) & (right <= right_cond))) @classmethod def validate(cls, records): super(ModelSQL, cls).validate(records) transaction = Transaction() database = transaction.database connection = transaction.connection has_constraint = database.has_constraint lock = database.lock cursor = transaction.connection.cursor() # Works only for a single transaction ids = list(map(int, records)) for _, sql, error in cls._sql_constraints: if has_constraint(sql): continue table = sql.table if isinstance(sql, (Unique, Exclude)): lock(connection, cls._table) columns = list(sql.columns) columns.insert(0, table.id) in_max = transaction.database.IN_MAX // (len(columns) + 1) for sub_ids in grouped_slice(ids, in_max): where = reduce_ids(table.id, sub_ids) if isinstance(sql, Exclude) and sql.where: where &= sql.where cursor.execute(*table.select(*columns, where=where)) where = Literal(False) for row in cursor: clause = table.id != row[0] for column, operator, value in zip( sql.columns, sql.operators, row[1:]): if value is None: # NULL is always unique clause &= Literal(False) clause &= operator(column, value) where |= clause if isinstance(sql, Exclude) and sql.where: where &= sql.where cursor.execute( *table.select(table.id, where=where, limit=1)) if cursor.fetchone(): raise SQLConstraintError(gettext(error)) elif isinstance(sql, Check): for sub_ids in grouped_slice(ids): red_sql = reduce_ids(table.id, sub_ids) cursor.execute(*table.select(table.id, where=~sql.expression & red_sql, limit=1)) if cursor.fetchone(): raise SQLConstraintError(gettext(error)) @dualmethod def lock(cls, records=None): transaction = Transaction() database = transaction.database connection = transaction.connection table = cls.__table__() if records is not None and database.has_select_for(): for sub_records in grouped_slice(records): where = reduce_ids(table.id, sub_records) query = table.select( Literal(1), where=where, for_=For('UPDATE', nowait=True)) with connection.cursor() as cursor: cursor.execute(*query) else: transaction.lock_table(cls._table) def convert_from(table, tables, type_='LEFT'): # Don't nested joins as SQLite doesn't support right, condition = tables[None] if table: table = table.join(right, type_, condition) else: table = right for k, sub_tables in tables.items(): if k is None: continue table = convert_from(table, sub_tables, type_=type_) return table def split_subquery_domain(domain): """ Split a domain in two parts: - the first one contains all the sub-domains with only local fields - the second one contains all the sub-domains using a related field The main operator of the domain will be stripped from the results. """ local_domains, subquery_domains = [], [] for sub_domain in domain: if is_leaf(sub_domain): if '.' in sub_domain[0]: subquery_domains.append(sub_domain) else: local_domains.append(sub_domain) elif (not sub_domain or list(sub_domain) in [['OR'], ['AND']] or sub_domain in ['OR', 'AND']): continue else: sub_ldomains, sub_sqdomains = split_subquery_domain(sub_domain) if sub_sqdomains: subquery_domains.append(sub_domain) else: local_domains.append(sub_domain) return local_domains, subquery_domains