# This file is part of Tryton. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. from sql import As, Column, Expression, Literal, Query, With from sql.aggregate import Max from sql.conditionals import Coalesce from sql.operators import Or from trytond.config import config from trytond.pool import Pool from trytond.pyson import PYSONEncoder from trytond.tools import cached_property, reduce_ids from trytond.transaction import Transaction, inactive_records from .field import ( Field, context_validate, domain_method, instantiate_context, order_method, search_order_validate) _subquery_threshold = config.getint('database', 'subquery_threshold') class Many2One(Field): ''' Define many2one field (``int``). ''' _type = 'many2one' _sql_type = 'INTEGER' def __init__(self, model_name, string='', left=None, right=None, path=None, ondelete='SET NULL', datetime_field=None, search_order=None, search_context=None, help='', required=False, readonly=False, domain=None, states=None, on_change=None, on_change_with=None, depends=None, context=None, loading='eager'): ''' :param model_name: The name of the target model. :param left: The name of the field to store the left value for Modified Preorder Tree Traversal. See http://en.wikipedia.org/wiki/Tree_traversal :param right: The name of the field to store the right value. See left :param path: The name of the field used to store the path. :param ondelete: Define the behavior of the record when the target record is deleted. (``CASCADE``, ``RESTRICT``, ``SET NULL``) ``SET NULL`` will be changed into ``RESTRICT`` if required is set. :param datetime_field: The name of the field that contains the datetime value to read the target record. :param target_search: The kind of target search 'subquery' or 'join' :param search_order: The order to use when searching for a record :param search_context: The context to use when searching for a record ''' self.__required = required if ondelete not in ('CASCADE', 'RESTRICT', 'SET NULL'): raise Exception('Bad arguments') self.ondelete = ondelete super(Many2One, self).__init__(string=string, help=help, required=required, readonly=readonly, domain=domain, states=states, on_change=on_change, on_change_with=on_change_with, depends=depends, context=context, loading=loading) self.model_name = model_name self.left = left self.right = right self.path = path self.datetime_field = datetime_field self.__search_order = None self.search_order = search_order self.__search_context = None self.search_context = search_context or {} __init__.__doc__ += Field.__init__.__doc__ def __get_required(self): return self.__required def __set_required(self, value): self.__required = value if value and self.ondelete == 'SET NULL': self.ondelete = 'RESTRICT' required = property(__get_required, __set_required) @property def search_order(self): return self.__search_order @search_order.setter def search_order(self, value): search_order_validate(value) self.__search_order = value @property def search_context(self): return self.__search_context @search_context.setter def search_context(self, value): context_validate(value) self.__search_context = value @cached_property def display_depends(self): depends = super().display_depends if self.datetime_field: depends.add(self.datetime_field) return depends def get_target(self): 'Return the target Model' return Pool().get(self.model_name) def __set__(self, inst, value): Target = self.get_target() if isinstance(value, (dict, int)): ctx = instantiate_context(self, inst) with Transaction().set_context(ctx): if isinstance(value, dict): value = Target(**value) elif isinstance(value, int): value = Target(value) assert isinstance(value, (Target, type(None))) super(Many2One, self).__set__(inst, value) def sql_format(self, value): from ..model import Model assert value is not False assert ( not isinstance(value, Model) or value.__name__ == self.model_name) if value and not isinstance(value, (Query, Expression)): value = int(value) return super().sql_format(value) def convert_domain_path(self, domain, tables): cursor = Transaction().connection.cursor() table, _ = tables[None] name, operator, ids = domain red_sql = reduce_ids(table.id, (i for i in ids if i is not None)) Target = self.get_target() path_column = getattr(Target, self.path).sql_column(table) path_column = Coalesce(path_column, '') cursor.execute(*table.select(path_column, where=red_sql)) if operator.endswith('child_of'): where = Or() for path, in cursor: where.append(path_column.like(path + '%')) else: ids = [int(x) for path, in cursor for x in path.split('/')[:-1]] where = reduce_ids(table.id, ids) if not where: where = Literal(False) if operator.startswith('not'): return ~where return where def convert_domain_mptt(self, domain, tables): cursor = Transaction().connection.cursor() table, _ = tables[None] name, operator, ids = domain red_sql = reduce_ids(table.id, (i for i in ids if i is not None)) Target = self.get_target() left = getattr(Target, self.left).sql_column(table) right = getattr(Target, self.right).sql_column(table) cursor.execute(*table.select(left, right, where=red_sql)) where = Or() for l, r in cursor: if operator.endswith('child_of'): where.append((left >= l) & (right <= r)) else: where.append((left <= l) & (right >= r)) if not where: where = Literal(False) if operator.startswith('not'): return ~where return where def convert_domain_tree(self, domain, tables): Target = self.get_target() target = Target.__table__() table, _ = tables[None] name, operator, ids = domain red_sql = reduce_ids(target.id, (i for i in ids if i is not None)) if operator.endswith('child_of'): tree = With('id', recursive=True) tree.query = target.select(target.id, where=red_sql) tree.query |= (target .join(tree, condition=Column(target, name) == tree.id) .select(target.id)) else: tree = With('id', name, recursive=True) tree.query = target.select( target.id, Column(target, name), where=red_sql) tree.query |= (target .join(tree, condition=target.id == Column(tree, name)) .select(target.id, Column(target, name))) expression = table.id.in_(tree.select(tree.id, with_=[tree])) if operator.startswith('not'): return ~expression return expression @inactive_records @domain_method def convert_domain(self, domain, tables, Model): pool = Pool() Rule = pool.get('ir.rule') Target = self.get_target() table, _ = tables[None] name, operator, value = domain[:3] column = self.sql_column(table) if '.' not in name: if operator.endswith('child_of') or operator.endswith('parent_of'): if Target != Model: if operator.endswith('child_of'): target_operator = 'child_of' else: target_operator = 'parent_of' query = Target.search([ (domain[3], target_operator, value), ], order=[], query=True) expression = column.in_(query) if operator.startswith('not'): return ~expression return expression if isinstance(value, str): targets = Target.search([('rec_name', 'ilike', value)], order=[]) ids = [t.id for t in targets] elif not isinstance(value, (list, tuple)): ids = [value] else: ids = value if not ids: expression = Literal(False) if operator.startswith('not'): return ~expression return expression elif self.left and self.right: return self.convert_domain_mptt( (name, operator, ids), tables) elif self.path: return self.convert_domain_path( (name, operator, ids), tables) else: return self.convert_domain_tree( (name, operator, ids), tables) # Used for Many2Many where clause if operator.endswith('where'): query = Target.search(value, order=[], query=True) target_id, = query.columns if isinstance(target_id, As): target_id = target_id.expression query.where &= target_id == column expression = column.in_(query) if operator.startswith('not'): return ~expression return expression if not isinstance(value, str): return super(Many2One, self).convert_domain(domain, tables, Model) else: target_name = 'rec_name' else: _, target_name = name.split('.', 1) target_domain = [(target_name,) + tuple(domain[1:])] rule_domain = Rule.domain_get(Target.__name__, mode='read') if not rule_domain and target_name == 'id': # No need to join with the target table return super().convert_domain( (self.name, operator, value), tables, Model) elif Target.estimated_count() < _subquery_threshold: query = Target.search(target_domain, order=[], query=True) return column.in_(query) else: target_domain = [target_domain, rule_domain] target_tables = self._get_target_tables(tables) target_table, _ = target_tables[None] _, expression = Target.search_domain( target_domain, tables=target_tables) return expression @order_method def convert_order(self, name, tables, Model): fname, _, oexpr = name.partition('.') Target = self.get_target() if oexpr: oname, _, _ = oexpr.partition('.') else: oname = 'id' if (Target._rec_name in Target._fields and Target._fields[Target._rec_name].sortable(Target)): oname = Target._rec_name if (Target._order_name in Target._fields and Target._fields[Target._order_name].sortable(Target)): oname = Target._order_name oexpr = oname table, _ = tables[None] if oname == 'id': return [self.sql_column(table)] ofield = Target._fields[oname] target_tables = self._get_target_tables(tables) return ofield.convert_order(oexpr, target_tables, Target) def _get_target_tables(self, tables): Target = self.get_target() table, _ = tables[None] target_tables = tables.get(self.name) context = Transaction().context if target_tables is None: if Target._history and context.get('_datetime'): target = Target.__table_history__() target_history = Target.__table_history__() history_condition = Column(target, '__id').in_( target_history.select( Max(Column(target_history, '__id')), where=Coalesce( target_history.write_date, target_history.create_date) <= context['_datetime'], group_by=target_history.id)) else: target = Target.__table__() history_condition = None condition = target.id == self.sql_column(table) if history_condition: condition &= history_condition target_tables = { None: (target, condition), } tables[self.name] = target_tables return target_tables def definition(self, model, language): encoder = PYSONEncoder() target = self.get_target() relation_fields = [fname for fname, field in target._fields.items() if field._type == 'one2many' and field.model_name == model.__name__ and field.field == self.name] definition = super().definition(model, language) definition['datetime_field'] = self.datetime_field definition['relation'] = target.__name__ if len(relation_fields) == 1: definition['relation_field'], = relation_fields definition['search_context'] = encoder.encode(self.search_context) definition['search_order'] = encoder.encode(self.search_order) return definition