Files
tradon/model/fields/many2one.py
2025-12-26 13:11:43 +00:00

355 lines
14 KiB
Python
Executable File

# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from sql import As, Column, Expression, Literal, Query, With
from sql.aggregate import Max
from sql.conditionals import Coalesce
from sql.operators import Or
from trytond.config import config
from trytond.pool import Pool
from trytond.pyson import PYSONEncoder
from trytond.tools import cached_property, reduce_ids
from trytond.transaction import Transaction, inactive_records
from .field import (
Field, context_validate, domain_method, instantiate_context, order_method,
search_order_validate)
_subquery_threshold = config.getint('database', 'subquery_threshold')
class Many2One(Field):
'''
Define many2one field (``int``).
'''
_type = 'many2one'
_sql_type = 'INTEGER'
def __init__(self, model_name, string='', left=None, right=None, path=None,
ondelete='SET NULL', datetime_field=None,
search_order=None, search_context=None, help='', required=False,
readonly=False, domain=None, states=None,
on_change=None, on_change_with=None, depends=None, context=None,
loading='eager'):
'''
:param model_name: The name of the target model.
:param left: The name of the field to store the left value for
Modified Preorder Tree Traversal.
See http://en.wikipedia.org/wiki/Tree_traversal
:param right: The name of the field to store the right value. See left
:param path: The name of the field used to store the path.
:param ondelete: Define the behavior of the record when the target
record is deleted. (``CASCADE``, ``RESTRICT``, ``SET NULL``)
``SET NULL`` will be changed into ``RESTRICT`` if required is set.
:param datetime_field: The name of the field that contains the datetime
value to read the target record.
:param target_search: The kind of target search 'subquery' or 'join'
:param search_order: The order to use when searching for a record
:param search_context: The context to use when searching for a record
'''
self.__required = required
if ondelete not in ('CASCADE', 'RESTRICT', 'SET NULL'):
raise Exception('Bad arguments')
self.ondelete = ondelete
super(Many2One, self).__init__(string=string, help=help,
required=required, readonly=readonly, domain=domain, states=states,
on_change=on_change, on_change_with=on_change_with,
depends=depends, context=context, loading=loading)
self.model_name = model_name
self.left = left
self.right = right
self.path = path
self.datetime_field = datetime_field
self.__search_order = None
self.search_order = search_order
self.__search_context = None
self.search_context = search_context or {}
__init__.__doc__ += Field.__init__.__doc__
def __get_required(self):
return self.__required
def __set_required(self, value):
self.__required = value
if value and self.ondelete == 'SET NULL':
self.ondelete = 'RESTRICT'
required = property(__get_required, __set_required)
@property
def search_order(self):
return self.__search_order
@search_order.setter
def search_order(self, value):
search_order_validate(value)
self.__search_order = value
@property
def search_context(self):
return self.__search_context
@search_context.setter
def search_context(self, value):
context_validate(value)
self.__search_context = value
@cached_property
def display_depends(self):
depends = super().display_depends
if self.datetime_field:
depends.add(self.datetime_field)
return depends
def get_target(self):
'Return the target Model'
return Pool().get(self.model_name)
def __set__(self, inst, value):
Target = self.get_target()
if isinstance(value, (dict, int)):
ctx = instantiate_context(self, inst)
with Transaction().set_context(ctx):
if isinstance(value, dict):
value = Target(**value)
elif isinstance(value, int):
value = Target(value)
assert isinstance(value, (Target, type(None)))
super(Many2One, self).__set__(inst, value)
def sql_format(self, value):
from ..model import Model
assert value is not False
assert (
not isinstance(value, Model) or value.__name__ == self.model_name)
if value and not isinstance(value, (Query, Expression)):
value = int(value)
return super().sql_format(value)
def convert_domain_path(self, domain, tables):
cursor = Transaction().connection.cursor()
table, _ = tables[None]
name, operator, ids = domain
red_sql = reduce_ids(table.id, (i for i in ids if i is not None))
Target = self.get_target()
path_column = getattr(Target, self.path).sql_column(table)
path_column = Coalesce(path_column, '')
cursor.execute(*table.select(path_column, where=red_sql))
if operator.endswith('child_of'):
where = Or()
for path, in cursor:
where.append(path_column.like(path + '%'))
else:
ids = [int(x) for path, in cursor for x in path.split('/')[:-1]]
where = reduce_ids(table.id, ids)
if not where:
where = Literal(False)
if operator.startswith('not'):
return ~where
return where
def convert_domain_mptt(self, domain, tables):
cursor = Transaction().connection.cursor()
table, _ = tables[None]
name, operator, ids = domain
red_sql = reduce_ids(table.id, (i for i in ids if i is not None))
Target = self.get_target()
left = getattr(Target, self.left).sql_column(table)
right = getattr(Target, self.right).sql_column(table)
cursor.execute(*table.select(left, right, where=red_sql))
where = Or()
for l, r in cursor:
if operator.endswith('child_of'):
where.append((left >= l) & (right <= r))
else:
where.append((left <= l) & (right >= r))
if not where:
where = Literal(False)
if operator.startswith('not'):
return ~where
return where
def convert_domain_tree(self, domain, tables):
Target = self.get_target()
target = Target.__table__()
table, _ = tables[None]
name, operator, ids = domain
red_sql = reduce_ids(target.id, (i for i in ids if i is not None))
if operator.endswith('child_of'):
tree = With('id', recursive=True)
tree.query = target.select(target.id, where=red_sql)
tree.query |= (target
.join(tree, condition=Column(target, name) == tree.id)
.select(target.id))
else:
tree = With('id', name, recursive=True)
tree.query = target.select(
target.id, Column(target, name), where=red_sql)
tree.query |= (target
.join(tree, condition=target.id == Column(tree, name))
.select(target.id, Column(target, name)))
expression = table.id.in_(tree.select(tree.id, with_=[tree]))
if operator.startswith('not'):
return ~expression
return expression
@inactive_records
@domain_method
def convert_domain(self, domain, tables, Model):
pool = Pool()
Rule = pool.get('ir.rule')
Target = self.get_target()
table, _ = tables[None]
name, operator, value = domain[:3]
column = self.sql_column(table)
if '.' not in name:
if operator.endswith('child_of') or operator.endswith('parent_of'):
if Target != Model:
if operator.endswith('child_of'):
target_operator = 'child_of'
else:
target_operator = 'parent_of'
query = Target.search([
(domain[3], target_operator, value),
], order=[], query=True)
expression = column.in_(query)
if operator.startswith('not'):
return ~expression
return expression
if isinstance(value, str):
targets = Target.search([('rec_name', 'ilike', value)],
order=[])
ids = [t.id for t in targets]
elif not isinstance(value, (list, tuple)):
ids = [value]
else:
ids = value
if not ids:
expression = Literal(False)
if operator.startswith('not'):
return ~expression
return expression
elif self.left and self.right:
return self.convert_domain_mptt(
(name, operator, ids), tables)
elif self.path:
return self.convert_domain_path(
(name, operator, ids), tables)
else:
return self.convert_domain_tree(
(name, operator, ids), tables)
# Used for Many2Many where clause
if operator.endswith('where'):
query = Target.search(value, order=[], query=True)
target_id, = query.columns
if isinstance(target_id, As):
target_id = target_id.expression
query.where &= target_id == column
expression = column.in_(query)
if operator.startswith('not'):
return ~expression
return expression
if not isinstance(value, str):
return super(Many2One, self).convert_domain(domain, tables,
Model)
else:
target_name = 'rec_name'
else:
_, target_name = name.split('.', 1)
target_domain = [(target_name,) + tuple(domain[1:])]
rule_domain = Rule.domain_get(Target.__name__, mode='read')
if not rule_domain and target_name == 'id':
# No need to join with the target table
return super().convert_domain(
(self.name, operator, value), tables, Model)
elif Target.estimated_count() < _subquery_threshold:
query = Target.search(target_domain, order=[], query=True)
return column.in_(query)
else:
target_domain = [target_domain, rule_domain]
target_tables = self._get_target_tables(tables)
target_table, _ = target_tables[None]
_, expression = Target.search_domain(
target_domain, tables=target_tables)
return expression
@order_method
def convert_order(self, name, tables, Model):
fname, _, oexpr = name.partition('.')
Target = self.get_target()
if oexpr:
oname, _, _ = oexpr.partition('.')
else:
oname = 'id'
if (Target._rec_name in Target._fields
and Target._fields[Target._rec_name].sortable(Target)):
oname = Target._rec_name
if (Target._order_name in Target._fields
and Target._fields[Target._order_name].sortable(Target)):
oname = Target._order_name
oexpr = oname
table, _ = tables[None]
if oname == 'id':
return [self.sql_column(table)]
ofield = Target._fields[oname]
target_tables = self._get_target_tables(tables)
return ofield.convert_order(oexpr, target_tables, Target)
def _get_target_tables(self, tables):
Target = self.get_target()
table, _ = tables[None]
target_tables = tables.get(self.name)
context = Transaction().context
if target_tables is None:
if Target._history and context.get('_datetime'):
target = Target.__table_history__()
target_history = Target.__table_history__()
history_condition = Column(target, '__id').in_(
target_history.select(
Max(Column(target_history, '__id')),
where=Coalesce(
target_history.write_date,
target_history.create_date)
<= context['_datetime'],
group_by=target_history.id))
else:
target = Target.__table__()
history_condition = None
condition = target.id == self.sql_column(table)
if history_condition:
condition &= history_condition
target_tables = {
None: (target, condition),
}
tables[self.name] = target_tables
return target_tables
def definition(self, model, language):
encoder = PYSONEncoder()
target = self.get_target()
relation_fields = [fname for fname, field in target._fields.items()
if field._type == 'one2many'
and field.model_name == model.__name__
and field.field == self.name]
definition = super().definition(model, language)
definition['datetime_field'] = self.datetime_field
definition['relation'] = target.__name__
if len(relation_fields) == 1:
definition['relation_field'], = relation_fields
definition['search_context'] = encoder.encode(self.search_context)
definition['search_order'] = encoder.encode(self.search_order)
return definition