Files
tradon/ir/rule.py
2025-12-26 13:11:43 +00:00

350 lines
12 KiB
Python
Executable File

# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from collections import defaultdict
from sql import Literal
from trytond.cache import Cache
from trytond.i18n import gettext
from trytond.model import Check, Index, ModelSQL, ModelView, fields
from trytond.model.exceptions import ValidationError
from trytond.pool import Pool
from trytond.pyson import PYSONDecoder
from trytond.transaction import Transaction, inactive_records
class DomainError(ValidationError):
pass
def _get_access_models(Model, names=None, model2field=None, path=None):
"Return names and model2field"
if names is None:
names = set()
if model2field is None:
model2field = defaultdict(list)
if Model.__name__ in names:
return
names.add(Model.__name__)
if path:
model2field[Model.__name__].append(path)
for field_name in Model.__access__:
field = getattr(Model, field_name)
Target = field.get_target()
if path:
target_path = path + '.' + field_name
else:
target_path = field_name
_get_access_models(Target, names, model2field, target_path)
return names, model2field
class RuleGroup(
fields.fmany2one(
'model_ref', 'model', 'ir.model,model', "Model",
required=True, ondelete='CASCADE'),
ModelSQL, ModelView):
"Rule group"
__name__ = 'ir.rule.group'
name = fields.Char(
"Name", translate=True, required=True,
help="Displayed to users when access error is raised for this rule.")
model = fields.Char("Model", required=True)
global_p = fields.Boolean('Global',
help="Make the rule global \nso every users must follow this rule.")
default_p = fields.Boolean('Default',
help="Add this rule to all users by default.")
rules = fields.One2Many('ir.rule', 'rule_group', 'Tests',
help="The rule is satisfied if at least one test is True.")
perm_read = fields.Boolean('Read Access')
perm_write = fields.Boolean('Write Access')
perm_create = fields.Boolean('Create Access')
perm_delete = fields.Boolean('Delete Access')
@classmethod
def __setup__(cls):
super(RuleGroup, cls).__setup__()
t = cls.__table__()
cls._order.insert(0, ('model', 'ASC'))
cls._order.insert(1, ('global_p', 'ASC'))
cls._order.insert(2, ('default_p', 'ASC'))
cls._sql_constraints += [
('global_default_exclusive',
Check(t, (t.global_p == Literal(False))
| (t.default_p == Literal(False))),
'Global and Default are mutually exclusive!'),
]
cls._sql_indexes.update({
Index(t, (t.model, Index.Equality())),
})
@classmethod
def __register__(cls, module):
pool = Pool()
Model = pool.get('ir.model')
transaction = Transaction()
cursor = transaction.connection.cursor()
table_h = cls.__table_handler__(module)
table = cls.__table__()
model = Model.__table__()
# Migration from 7.0: model as char
if (table_h.column_exist('model')
and table_h.column_is_type('model', 'INTEGER')):
table_h.column_rename('model', '_temp_model')
table_h.add_column('model', 'VARCHAR')
cursor.execute(*table.update(
[table.model], [model.model],
from_=[model],
where=table._temp_model == model.id))
table_h.drop_column('_temp_model')
super().__register__(module)
@staticmethod
def default_global_p():
return True
@staticmethod
def default_default_p():
return False
@staticmethod
def default_perm_read():
return True
@staticmethod
def default_perm_write():
return True
@staticmethod
def default_perm_create():
return True
@staticmethod
def default_perm_delete():
return True
@classmethod
def delete(cls, groups):
super(RuleGroup, cls).delete(groups)
# Restart the cache on the domain_get method of ir.rule
Pool().get('ir.rule')._domain_get_cache.clear()
@classmethod
def create(cls, vlist):
res = super(RuleGroup, cls).create(vlist)
# Restart the cache on the domain_get method of ir.rule
Pool().get('ir.rule')._domain_get_cache.clear()
return res
@classmethod
def write(cls, groups, vals, *args):
super(RuleGroup, cls).write(groups, vals, *args)
# Restart the cache on the domain_get method of ir.rule
Pool().get('ir.rule')._domain_get_cache.clear()
class Rule(ModelSQL, ModelView):
"Rule"
__name__ = 'ir.rule'
rule_group = fields.Many2One('ir.rule.group', 'Group',
required=True, ondelete="CASCADE")
domain = fields.Char('Domain', required=True,
help="Domain is evaluated with a PYSON context containing:"
'\n- "groups" as list of ids from the current user')
_domain_get_cache = Cache('ir_rule.domain_get', context=False)
modes = {'read', 'write', 'create', 'delete'}
@classmethod
def __setup__(cls):
super().__setup__()
cls.__access__.add('rule_group')
table = cls.__table__()
cls._sql_indexes.add(
Index(table, (table.rule_group, Index.Equality())))
@classmethod
def validate_fields(cls, rules, field_names):
super().validate_fields(rules, field_names)
cls.check_domain(rules, field_names)
@classmethod
def check_domain(cls, rules, field_names=None):
if field_names and 'domain' not in field_names:
return
for rule in rules:
ctx = cls._get_context(rule.rule_group.model)
try:
value = PYSONDecoder(ctx).decode(rule.domain)
except Exception:
raise DomainError(gettext(
'ir.msg_rule_invalid_domain', name=rule.rec_name))
if not isinstance(value, list):
raise DomainError(gettext(
'ir.msg_rule_invalid_domain', name=rule.rec_name))
else:
try:
fields.domain_validate(value)
except Exception:
raise DomainError(gettext(
'ir.msg_rule_invalid_domain', name=rule.rec_name))
@classmethod
def _get_context(cls, model_name):
pool = Pool()
User = pool.get('res.user')
return {
'groups': User.get_groups()
}
@classmethod
def _get_cache_key(cls, model_names):
pool = Pool()
User = pool.get('res.user')
# _datetime value will be added to the domain
return (Transaction().context.get('_datetime'), User.get_groups())
@classmethod
def get(cls, model_name, mode='read'):
"Return dictionary of non-global and global rules"
pool = Pool()
RuleGroup = pool.get('ir.rule.group')
RuleGroup_Group = pool.get('ir.rule.group-res.group')
User = pool.get('res.user')
rule_table = cls.__table__()
rule_group = RuleGroup.__table__()
rule_group_group = RuleGroup_Group.__table__()
transaction = Transaction()
assert mode in cls.modes
groups = User.get_groups()
model_names, model2field = _get_access_models(pool.get(model_name))
model_names = list(model_names)
cursor = transaction.connection.cursor()
# root user above constraint
if transaction.user == 0:
return {}, {}
cursor.execute(*rule_table.join(rule_group,
condition=rule_group.id == rule_table.rule_group
).join(rule_group_group, 'LEFT',
condition=rule_group_group.rule_group == rule_group.id
).select(rule_table.id,
where=(rule_group.model.in_(model_names))
& (getattr(rule_group, 'perm_%s' % mode) == Literal(True))
& (rule_group_group.group.in_(groups or [-1])
| (rule_group.default_p == Literal(True))
| (rule_group.global_p == Literal(True))
)))
ids = [x for x, in cursor]
# Test if there is no rule_group that have no rule
cursor.execute(*rule_group.join(rule_group_group, 'LEFT',
condition=rule_group_group.rule_group == rule_group.id
).select(rule_group.id,
where=(rule_group.model.in_(model_names))
& ~rule_group.id.in_(rule_table.select(rule_table.rule_group))
& rule_group_group.group.in_(groups or [-1])))
no_rules = cursor.fetchone()
clause = defaultdict(lambda: ['OR'])
clause_global = defaultdict(lambda: ['OR'])
# Use root user without context to prevent recursion
with transaction.set_user(0), transaction.set_context(user=0):
rules = cls.browse(ids)
for rule in rules:
decoder = PYSONDecoder(
cls._get_context(rule.rule_group.model))
assert rule.domain, ('Rule domain empty,'
'check if migration was done')
dom = decoder.decode(rule.domain)
target_model = rule.rule_group.model
if target_model in model2field:
target_dom = ['OR']
for field in model2field[target_model]:
target_dom.append((field, 'where', dom))
dom = target_dom
if rule.rule_group.global_p:
clause_global[rule.rule_group].append(dom)
else:
clause[rule.rule_group].append(dom)
if no_rules:
group_id = no_rules[0]
clause[RuleGroup(group_id)] = []
return clause, clause_global
@classmethod
def domain_get(cls, model_name, mode='read'):
pool = Pool()
transaction = Transaction()
# root user above constraint
if transaction.user == 0 or not transaction.check_access:
return []
assert mode in cls.modes
model_names, _ = _get_access_models(pool.get(model_name))
key = (model_name, mode) + cls._get_cache_key(model_names)
domain = cls._domain_get_cache.get(key, False)
if domain is not False:
return domain
clause, clause_global = cls.get(model_name, mode=mode)
clause = list(clause.values())
if clause:
clause.insert(0, 'OR')
clause_global = list(clause_global.values())
if clause_global:
clause_global.insert(0, 'AND')
if clause and clause_global:
clause = ['AND', clause_global, clause]
elif clause_global:
clause = clause_global
cls._domain_get_cache.set(key, clause)
return clause
@classmethod
def query_get(cls, model_name, mode='read'):
pool = Pool()
Model = pool.get(model_name)
domain = cls.domain_get(model_name, mode=mode)
# Use root to prevent infinite recursion
with Transaction().set_user(0, set_context=True), inactive_records():
return Model.search(domain, order=[], query=True)
@classmethod
def delete(cls, rules):
super(Rule, cls).delete(rules)
# Restart the cache on the domain_get method of ir.rule
cls._domain_get_cache.clear()
@classmethod
def create(cls, vlist):
res = super(Rule, cls).create(vlist)
# Restart the cache on the domain_get method of ir.rule
cls._domain_get_cache.clear()
return res
@classmethod
def write(cls, rules, vals, *args):
super(Rule, cls).write(rules, vals, *args)
# Restart the cache on the domain_get method
cls._domain_get_cache.clear()