Files
tradon/model/modelsql.py
2025-12-26 13:11:43 +00:00

2246 lines
89 KiB
Python
Executable File

# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import datetime
from collections import OrderedDict, defaultdict
from functools import wraps
from itertools import chain, groupby, islice, product, repeat
from sql import (
Asc, Column, Desc, Expression, For, Literal, Null, NullsFirst, NullsLast,
Table, Union, Window, With)
from sql.aggregate import Count, Max
from sql.conditionals import Coalesce
from sql.functions import CurrentTimestamp, Extract, RowNumber, Substring
from sql.operators import And, Concat, Equal, Exists, Operator, Or
from trytond import backend
from trytond.cache import freeze
from trytond.config import config
from trytond.exceptions import ConcurrencyException
from trytond.i18n import gettext
from trytond.pool import Pool
from trytond.pyson import PYSONDecoder, PYSONEncoder
from trytond.rpc import RPC
from trytond.tools import cursor_dict, grouped_slice, reduce_ids
from trytond.transaction import (
Transaction, inactive_records, record_cache_size, without_check_access)
from . import fields
from .descriptors import dualmethod
from .modelstorage import (
AccessError, ModelStorage, RequiredValidationError, SizeValidationError,
ValidationError, is_leaf)
from .modelview import ModelView
class ForeignKeyError(ValidationError):
pass
class SQLConstraintError(ValidationError):
pass
class Constraint(object):
__slots__ = ('_table',)
def __init__(self, table):
assert isinstance(table, Table)
self._table = table
@property
def table(self):
return self._table
def __str__(self):
raise NotImplementedError
@property
def params(self):
raise NotImplementedError
class Check(Constraint):
__slots__ = ('_expression',)
def __init__(self, table, expression):
super(Check, self).__init__(table)
assert isinstance(expression, Expression)
self._expression = expression
@property
def expression(self):
return self._expression
def __str__(self):
return 'CHECK(%s)' % self.expression
@property
def params(self):
return self.expression.params
class Unique(Constraint):
__slots__ = ('_columns',)
def __init__(self, table, *columns):
super(Unique, self).__init__(table)
assert all(isinstance(col, Column) for col in columns)
self._columns = tuple(columns)
@property
def columns(self):
return self._columns
@property
def operators(self):
return tuple(Equal for c in self._columns)
def __str__(self):
return 'UNIQUE(%s)' % (', '.join(map(str, self.columns)))
@property
def params(self):
p = []
for column in self.columns:
p.extend(column.params)
return tuple(p)
class Exclude(Constraint):
__slots__ = ('_excludes', '_where')
def __init__(self, table, *excludes, **kwargs):
super(Exclude, self).__init__(table)
assert all(isinstance(c, Expression) and issubclass(o, Operator)
for c, o in excludes), excludes
self._excludes = tuple(excludes)
where = kwargs.get('where')
if where is not None:
assert isinstance(where, Expression)
self._where = where
@property
def excludes(self):
return self._excludes
@property
def columns(self):
return tuple(c for c, _ in self._excludes)
@property
def operators(self):
return tuple(o for _, o in self._excludes)
@property
def where(self):
return self._where
def __str__(self):
exclude = ', '.join('%s WITH %s' % (column, operator._operator)
for column, operator in self.excludes)
where = ''
if self.where:
where = ' WHERE ' + str(self.where)
return 'EXCLUDE (%s)' % exclude + where
@property
def params(self):
p = []
for column, operator in self._excludes:
p.extend(column.params)
if self.where:
p.extend(self.where.params)
return tuple(p)
class Index:
__slots__ = ('table', 'expressions', 'options')
def __init__(self, table, *expressions, **options):
self.table = table
assert all(
isinstance(e, Expression) and isinstance(u, self.Usage)
for e, u in expressions)
self.expressions = expressions
self.options = options
def __hash__(self):
table_def = (
self.table._name, self.table._schema, self.table._database)
expressions = (
(str(e), e.params, hash(u)) for e, u in self.expressions)
return hash((table_def, *expressions))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplementedError
return (
str(self.table) == str(other.table)
and len(self.expressions) == len(other.expressions)
and all((str(c), u) == (str(oc), ou)
for (c, u), (oc, ou) in zip(
self.expressions, other.expressions))
and self._options_cmp == other._options_cmp)
@property
def _options_cmp(self):
def _format(value):
return str(value) if isinstance(value, Expression) else value
return {k: _format(v) for k, v in self.options.items()}
def __lt__(self, other):
if not isinstance(other, self.__class__):
return NotImplementedError
if (self.table != other.table
or self._options_cmp != other._options_cmp):
return False
if self == other:
return False
if len(self.expressions) >= len(other.expressions):
return False
for (c, u), (oc, ou) in zip(self.expressions, other.expressions):
if (str(c), u) != (str(oc), ou):
return False
return True
def __le__(self, other):
if not isinstance(other, self.__class__):
return NotImplementedError
return self < other or self == other
class Unaccent(Expression):
"Unaccent function if database support for index"
__slots__ = ('_expression',)
def __init__(self, expression):
self._expression = expression
@property
def expression(self):
expression = self._expression
database = Transaction().database
if database.has_unaccent_indexable():
expression = database.unaccent(expression)
return expression
def __str__(self):
return str(self.expression)
@property
def params(self):
return self.expression.params
class Usage:
__slots__ = ('options',)
def __init__(self, **options):
self.options = options
def __hash__(self):
return hash((self.__class__.__name__, *self.options.items()))
def __eq__(self, other):
return (self.__class__ == other.__class__
and self.options == other.options)
class Equality(Usage):
__slots__ = ()
class Range(Usage):
__slots__ = ()
class Similarity(Usage):
__slots__ = ()
def no_table_query(func):
@wraps(func)
def wrapper(cls, *args, **kwargs):
if callable(cls.table_query):
raise NotImplementedError("On table_query")
return func(cls, *args, **kwargs)
return wrapper
class ModelSQL(ModelStorage):
"""
Define a model with storage in database.
"""
__slots__ = ()
_table = None # The name of the table in database
_order = None
_order_name = None # Use to force order field when sorting on Many2One
_history = False
table_query = None
@classmethod
def __setup__(cls):
cls._table = config.get('table', cls.__name__, default=cls._table)
if not cls._table:
cls._table = cls.__name__.replace('.', '_')
assert cls._table[-9:] != '__history', \
'Model _table %s cannot end with "__history"' % cls._table
super(ModelSQL, cls).__setup__()
cls._sql_constraints = []
cls._sql_indexes = set()
cls._history_sql_indexes = set()
if not callable(cls.table_query):
table = cls.__table__()
cls._sql_constraints.append(
('id_positive', Check(table, table.id >= 0),
'ir.msg_id_positive'))
rec_name_field = getattr(cls, cls._rec_name, None)
if (isinstance(rec_name_field, fields.Field)
and not hasattr(rec_name_field, 'set')):
column = Column(table, cls._rec_name)
if getattr(rec_name_field, 'search_unaccented', False):
column = Index.Unaccent(column)
cls._sql_indexes.add(
Index(table, (column, Index.Similarity())))
cls._order = [('id', None)]
if issubclass(cls, ModelView):
cls.__rpc__.update({
'history_revisions': RPC(),
})
if cls._history:
history_table = cls.__table_history__()
cls._history_sql_indexes.update({
Index(
history_table,
(history_table.id, Index.Equality())),
Index(
history_table,
(Coalesce(
history_table.write_date,
history_table.create_date).desc,
Index.Range()),
include=[
Column(history_table, '__id'),
history_table.id]),
})
@classmethod
def __post_setup__(cls):
super().__post_setup__()
# Define Range index to optimise with reduce_ids
for field in cls._fields.values():
field_names = set()
if isinstance(field, fields.One2Many):
Target = field.get_target()
if field.field:
field_names.add(field.field)
elif isinstance(field, fields.Many2Many):
Target = field.get_relation()
if field.origin:
field_names.add(field.origin)
if field.target:
field_names.add(field.target)
else:
continue
field_names.discard('id')
for field_name in field_names:
target_field = getattr(Target, field_name)
if (issubclass(Target, ModelSQL)
and not callable(Target.table_query)
and not hasattr(target_field, 'set')):
target = Target.__table__()
column = Column(target, field_name)
if not target_field.required and Target != cls:
where = column != Null
else:
where = None
if target_field._type == 'reference':
Target._sql_indexes.update({
Index(
target,
(column, Index.Equality()),
where=where),
Index(
target,
(column, Index.Similarity(begin=True)),
(target_field.sql_id(column, Target),
Index.Range()),
where=where),
})
else:
Target._sql_indexes.add(
Index(
target,
(column, Index.Range()),
where=where))
@classmethod
def __table__(cls):
if callable(cls.table_query):
return cls.table_query()
else:
return Table(cls._table)
@classmethod
def __table_history__(cls):
if not cls._history:
raise ValueError('No history table')
return Table(cls._table + '__history')
@classmethod
def __table_handler__(cls, module_name=None, history=False):
return backend.TableHandler(cls, history=history)
@classmethod
def __register__(cls, module_name):
cursor = Transaction().connection.cursor()
super(ModelSQL, cls).__register__(module_name)
if callable(cls.table_query):
return
pool = Pool()
# Initiate after the callable test to prevent calling table_query which
# may rely on other model being registered
sql_table = cls.__table__()
# create/update table in the database
table = cls.__table_handler__(module_name)
if cls._history:
history_table = cls.__table_handler__(module_name, history=True)
for field_name, field in cls._fields.items():
if field_name == 'id':
continue
sql_type = field.sql_type()
if not sql_type:
continue
if field_name in cls._defaults:
def default():
default_ = cls._clean_defaults({
field_name: cls._defaults[field_name](),
})[field_name]
return field.sql_format(default_)
else:
default = None
table.add_column(field_name, field._sql_type, default=default)
if cls._history:
history_table.add_column(field_name, field._sql_type)
if isinstance(field, (fields.Integer, fields.Float)):
# migration from tryton 2.2
table.db_default(field_name, None)
if isinstance(field, (fields.Boolean)):
table.db_default(field_name, False)
if isinstance(field, fields.Many2One):
if field.model_name in ('res.user', 'res.group'):
# XXX need to merge ir and res
ref = field.model_name.replace('.', '_')
else:
ref_model = pool.get(field.model_name)
if (issubclass(ref_model, ModelSQL)
and not callable(ref_model.table_query)):
ref = ref_model._table
# Create foreign key table if missing
if not backend.TableHandler.table_exist(ref):
backend.TableHandler(ref_model)
else:
ref = None
if field_name in ['create_uid', 'write_uid']:
# migration from 3.6
table.drop_fk(field_name)
elif ref:
table.add_fk(field_name, ref, on_delete=field.ondelete)
required = field.required
# Do not set 'NOT NULL' for Binary field as the database column
# will be left empty if stored in the filestore or filled later by
# the set method.
if isinstance(field, fields.Binary):
required = False
table.not_null_action(
field_name, action=required and 'add' or 'remove')
for field_name, field in cls._fields.items():
if (isinstance(field, fields.Many2One)
and field.model_name == cls.__name__):
if field.path:
default_path = cls._defaults.get(
field.path, lambda: None)()
cursor.execute(*sql_table.select(sql_table.id,
where=(
Column(sql_table, field.path) == default_path)
| (Column(sql_table, field.path) == Null),
limit=1))
if cursor.fetchone():
cls._rebuild_path(field_name)
if field.left and field.right:
left_default = cls._defaults.get(
field.left, lambda: None)()
right_default = cls._defaults.get(
field.right, lambda: None)()
cursor.execute(*sql_table.select(sql_table.id,
where=(
Column(sql_table, field.left) == left_default)
| (Column(sql_table, field.left) == Null)
| (Column(sql_table, field.right) == right_default)
| (Column(sql_table, field.right) == Null),
limit=1))
if cursor.fetchone():
cls._rebuild_tree(field_name, None, 0)
for ident, constraint, _ in cls._sql_constraints:
assert (
not ident.startswith('idx_') and not ident.endswith('_index'))
table.add_constraint(ident, constraint)
if cls._history:
cls._update_history_table()
h_table = cls.__table_history__()
cursor.execute(*sql_table.select(sql_table.id, limit=1))
if cursor.fetchone():
cursor.execute(
*h_table.select(h_table.id, limit=1))
if not cursor.fetchone():
columns = [n for n, f in cls._fields.items()
if f.sql_type()]
cursor.execute(*h_table.insert(
[Column(h_table, c) for c in columns],
sql_table.select(*(Column(sql_table, c)
for c in columns))))
cursor.execute(*h_table.update(
[h_table.write_date], [None]))
@classmethod
def _update_sql_indexes(cls, concurrently=False):
def no_subset(index):
for j in cls._sql_indexes:
if j != index and index < j:
return False
return True
if not callable(cls.table_query):
table_h = cls.__table_handler__()
indexes = filter(no_subset, cls._sql_indexes)
table_h.set_indexes(indexes, concurrently=concurrently)
if cls._history:
history_th = cls.__table_handler__(history=True)
indexes = filter(no_subset, cls._history_sql_indexes)
history_th.set_indexes(indexes, concurrently=concurrently)
@classmethod
def _update_history_table(cls):
if cls._history:
history_table = cls.__table_handler__(history=True)
for field_name, field in cls._fields.items():
if not field.sql_type():
continue
history_table.add_column(field_name, field._sql_type)
@classmethod
@without_check_access
def __raise_integrity_error(
cls, exception, values, field_names=None, transaction=None):
pool = Pool()
if field_names is None:
field_names = list(cls._fields.keys())
if transaction is None:
transaction = Transaction()
for field_name in field_names:
if field_name not in cls._fields:
continue
field = cls._fields[field_name]
# Check required fields
if (field.required
and field.sql_type()
and field_name not in ('create_uid', 'create_date')):
if values.get(field_name) is None:
raise RequiredValidationError(
gettext('ir.msg_required_validation',
**cls.__names__(field_name)))
for name, _, error in cls._sql_constraints:
if backend.TableHandler.convert_name(name) in str(exception):
raise SQLConstraintError(gettext(error))
# Check foreign key in last because this can raise false positive
# if the target is created during the same transaction.
for field_name in field_names:
if field_name not in cls._fields:
continue
field = cls._fields[field_name]
if isinstance(field, fields.Many2One) and values.get(field_name):
Model = pool.get(field.model_name)
create_records = transaction.create_records[field.model_name]
delete_records = transaction.delete_records[field.model_name]
target_records = Model.search([
('id', '=', field.sql_format(values[field_name])),
], order=[])
if not ((
target_records
or (values[field_name] in create_records))
and (values[field_name] not in delete_records)):
error_args = cls.__names__(field_name)
error_args['value'] = values[field_name]
raise ForeignKeyError(
gettext('ir.msg_foreign_model_missing',
**error_args))
@classmethod
@without_check_access
def __raise_data_error(
cls, exception, values, field_names=None, transaction=None):
if field_names is None:
field_names = list(cls._fields.keys())
if transaction is None:
transaction = Transaction()
for field_name in field_names:
if field_name not in cls._fields:
continue
field = cls._fields[field_name]
# Check field size
if (hasattr(field, 'size')
and isinstance(field.size, int)
and field.sql_type()):
value = values.get(field_name) or ''
size = len(value)
if size > field.size:
error_args = cls.__names__(field_name)
error_args['value'] = value
error_args['size'] = size
error_args['max_size'] = field.size
raise SizeValidationError(
gettext('ir.msg_size_validation', **error_args))
@classmethod
def history_revisions(cls, ids):
pool = Pool()
ModelAccess = pool.get('ir.model.access')
User = pool.get('res.user')
cursor = Transaction().connection.cursor()
ModelAccess.check(cls.__name__, 'read')
table = cls.__table_history__()
user = User.__table__()
revisions = []
for sub_ids in grouped_slice(ids):
where = reduce_ids(table.id, sub_ids)
cursor.execute(*table.join(user, 'LEFT',
Coalesce(table.write_uid, table.create_uid) == user.id)
.select(
Coalesce(table.write_date, table.create_date),
table.id,
user.name,
where=where))
revisions.append(cursor.fetchall())
revisions = list(chain(*revisions))
revisions.sort(reverse=True)
# SQLite uses char for COALESCE
if revisions and isinstance(revisions[0][0], str):
strptime = datetime.datetime.strptime
format_ = '%Y-%m-%d %H:%M:%S.%f'
revisions = [(strptime(timestamp, format_), id_, name)
for timestamp, id_, name in revisions]
return revisions
@classmethod
def _insert_history(cls, ids, deleted=False):
transaction = Transaction()
cursor = transaction.connection.cursor()
if not cls._history:
return
user = transaction.user
table = cls.__table__()
history = cls.__table_history__()
columns = []
hcolumns = []
if not deleted:
fields = cls._fields
else:
fields = {
'id': cls.id,
'write_uid': cls.write_uid,
'write_date': cls.write_date,
}
for fname, field in sorted(fields.items()):
if not field.sql_type():
continue
columns.append(Column(table, fname))
hcolumns.append(Column(history, fname))
for sub_ids in grouped_slice(ids):
if not deleted:
where = reduce_ids(table.id, sub_ids)
cursor.execute(*history.insert(hcolumns,
table.select(*columns, where=where)))
else:
if transaction.database.has_multirow_insert():
cursor.execute(*history.insert(hcolumns,
[[id_, CurrentTimestamp(), user]
for id_ in sub_ids]))
else:
for id_ in sub_ids:
cursor.execute(*history.insert(hcolumns,
[[id_, CurrentTimestamp(), user]]))
@classmethod
def _restore_history(cls, ids, datetime, _before=False):
if not cls._history:
return
transaction = Transaction()
cursor = transaction.connection.cursor()
table = cls.__table__()
history = cls.__table_history__()
transaction.counter += 1
for cache in transaction.cache.values():
if cls.__name__ in cache:
cache_cls = cache[cls.__name__]
for id_ in ids:
cache_cls.pop(id_, None)
columns = []
hcolumns = []
fnames = sorted(n for n, f in cls._fields.items()
if f.sql_type())
for fname in fnames:
columns.append(Column(table, fname))
if fname == 'write_uid':
hcolumns.append(Literal(transaction.user))
elif fname == 'write_date':
hcolumns.append(CurrentTimestamp())
else:
hcolumns.append(Column(history, fname))
def is_deleted(values):
return all(not v for n, v in zip(fnames, values)
if n not in ['id', 'write_uid', 'write_date'])
to_delete = []
to_update = []
for id_ in ids:
column_datetime = Coalesce(history.write_date, history.create_date)
if not _before:
hwhere = (column_datetime <= datetime)
else:
hwhere = (column_datetime < datetime)
hwhere &= (history.id == id_)
horder = (column_datetime.desc, Column(history, '__id').desc)
cursor.execute(*history.select(*hcolumns,
where=hwhere, order_by=horder, limit=1))
values = cursor.fetchone()
if not values or is_deleted(values):
to_delete.append(id_)
else:
to_update.append(id_)
values = list(values)
cursor.execute(*table.update(columns, values,
where=table.id == id_))
rowcount = cursor.rowcount
if rowcount == -1 or rowcount is None:
cursor.execute(*table.select(table.id,
where=table.id == id_))
rowcount = len(cursor.fetchall())
if rowcount < 1:
cursor.execute(*table.insert(columns, [values]))
if to_delete:
for sub_ids in grouped_slice(to_delete):
where = reduce_ids(table.id, sub_ids)
cursor.execute(*table.delete(where=where))
cls._insert_history(to_delete, True)
if to_update:
cls._insert_history(to_update)
@classmethod
def restore_history(cls, ids, datetime):
'Restore record ids from history at the date time'
cls._restore_history(ids, datetime)
@classmethod
def restore_history_before(cls, ids, datetime):
'Restore record ids from history before the date time'
cls._restore_history(ids, datetime, _before=True)
@classmethod
def __check_timestamp(cls, ids):
transaction = Transaction()
cursor = transaction.connection.cursor()
table = cls.__table__()
if not transaction.timestamp:
return
for sub_ids in grouped_slice(ids):
where = Or()
for id_ in sub_ids:
try:
timestamp = transaction.timestamp.pop(
'%s,%s' % (cls.__name__, id_))
except KeyError:
continue
if timestamp is None:
continue
sql_type = fields.Char('timestamp').sql_type().base
where.append((table.id == id_)
& (Extract('EPOCH',
Coalesce(table.write_date, table.create_date)
).cast(sql_type) != timestamp))
if where:
cursor.execute(*table.select(table.id, where=where, limit=1))
if cursor.fetchone():
raise ConcurrencyException(
'Records were modified in the meanwhile')
@classmethod
@no_table_query
def create(cls, vlist):
transaction = Transaction()
cursor = transaction.connection.cursor()
in_max = transaction.database.IN_MAX
pool = Pool()
Translation = pool.get('ir.translation')
super(ModelSQL, cls).create(vlist)
table = cls.__table__()
modified_fields = set()
defaults_cache = {} # Store already computed default values
missing_defaults = {} # Store missing default values by schema
new_ids = []
vlist = [v.copy() for v in vlist]
def db_insert(columns, vlist, column_names):
if transaction.database.has_multirow_insert():
vlist = (
s for s in grouped_slice(
vlist, in_max // (len(column_names) or 1)))
else:
vlist = ([v] for v in vlist)
for values in vlist:
values = list(values)
cols = list(columns)
try:
if len(values) > 1:
ids = transaction.database.nextid(
transaction.connection, cls._table,
count=len(values))
if ids is not None:
assert len(ids) == len(values)
cols.append(table.id)
for val, id in zip(values, ids):
val.append(id)
cursor.execute(*table.insert(cols, values))
yield from ids
continue
for i, val in enumerate(values):
if transaction.database.has_returning():
cursor.execute(*table.insert(
cols, [val], [table.id]))
yield from (r[0] for r in cursor)
else:
id_new = transaction.database.nextid(
transaction.connection, cls._table)
if id_new:
if i == 0:
cols.append(table.id)
val.append(id_new)
cursor.execute(*table.insert(cols, [val]))
else:
cursor.execute(*table.insert(cols, [val]))
id_new = transaction.database.lastid(cursor)
yield id_new
except (
backend.DatabaseIntegrityError,
backend.DatabaseDataError
) as exception:
if isinstance(exception, backend.DatabaseIntegrityError):
raise_func = cls.__raise_integrity_error
elif isinstance(exception, backend.DatabaseDataError):
raise_func = cls.__raise_data_error
with transaction.new_transaction():
for value in values:
skip = len(['create_uid', 'create_date'])
recomposed = dict(zip(column_names, value[skip:]))
raise_func(
exception, recomposed, transaction=transaction)
raise
to_insert = []
previous_columns = [table.create_uid, table.create_date]
previous_column_names = []
for values in vlist:
# Clean values
for key in ('create_uid', 'create_date',
'write_uid', 'write_date', 'id'):
if key in values:
del values[key]
modified_fields |= values.keys()
# Get default values
values_schema = tuple(sorted(values))
if values_schema not in missing_defaults:
default = []
missing_defaults[values_schema] = default_values = {}
for fname, field in cls._fields.items():
if fname in values:
continue
if fname in [
'create_uid', 'create_date',
'write_uid', 'write_date', 'id']:
continue
if isinstance(field, fields.Function) and not field.setter:
continue
if fname in defaults_cache:
default_values[fname] = defaults_cache[fname]
else:
default.append(fname)
if default:
defaults = cls.default_get(default, with_rec_name=False)
default_values.update(cls._clean_defaults(defaults))
defaults_cache.update(default_values)
values.update(missing_defaults[values_schema])
current_column_names = []
current_columns = [table.create_uid, table.create_date]
current_values = [transaction.user, CurrentTimestamp()]
# Insert record
for fname, value in sorted(values.items()):
field = cls._fields[fname]
if not hasattr(field, 'set'):
current_columns.append(Column(table, fname))
current_values.append(field.sql_format(value))
current_column_names.append(fname)
if current_column_names != previous_column_names:
if to_insert:
new_ids.extend(db_insert(
previous_columns, to_insert,
previous_column_names))
to_insert.clear()
previous_columns = current_columns
previous_column_names = current_column_names
to_insert.append(current_values)
else:
if to_insert:
new_ids.extend(db_insert(
current_columns, to_insert, current_column_names))
transaction.create_records[cls.__name__].update(new_ids)
# Update path before fields_to_set which could create children
if cls._path_fields:
field_names = list(sorted(cls._path_fields))
cls._set_path(field_names, repeat(new_ids, len(field_names)))
# Update mptt before fields_to_set which could create children
if cls._mptt_fields:
field_names = list(sorted(cls._mptt_fields))
cls._update_mptt(field_names, repeat(new_ids, len(field_names)))
translation_values = {}
fields_to_set = {}
for values, new_id in zip(vlist, new_ids):
for fname, value in values.items():
field = cls._fields[fname]
if (getattr(field, 'translate', False)
and not hasattr(field, 'set')):
translation_values.setdefault(
'%s,%s' % (cls.__name__, fname), {})[new_id] = (
field.sql_format(value))
if hasattr(field, 'set'):
args = fields_to_set.setdefault(fname, [])
actions = iter(args)
for ids, val in zip(actions, actions):
if val == value:
ids.append(new_id)
break
else:
args.extend(([new_id], value))
if translation_values:
for name, translations in translation_values.items():
Translation.set_ids(name, 'model', Transaction().language,
list(translations.keys()), list(translations.values()))
for fname in sorted(fields_to_set, key=cls.index_set_field):
fargs = fields_to_set[fname]
field = cls._fields[fname]
field.set(cls, fname, *fargs)
cls._insert_history(new_ids)
cls.__check_domain_rule(new_ids, 'create')
records = cls.browse(new_ids)
for sub_records in grouped_slice(
records, record_cache_size(transaction)):
cls._validate(sub_records)
cls.trigger_create(records)
return records
@classmethod
def read(cls, ids, fields_names):
pool = Pool()
Rule = pool.get('ir.rule')
Translation = pool.get('ir.translation')
super(ModelSQL, cls).read(ids, fields_names=fields_names)
transaction = Transaction()
cursor = Transaction().connection.cursor()
if not ids:
return []
# construct a clause for the rules :
domain = Rule.domain_get(cls.__name__, mode='read')
fields_related = defaultdict(set)
extra_fields = set()
if 'write_date' not in fields_names:
extra_fields.add('write_date')
for field_name in fields_names:
if field_name in {'_timestamp', '_write', '_delete'}:
continue
if '.' in field_name:
field_name, field_related = field_name.split('.', 1)
fields_related[field_name].add(field_related)
if field_name.endswith(':string'):
field_name = field_name[:-len(':string')]
fields_related[field_name]
field = cls._fields[field_name]
if hasattr(field, 'datetime_field') and field.datetime_field:
extra_fields.add(field.datetime_field)
if field.context:
extra_fields.update(fields.get_eval_fields(field.context))
extra_fields.discard('id')
all_fields = (
set(fields_names) | fields_related.keys() | extra_fields)
result = []
table = cls.__table__()
in_max = transaction.database.IN_MAX
history_order = None
history_clause = None
history_limit = None
if (cls._history
and transaction.context.get('_datetime')
and not callable(cls.table_query)):
in_max = 1
table = cls.__table_history__()
column = Coalesce(table.write_date, table.create_date)
history_clause = (column <= Transaction().context['_datetime'])
history_order = (column.desc, Column(table, '__id').desc)
history_limit = 1
columns = {}
for f in all_fields:
field = cls._fields.get(f)
if field and field.sql_type():
columns[f] = field.sql_column(table).as_(f)
if backend.name == 'sqlite':
columns[f].output_name += ' [%s]' % field.sql_type().base
elif f in {'_write', '_delete'}:
if not callable(cls.table_query):
rule_domain = Rule.domain_get(
cls.__name__, mode=f.lstrip('_'))
# No need to compute rule domain if it is the same as the
# read rule domain because it is already applied as where
# clause.
if rule_domain and rule_domain != domain:
rule_tables = {None: (table, None)}
rule_tables, rule_expression = cls.search_domain(
rule_domain, active_test=False, tables=rule_tables)
if len(rule_tables) > 1:
# The expression uses another table
rule_tables, rule_expression = cls.search_domain(
rule_domain, active_test=False)
rule_from = convert_from(None, rule_tables)
rule_table, _ = rule_tables[None]
rule_where = rule_table.id == table.id
rule_expression = rule_from.select(
rule_expression, where=rule_where)
columns[f] = rule_expression.as_(f)
else:
columns[f] = Literal(True).as_(f)
elif f == '_timestamp' and not callable(cls.table_query):
sql_type = fields.Char('timestamp').sql_type().base
columns[f] = Extract(
'EPOCH', Coalesce(table.write_date, table.create_date)
).cast(sql_type).as_('_timestamp')
if ('write_date' not in fields_names
and columns.keys() == {'write_date'}):
columns.pop('write_date')
extra_fields.discard('write_date')
if columns or domain:
if 'id' not in fields_names:
columns['id'] = table.id.as_('id')
tables = {None: (table, None)}
if domain:
tables, dom_exp = cls.search_domain(
domain, active_test=False, tables=tables)
from_ = convert_from(None, tables)
for sub_ids in grouped_slice(ids, in_max):
sub_ids = list(sub_ids)
red_sql = reduce_ids(table.id, sub_ids)
where = red_sql
if history_clause:
where &= history_clause
if domain:
where &= dom_exp
cursor.execute(*from_.select(*columns.values(), where=where,
order_by=history_order, limit=history_limit))
fetchall = list(cursor_dict(cursor))
if not len(fetchall) == len({}.fromkeys(sub_ids)):
cls.__check_domain_rule(ids, 'read')
raise RuntimeError("Undetected access error")
result.extend(fetchall)
else:
result = [{'id': x} for x in ids]
cachable_fields = []
max_write_date = max(
(r['write_date'] for r in result if r.get('write_date')),
default=None)
for fname, column in columns.items():
if fname.startswith('_'):
continue
field = cls._fields[fname]
if not hasattr(field, 'get'):
if getattr(field, 'translate', False):
translations = Translation.get_ids(
cls.__name__ + ',' + fname, 'model',
Transaction().language, ids,
cached_after=max_write_date)
for row in result:
row[fname] = translations.get(row['id']) or row[fname]
if fname != 'id':
cachable_fields.append(fname)
# all fields for which there is a get attribute
getter_fields = [f for f in all_fields
if f in cls._fields and hasattr(cls._fields[f], 'get')]
getter_fields = sorted(getter_fields, key=cls.index_get_field)
cache = transaction.get_cache()[cls.__name__]
if getter_fields and cachable_fields:
for row in result:
for fname in cachable_fields:
cache[row['id']][fname] = row[fname]
func_fields = {}
for fname in getter_fields:
field = cls._fields[fname]
if isinstance(field, fields.Function):
key = (
field.getter, field.getter_with_context,
getattr(field, 'datetime_field', None))
func_fields.setdefault(key, [])
func_fields[key].append(fname)
elif getattr(field, 'datetime_field', None):
for row in result:
with Transaction().set_context(
_datetime=row[field.datetime_field]):
date_result = field.get([row['id']], cls, fname,
values=[row])
row[fname] = date_result[row['id']]
else:
# get the value of that field for all records/ids
getter_result = field.get(ids, cls, fname, values=result)
for row in result:
row[fname] = getter_result[row['id']]
for key in func_fields:
field_list = func_fields[key]
fname = field_list[0]
field = cls._fields[fname]
_, getter_with_context, datetime_field = key
if datetime_field:
for row in result:
with Transaction().set_context(
_datetime=row[datetime_field]):
date_results = field.get([row['id']], cls, field_list,
values=[row])
for fname in field_list:
date_result = date_results[fname]
row[fname] = date_result[row['id']]
else:
for sub_results in grouped_slice(
result, record_cache_size(transaction)):
sub_results = list(sub_results)
sub_ids = []
sub_values = []
for row in sub_results:
if (row['id'] not in cache
or any(f not in cache[row['id']]
for f in field_list)):
sub_ids.append(row['id'])
sub_values.append(row)
else:
for fname in field_list:
row[fname] = cache[row['id']][fname]
getter_results = field.get(
sub_ids, cls, field_list, values=sub_values)
for fname in field_list:
getter_result = getter_results[fname]
for row in sub_values:
row[fname] = getter_result[row['id']]
if (transaction.readonly
and not getter_with_context):
cache[row['id']][fname] = row[fname]
def read_related(field, Target, rows, fields):
name = field.name
target_ids = []
if field._type.endswith('2many'):
add = target_ids.extend
elif field._type == 'reference':
def add(value):
try:
id_ = int(value.split(',', 1)[1])
except ValueError:
pass
else:
if id_ >= 0:
target_ids.append(id_)
else:
add = target_ids.append
for row in rows:
value = row[name]
if value is not None:
add(value)
related_read_limit = transaction.context.get('related_read_limit')
rows = Target.read(target_ids[:related_read_limit], fields)
if related_read_limit is not None:
rows += [{'id': i} for i in target_ids[related_read_limit:]]
return rows
def add_related(field, rows, targets):
name = field.name
key = name + '.'
if field._type.endswith('2many'):
for row in rows:
row[key] = values = list()
for target in row[name]:
if target is not None:
values.append(targets[target])
elif field._type in {'selection', 'multiselection'}:
key = name + ':string'
for row, target in zip(rows, targets):
selection = field.get_selection(cls, name, target)
if field._type == 'selection':
row[key] = field.get_selection_string(
selection, row[name])
else:
row[key] = [
field.get_selection_string(selection, v)
for v in row[name]]
else:
for row in rows:
value = row[name]
if isinstance(value, str):
try:
value = int(value.split(',', 1)[1])
except ValueError:
value = None
if value is not None and value >= 0:
row[key] = targets[value]
else:
row[key] = None
to_del = set()
for fname in fields_related.keys() | extra_fields:
if fname not in fields_names:
to_del.add(fname)
if fname not in cls._fields:
continue
if fname not in fields_related:
continue
field = cls._fields[fname]
datetime_field = getattr(field, 'datetime_field', None)
def groupfunc(row):
ctx = {}
if field.context:
pyson_context = PYSONEncoder().encode(field.context)
ctx.update(PYSONDecoder(row).decode(pyson_context))
if datetime_field:
ctx['_datetime'] = row.get(datetime_field)
if field._type in {'selection', 'multiselection'}:
Target = None
elif field._type == 'reference':
value = row[fname]
if not value:
Target = None
else:
model, _ = value.split(',', 1)
Target = pool.get(model)
else:
Target = field.get_target()
return Target, ctx
def orderfunc(row):
Target, ctx = groupfunc(row)
return (Target.__name__ if Target else '', freeze(ctx))
for (Target, ctx), rows in groupby(
sorted(result, key=orderfunc), key=groupfunc):
rows = list(rows)
with Transaction().set_context(ctx):
if Target:
targets = read_related(
field, Target, rows, list(fields_related[fname]))
targets = {t['id']: t for t in targets}
else:
targets = cls.browse([r['id'] for r in rows])
add_related(field, rows, targets)
for row, field in product(result, to_del):
del row[field]
return result
@classmethod
@no_table_query
def write(cls, records, values, *args):
transaction = Transaction()
cursor = transaction.connection.cursor()
pool = Pool()
Translation = pool.get('ir.translation')
Config = pool.get('ir.configuration')
assert not len(args) % 2
# Remove possible duplicates from all records
all_records = list(OrderedDict.fromkeys(
sum(((records, values) + args)[0:None:2], [])))
all_ids = [r.id for r in all_records]
all_field_names = set()
# Call before cursor cache cleaning
trigger_eligibles = cls.trigger_write_get_eligibles(all_records)
super(ModelSQL, cls).write(records, values, *args)
table = cls.__table__()
cls.__check_timestamp(all_ids)
cls.__check_domain_rule(all_ids, 'write')
fields_to_set = {}
actions = iter((records, values) + args)
for records, values in zip(actions, actions):
ids = [r.id for r in records]
values = values.copy()
# Clean values
for key in ('create_uid', 'create_date',
'write_uid', 'write_date', 'id'):
if key in values:
del values[key]
columns = [table.write_uid, table.write_date]
update_values = [transaction.user, CurrentTimestamp()]
store_translation = Transaction().language == Config.get_language()
for fname, value in values.items():
field = cls._fields[fname]
if not hasattr(field, 'set'):
if (not getattr(field, 'translate', False)
or store_translation):
columns.append(Column(table, fname))
update_values.append(field.sql_format(value))
for sub_ids in grouped_slice(ids):
red_sql = reduce_ids(table.id, sub_ids)
try:
cursor.execute(*table.update(columns, update_values,
where=red_sql))
except (
backend.DatabaseIntegrityError,
backend.DatabaseDataError) as exception:
transaction = Transaction()
with Transaction().new_transaction():
if isinstance(
exception, backend.DatabaseIntegrityError):
cls.__raise_integrity_error(
exception, values, list(values.keys()),
transaction=transaction)
elif isinstance(exception, backend.DatabaseDataError):
cls.__raise_data_error(
exception, values, list(values.keys()),
transaction=transaction)
raise
for fname, value in values.items():
field = cls._fields[fname]
if (getattr(field, 'translate', False)
and not hasattr(field, 'set')):
Translation.set_ids(
'%s,%s' % (cls.__name__, fname), 'model',
transaction.language, ids,
[field.sql_format(value)] * len(ids))
if hasattr(field, 'set'):
fields_to_set.setdefault(fname, []).extend((ids, value))
path_fields = cls._path_fields & values.keys()
if path_fields:
cls._update_path(
list(sorted(path_fields)), repeat(ids, len(path_fields)))
mptt_fields = cls._mptt_fields & values.keys()
if mptt_fields:
cls._update_mptt(
list(sorted(mptt_fields)), repeat(ids, len(mptt_fields)),
values)
all_field_names |= values.keys()
for fname in sorted(fields_to_set, key=cls.index_set_field):
fargs = fields_to_set[fname]
field = cls._fields[fname]
field.set(cls, fname, *fargs)
cls._insert_history(all_ids)
cls.__check_domain_rule(all_ids, 'write')
for sub_records in grouped_slice(
all_records, record_cache_size(transaction)):
cls._validate(sub_records, field_names=all_field_names)
cls.trigger_write(trigger_eligibles)
@classmethod
@no_table_query
def delete(cls, records):
transaction = Transaction()
in_max = transaction.database.IN_MAX
cursor = transaction.connection.cursor()
pool = Pool()
Translation = pool.get('ir.translation')
ids = list(map(int, records))
if not ids:
return
table = cls.__table__()
if cls.__name__ in transaction.delete_records:
ids = ids[:]
for del_id in transaction.delete_records[cls.__name__]:
while ids:
try:
ids.remove(del_id)
except ValueError:
break
cls.__check_timestamp(ids)
cls.__check_domain_rule(ids, 'delete')
tree_ids = {}
for fname in cls._mptt_fields:
field = cls._fields[fname]
tree_ids[fname] = []
for sub_ids in grouped_slice(ids):
where = reduce_ids(field.sql_column(table), sub_ids)
cursor.execute(*table.select(table.id, where=where))
tree_ids[fname] += [x[0] for x in cursor]
has_translation = any(
getattr(f, 'translate', False) and not hasattr(f, 'set')
for f in cls._fields.values())
foreign_keys_tocheck = []
foreign_keys_toupdate = []
foreign_keys_todelete = []
for _, model in pool.iterobject():
if (not issubclass(model, ModelStorage)
or callable(getattr(model, 'table_query', None))):
continue
for field_name, field in model._fields.items():
if (isinstance(field, fields.Many2One)
and field.model_name == cls.__name__):
if field.ondelete == 'CASCADE':
foreign_keys_todelete.append((model, field_name))
elif field.ondelete == 'SET NULL':
if field.required:
foreign_keys_tocheck.append((model, field_name))
else:
foreign_keys_toupdate.append((model, field_name))
else:
foreign_keys_tocheck.append((model, field_name))
transaction.delete_records[cls.__name__].update(ids)
cls.trigger_delete(records)
if len(records) > in_max:
# Clean self referencing foreign keys
# before deleting them by small groups
# Use the record id as value instead of NULL
# in case the field is required
foreign_fields_to_clean = [
fn for m, fn in foreign_keys_tocheck if m == cls]
if foreign_fields_to_clean:
for sub_ids in grouped_slice(ids):
columns = [
Column(table, n) for n in foreign_fields_to_clean]
cursor.execute(*table.update(
columns, [table.id] * len(foreign_fields_to_clean),
where=reduce_ids(table.id, sub_ids)))
def get_related_records(Model, field_name, sub_ids):
if issubclass(Model, ModelSQL):
foreign_table = Model.__table__()
foreign_red_sql = reduce_ids(
Column(foreign_table, field_name), sub_ids)
cursor.execute(*foreign_table.select(foreign_table.id,
where=foreign_red_sql))
related_records = Model.browse([x[0] for x in cursor])
else:
with without_check_access(), inactive_records():
related_records = Model.search(
[(field_name, 'in', sub_ids)],
order=[])
if Model == cls:
related_records = list(set(related_records) - set(records))
return related_records
for sub_ids, sub_records in zip(
grouped_slice(ids), grouped_slice(records)):
sub_ids = list(sub_ids)
red_sql = reduce_ids(table.id, sub_ids)
for Model, field_name in foreign_keys_toupdate:
related_records = get_related_records(
Model, field_name, sub_ids)
if related_records:
Model.write(related_records, {
field_name: None,
})
for Model, field_name in foreign_keys_todelete:
related_records = get_related_records(
Model, field_name, sub_ids)
if related_records:
Model.delete(related_records)
for Model, field_name in foreign_keys_tocheck:
if get_related_records(Model, field_name, sub_ids):
error_args = Model.__names__(field_name)
raise ForeignKeyError(
gettext('ir.msg_foreign_model_exist',
**error_args))
super(ModelSQL, cls).delete(list(sub_records))
try:
cursor.execute(*table.delete(where=red_sql))
except backend.DatabaseIntegrityError as exception:
transaction = Transaction()
with Transaction().new_transaction():
cls.__raise_integrity_error(
exception, {}, transaction=transaction)
raise
if has_translation:
Translation.delete_ids(cls.__name__, 'model', ids)
cls._insert_history(ids, deleted=True)
cls._update_mptt(list(tree_ids.keys()), list(tree_ids.values()))
@classmethod
def __check_domain_rule(cls, ids, mode):
pool = Pool()
Lang = pool.get('ir.lang')
Rule = pool.get('ir.rule')
Model = pool.get('ir.model')
try:
User = pool.get('res.user')
Group = pool.get('res.group')
except KeyError:
User = Group = None
table = cls.__table__()
transaction = Transaction()
in_max = transaction.database.IN_MAX
history_clause = None
limit = None
if (mode == 'read'
and cls._history
and transaction.context.get('_datetime')
and not callable(cls.table_query)):
in_max = 1
table = cls.__table_history__()
column = Coalesce(table.write_date, table.create_date)
history_clause = (column <= Transaction().context['_datetime'])
limit = 1
cursor = transaction.connection.cursor()
assert mode in Rule.modes
def test_domain(ids, domain):
result = []
tables = {None: (table, None)}
if domain:
tables, dom_exp = cls.search_domain(
domain, active_test=False, tables=tables)
from_ = convert_from(None, tables)
for sub_ids in grouped_slice(ids, in_max):
sub_ids = set(sub_ids)
where = reduce_ids(table.id, sub_ids)
if history_clause:
where &= history_clause
if domain:
where &= dom_exp
cursor.execute(
*from_.select(table.id, where=where, limit=limit))
rowcount = cursor.rowcount
if rowcount == -1 or rowcount is None:
rowcount = len(cursor.fetchall())
if rowcount != len(sub_ids):
cursor.execute(
*from_.select(table.id, where=where, limit=limit))
result.extend(
sub_ids.difference([x for x, in cursor]))
return result
domains = []
if mode in {'read', 'write'}:
domains.append([])
domain = Rule.domain_get(cls.__name__, mode=mode)
if domain:
domains.append(domain)
for domain in domains:
wrong_ids = test_domain(ids, domain)
if wrong_ids:
model = cls.__name__
if Model:
model = Model.get_name(cls.__name__)
ids = ', '.join(map(str, wrong_ids[:5]))
if len(wrong_ids) > 5:
ids += '...'
if domain:
rules = []
clause, clause_global = Rule.get(cls.__name__, mode=mode)
if clause:
dom = list(clause.values())
dom.insert(0, 'OR')
if test_domain(wrong_ids, dom):
rules.extend(clause.keys())
for rule, dom in clause_global.items():
if test_domain(wrong_ids, dom):
rules.append(rule)
msg = gettext(
f'ir.msg_{mode}_rule_error', ids=ids, model=model,
rules='\n'.join(r.name for r in rules))
else:
msg = gettext(
f'ir.msg_{mode}_error', ids=ids, model=model)
ctx_msg = []
lang = Lang.get()
if cls._history and transaction.context.get('_datetime'):
ctx_msg.append(gettext('ir.msg_context_datetime',
datetime=lang.strftime(
transaction.context['_datetime'])))
if domain and User and Group:
groups = Group.browse(User.get_groups())
ctx_msg.append(gettext('ir.msg_context_groups',
groups=', '.join(g.rec_name for g in groups)))
raise AccessError(msg, '\n'.join(ctx_msg))
@classmethod
def __search_query(cls, domain, count, query, order):
pool = Pool()
Rule = pool.get('ir.rule')
rule_domain = Rule.domain_get(cls.__name__, mode='read')
joined_domains = None
if domain and domain[0] == 'OR':
local_domains, subquery_domains = split_subquery_domain(domain)
if subquery_domains:
joined_domains = subquery_domains
if local_domains:
local_domains.insert(0, 'OR')
joined_domains.append(local_domains)
def get_local_columns(order_exprs):
local_columns = []
for order_expr in order_exprs:
if (isinstance(order_expr, Column)
and isinstance(order_expr._from, Table)
and order_expr._from._name == cls._table):
local_columns.append(order_expr._name)
else:
raise NotImplementedError
return local_columns
# The UNION optimization needs the columns used to order the query
extra_columns = set()
if order and joined_domains:
tables = {
None: (cls.__table__(), None),
}
for oexpr, otype in order:
fname = oexpr.partition('.')[0]
field = cls._fields[fname]
field_orders = field.convert_order(oexpr, tables, cls)
try:
order_columns = get_local_columns(field_orders)
extra_columns.update(order_columns)
except NotImplementedError:
joined_domains = None
break
# In case the search uses subqueries it's more efficient to use a UNION
# of queries than using clauses with some JOIN because databases can
# used indexes
if joined_domains is not None:
union_tables = []
for sub_domain in joined_domains:
sub_domain = [sub_domain] # it may be a clause
tables, expression = cls.search_domain(sub_domain)
if rule_domain:
tables, domain_exp = cls.search_domain(
rule_domain, active_test=False, tables=tables)
expression &= domain_exp
main_table, _ = tables[None]
table = convert_from(None, tables)
columns = cls.__searched_columns(main_table,
eager=not count and not query,
extra_columns=extra_columns)
union_tables.append(table.select(
*columns, where=expression))
expression = None
tables = {
None: (Union(*union_tables, all_=False), None),
}
else:
tables, expression = cls.search_domain(domain)
if rule_domain:
tables, domain_exp = cls.search_domain(
rule_domain, active_test=False, tables=tables)
expression &= domain_exp
return tables, expression
@classmethod
def __searched_columns(
cls, table, *, eager=False, history=False, extra_columns=None):
if extra_columns is None:
extra_columns = []
else:
extra_columns = sorted(extra_columns - {'id', '__id', '_datetime'})
columns = [table.id.as_('id')]
if (cls._history and Transaction().context.get('_datetime')
and (eager or history)):
columns.append(
Coalesce(table.write_date, table.create_date).as_('_datetime'))
columns.append(Column(table, '__id').as_('__id'))
for column_name in extra_columns:
field = cls._fields[column_name]
sql_column = field.sql_column(table).as_(column_name)
columns.append(sql_column)
if eager:
columns += [f.sql_column(table).as_(n)
for n, f in sorted(cls._fields.items())
if not hasattr(f, 'get')
and n not in extra_columns
and n != 'id'
and not getattr(f, 'translate', False)
and f.loading == 'eager']
if not callable(cls.table_query):
sql_type = fields.Char('timestamp').sql_type().base
columns += [Extract('EPOCH',
Coalesce(table.write_date, table.create_date)
).cast(sql_type).as_('_timestamp')]
return columns
@classmethod
def __search_order(cls, order, tables):
order_by = []
order_types = {
'DESC': Desc,
'ASC': Asc,
}
null_ordering_types = {
'NULLS FIRST': NullsFirst,
'NULLS LAST': NullsLast,
None: lambda _: _
}
for oexpr, otype in order:
fname, _, extra_expr = oexpr.partition('.')
field = cls._fields[fname]
if not otype:
otype, null_ordering = 'ASC', None
else:
otype = otype.upper()
try:
otype, null_ordering = otype.split(' ', 1)
except ValueError:
null_ordering = None
Order = order_types[otype]
NullOrdering = null_ordering_types[null_ordering]
forder = field.convert_order(oexpr, tables, cls)
order_by.extend((NullOrdering(Order(o)) for o in forder))
return order_by
@classmethod
def search(cls, domain, offset=0, limit=None, order=None, count=False,
query=False):
transaction = Transaction()
cursor = transaction.connection.cursor()
super(ModelSQL, cls).search(
domain, offset=offset, limit=limit, order=order, count=count)
if order is None or order is False:
order = cls._order
tables, expression = cls.__search_query(domain, count, query, order)
main_table, _ = tables[None]
if count:
table = convert_from(None, tables)
if (limit is not None and limit < cls.estimated_count()) or offset:
select = table.select(
Literal(1), where=expression, limit=limit, offset=offset
).select(Count(Literal('*')))
else:
select = table.select(Count(Literal('*')), where=expression)
if query:
return select
else:
cursor.execute(*select)
return cursor.fetchone()[0]
order_by = cls.__search_order(order, tables)
# compute it here because __search_order might modify tables
table = convert_from(None, tables)
if query:
columns = [main_table.id.as_('id')]
else:
columns = cls.__searched_columns(main_table, eager=True)
if backend.name == 'sqlite':
for column in columns:
field = cls._fields.get(column.output_name)
if field:
column.output_name += ' [%s]' % field.sql_type().base
select = table.select(
*columns, where=expression, limit=limit, offset=offset,
order_by=order_by)
if query:
return select
cursor.execute(*select)
rows = list(cursor_dict(cursor, transaction.database.IN_MAX))
cache = transaction.get_cache()
delete_records = transaction.delete_records[cls.__name__]
# Can not cache the history value if we are not sure to have fetch all
# the rows for each records
if (not (cls._history and transaction.context.get('_datetime'))
or len(rows) < transaction.database.IN_MAX):
keys = None
for data in islice(rows, 0, cache.size_limit):
if data['id'] in delete_records:
continue
if keys is None:
keys = list(data.keys())
for k in keys[:]:
if k in ('_timestamp', '_datetime', '__id'):
continue
field = cls._fields[k]
if not getattr(field, 'datetime_field', None):
keys.remove(k)
continue
for k in keys:
del data[k]
cache[cls.__name__][data['id']]._update(data)
return cls.browse([x['id'] for x in rows])
@classmethod
def search_domain(cls, domain, active_test=None, tables=None):
'''
Return SQL tables and expression
Set active_test to add it.
'''
transaction = Transaction()
if active_test is None:
active_test = transaction.active_records
domain = cls._search_domain_active(domain, active_test=active_test)
if tables is None:
tables = {}
if None not in tables:
if cls._history and transaction.context.get('_datetime'):
tables[None] = (cls.__table_history__(), None)
else:
tables[None] = (cls.__table__(), None)
def convert(domain):
if is_leaf(domain):
fname = domain[0].split('.', 1)[0]
field = cls._fields[fname]
expression = field.convert_domain(domain, tables, cls)
if not isinstance(expression, (Operator, Expression)):
return convert(expression)
return expression
elif not domain or list(domain) in (['OR'], ['AND']):
return Literal(True)
elif domain[0] == 'OR':
return Or((convert(d) for d in domain[1:]))
else:
return And((convert(d) for d in (
domain[1:] if domain[0] == 'AND' else domain)))
with without_check_access():
expression = convert(domain)
if cls._history and transaction.context.get('_datetime'):
database = Transaction().database
if database.has_window_functions():
table, _ = tables[None]
history = cls.__table_history__()
last_change = Coalesce(history.write_date, history.create_date)
# prefilter the history records for a bit of a speedup
selected_h_ids = convert_from(None, tables).select(
table.id, where=expression)
most_recent = history.select(
history.create_date, Column(history, '__id'),
RowNumber(
window=Window([history.id],
order_by=[
last_change.desc,
Column(history, '__id').desc])).as_('rank'),
where=((last_change <= transaction.context['_datetime'])
& history.id.in_(selected_h_ids)))
# Filter again as the latest records from most_recent might not
# match the expression
expression &= Exists(most_recent.select(
Literal(1),
where=(
(Column(table, '__id')
== Column(most_recent, '__id'))
& (most_recent.create_date != Null)
& (most_recent.rank == 1))))
else:
table, _ = tables[None]
history_1 = cls.__table_history__()
history_2 = cls.__table_history__()
last_change = Coalesce(
history_1.write_date, history_1.create_date)
latest_change = history_1.select(
history_1.id, Max(last_change).as_('date'),
where=(last_change <= transaction.context['_datetime']),
group_by=[history_1.id])
most_recent = history_2.join(
latest_change,
condition=(
(history_2.id == latest_change.id)
& (Coalesce(
history_2.write_date, history_2.create_date)
== latest_change.date))
).select(
Max(Column(history_2, '__id')).as_('h_id'),
where=(history_2.create_date != Null),
group_by=[history_2.id])
expression &= Exists(most_recent.select(
Literal(1),
where=(Column(table, '__id') == most_recent.h_id)))
return tables, expression
@classmethod
def _rebuild_path(cls, field_name):
"Rebuild path for the tree."
cursor = Transaction().connection.cursor()
field = cls._fields[field_name]
table = cls.__table__()
tree = With('id', 'path', recursive=True)
tree.query = table.select(
table.id, Concat(table.id, '/'),
where=Column(table, field_name) == Null)
tree.query |= (table
.join(tree,
condition=Column(table, field_name) == tree.id)
.select(table.id, Concat(Concat(tree.path, table.id), '/')))
query = table.update(
[Column(table, field.path)],
[tree.path],
from_=[tree], where=table.id == tree.id,
with_=[tree])
cursor.execute(*query)
@classmethod
def _set_path(cls, field_names, list_ids):
cursor = Transaction().connection.cursor()
table = cls.__table__()
parent = cls.__table__()
for field_name, ids in zip(field_names, list_ids):
field = cls._fields[field_name]
parent_column = Column(table, field_name)
path_column = Column(table, field.path)
query = table.update(
[path_column],
[Concat(Concat(Coalesce(
parent.select(parent.path,
where=parent.id == parent_column),
''), table.id), '/')])
for sub_ids in grouped_slice(ids):
query.where = reduce_ids(table.id, sub_ids)
cursor.execute(*query)
@classmethod
def _update_path(cls, field_names, list_ids):
transaction = Transaction()
cursor = transaction.connection.cursor()
update = transaction.connection.cursor()
table = cls.__table__()
parent = cls.__table__()
def update_path(query, column, sub_ids):
updated = set()
query.where = reduce_ids(table.id, sub_ids)
cursor.execute(*query)
for old_path, new_path in cursor:
if old_path == new_path:
continue
if any(old_path.startswith(p) for p in updated):
return False
update.execute(*table.update(
[column],
[Concat(new_path,
Substring(table.path, len(old_path) + 1))],
where=table.path.like(old_path + '%')))
updated.add(old_path)
return True
for field_name, ids in zip(field_names, list_ids):
field = cls._fields[field_name]
parent_column = Column(table, field_name)
parent_path_column = Column(parent, field.path)
path_column = Column(table, field.path)
query = (table
.join(parent, 'LEFT',
condition=parent_column == parent.id)
.select(path_column,
Concat(Concat(
Coalesce(parent_path_column, ''), table.id), '/')))
for sub_ids in grouped_slice(ids):
sub_ids = list(sub_ids)
while not update_path(query, path_column, sub_ids):
pass
@classmethod
def _update_mptt(cls, field_names, list_ids, values=None):
for field_name, ids in zip(field_names, list_ids):
field = cls._fields[field_name]
if (values is not None
and (field.left in values or field.right in values)):
raise Exception('ValidateError',
'You can not update fields: "%s", "%s"' %
(field.left, field.right))
if len(ids) < max(cls.estimated_count() / 4, 4):
for id_ in ids:
cls._update_tree(id_, field_name,
field.left, field.right)
else:
cls._rebuild_tree(field_name, None, 0)
@classmethod
def _rebuild_tree(cls, parent, parent_id, left):
'''
Rebuild left, right value for the tree.
'''
cursor = Transaction().connection.cursor()
table = cls.__table__()
right = left + 1
cursor.execute(*table.select(table.id,
where=Column(table, parent) == parent_id))
for child_id, in cursor:
right = cls._rebuild_tree(parent, child_id, right)
field = cls._fields[parent]
if parent_id:
cursor.execute(*table.update(
[Column(table, field.left), Column(table, field.right)],
[left, right],
where=table.id == parent_id))
return right + 1
@classmethod
def _update_tree(cls, record_id, field_name, left, right):
'''
Update left, right values for the tree.
Remarks:
- the value (right - left - 1) / 2 will not give
the number of children node
'''
cursor = Transaction().connection.cursor()
table = cls.__table__()
left = Column(table, left)
right = Column(table, right)
field = Column(table, field_name)
cursor.execute(*table.select(left, right, field,
where=table.id == record_id))
fetchone = cursor.fetchone()
if not fetchone:
return
old_left, old_right, parent_id = fetchone
if old_left == old_right == 0:
cursor.execute(*table.select(Max(right),
where=field == Null))
old_left, = cursor.fetchone()
old_left += 1
old_right = old_left + 1
cursor.execute(*table.update([left, right],
[old_left, old_right],
where=table.id == record_id))
size = old_right - old_left + 1
parent_right = 1
if parent_id:
cursor.execute(*table.select(right, where=table.id == parent_id))
parent_right = cursor.fetchone()[0]
else:
cursor.execute(*table.select(Max(right), where=field == Null))
fetchone = cursor.fetchone()
if fetchone:
parent_right = fetchone[0] + 1
cursor.execute(*table.update([left], [left + size],
where=left >= parent_right))
cursor.execute(*table.update([right], [right + size],
where=right >= parent_right))
if old_left < parent_right:
left_delta = parent_right - old_left
right_delta = parent_right - old_left
left_cond = old_left
right_cond = old_right
else:
left_delta = parent_right - old_left - size
right_delta = parent_right - old_left - size
left_cond = old_left + size
right_cond = old_right + size
cursor.execute(*table.update([left, right],
[left + left_delta, right + right_delta],
where=(left >= left_cond) & (right <= right_cond)))
@classmethod
def validate(cls, records):
super(ModelSQL, cls).validate(records)
transaction = Transaction()
database = transaction.database
connection = transaction.connection
has_constraint = database.has_constraint
lock = database.lock
cursor = transaction.connection.cursor()
# Works only for a single transaction
ids = list(map(int, records))
for _, sql, error in cls._sql_constraints:
if has_constraint(sql):
continue
table = sql.table
if isinstance(sql, (Unique, Exclude)):
lock(connection, cls._table)
columns = list(sql.columns)
columns.insert(0, table.id)
in_max = transaction.database.IN_MAX // (len(columns) + 1)
for sub_ids in grouped_slice(ids, in_max):
where = reduce_ids(table.id, sub_ids)
if isinstance(sql, Exclude) and sql.where:
where &= sql.where
cursor.execute(*table.select(*columns, where=where))
where = Literal(False)
for row in cursor:
clause = table.id != row[0]
for column, operator, value in zip(
sql.columns, sql.operators, row[1:]):
if value is None:
# NULL is always unique
clause &= Literal(False)
clause &= operator(column, value)
where |= clause
if isinstance(sql, Exclude) and sql.where:
where &= sql.where
cursor.execute(
*table.select(table.id, where=where, limit=1))
if cursor.fetchone():
raise SQLConstraintError(gettext(error))
elif isinstance(sql, Check):
for sub_ids in grouped_slice(ids):
red_sql = reduce_ids(table.id, sub_ids)
cursor.execute(*table.select(table.id,
where=~sql.expression & red_sql,
limit=1))
if cursor.fetchone():
raise SQLConstraintError(gettext(error))
@dualmethod
def lock(cls, records=None):
transaction = Transaction()
database = transaction.database
connection = transaction.connection
table = cls.__table__()
if records is not None and database.has_select_for():
for sub_records in grouped_slice(records):
where = reduce_ids(table.id, sub_records)
query = table.select(
Literal(1), where=where, for_=For('UPDATE', nowait=True))
with connection.cursor() as cursor:
cursor.execute(*query)
else:
transaction.lock_table(cls._table)
def convert_from(table, tables, type_='LEFT'):
# Don't nested joins as SQLite doesn't support
right, condition = tables[None]
if table:
table = table.join(right, type_, condition)
else:
table = right
for k, sub_tables in tables.items():
if k is None:
continue
table = convert_from(table, sub_tables, type_=type_)
return table
def split_subquery_domain(domain):
"""
Split a domain in two parts:
- the first one contains all the sub-domains with only local fields
- the second one contains all the sub-domains using a related field
The main operator of the domain will be stripped from the results.
"""
local_domains, subquery_domains = [], []
for sub_domain in domain:
if is_leaf(sub_domain):
if '.' in sub_domain[0]:
subquery_domains.append(sub_domain)
else:
local_domains.append(sub_domain)
elif (not sub_domain or list(sub_domain) in [['OR'], ['AND']]
or sub_domain in ['OR', 'AND']):
continue
else:
sub_ldomains, sub_sqdomains = split_subquery_domain(sub_domain)
if sub_sqdomains:
subquery_domains.append(sub_domain)
else:
local_domains.append(sub_domain)
return local_domains, subquery_domains