# This file is part of Tryton. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. __all__ = ['Wizard', 'StateView', 'StateTransition', 'StateAction', 'StateReport', 'Button'] import copy import json from trytond.i18n import gettext from trytond.model import ModelSQL from trytond.model.exceptions import AccessError from trytond.model.fields import states_validate from trytond.pool import Pool, PoolBase from trytond.protocols.jsonrpc import JSONDecoder, JSONEncoder from trytond.pyson import PYSONEncoder from trytond.rpc import RPC from trytond.tools import cached_property from trytond.transaction import Transaction, check_access from trytond.url import URLMixin class Button(object): ''' Define a button on wizard. ''' def __init__(self, string, state, icon='', default=False, states=None, validate=None): self.string = string self.state = state self.icon = icon self.default = bool(default) self.__states = None self.states = states or {} self.validate = validate @property def states(self): return self.__states @states.setter def states(self, value): states_validate(value) self.__states = value class State(object): ''' A State of a wizard. ''' name = None class StateView(State): ''' A view state of a wizard. ''' def __init__(self, model_name, view, buttons): ''' model_name is the name of the model view is the xml id of the view buttons is a list of Button ''' self.model_name = model_name self.view = view self.buttons = buttons assert len(self.buttons) == len(set(b.state for b in self.buttons)) assert len([b for b in self.buttons if b.default]) <= 1 def get_view(self, wizard, state_name): ''' Returns the view definition ''' Model_ = Pool().get(self.model_name) ModelData = Pool().get('ir.model.data') if self.view: module, fs_id = self.view.split('.') view_id = ModelData.get_id(module, fs_id) else: view_id = None return Model_.fields_view_get(view_id=view_id, view_type='form') def get_defaults(self, wizard, state_name, fields): ''' Returns defaults values for the fields ''' pool = Pool() Model_ = pool.get(self.model_name) defaults = Model_.default_get(fields) default = getattr(wizard, 'default_%s' % state_name, None) if default: defaults.update(default(fields)) self._complete_values(defaults) return defaults def get_values(self, wizard, state_name, fields): "Return values for the fields" values = {} value = getattr(wizard, 'value_%s' % state_name, None) if value: values.update(value(fields)) self._complete_values(values) return values def _complete_values(self, values): pool = Pool() Model_ = pool.get(self.model_name) for field_name, value in list(values.items()): if '.' in field_name: continue field = Model_._fields[field_name] if value and field._type == 'many2one': Target = field.get_target() if 'rec_name' in Target._fields: values.setdefault( field_name + '.', {})['rec_name'] = Target( value).rec_name def get_buttons(self, wizard, state_name): ''' Returns button definitions translated ''' Translation = Pool().get('ir.translation') def translation_key(button): return (','.join((wizard.__name__, state_name, button.state)), 'wizard_button', Transaction().language, button.string) translation_keys = [translation_key(button) for button in self.buttons] translations = Translation.get_sources(translation_keys) encoder = PYSONEncoder() result = [] for button in self.buttons: validate = (button.validate if button.validate is not None else button.state != wizard.end_state) result.append({ 'state': button.state, 'icon': button.icon, 'default': button.default, 'validate': validate, 'string': (translations.get(translation_key(button)) or button.string), 'states': encoder.encode(button.states), }) return result class StateTransition(State): ''' A transition state of a wizard. ''' class StateAction(StateTransition): ''' An action state of a wizard. ''' def __init__(self, action_id): ''' action_id is a string containing ``module.xml_id`` ''' super(StateAction, self).__init__() self.action_id = action_id def get_action(self): "Returns action definition" pool = Pool() ModelData = pool.get('ir.model.data') Action = pool.get('ir.action') module, fs_id = self.action_id.split('.') action_id = Action.get_action_id( ModelData.get_id(module, fs_id)) action = Action(action_id) return action.get_action_value() class StateReport(StateAction): 'An report state of a wizard' def __init__(self, report_name): super(StateReport, self).__init__(None) self.report_name = report_name def get_action(self): 'Return report definition' pool = Pool() ActionReport = pool.get('ir.action.report') action_reports = ActionReport.search([ ('report_name', '=', self.report_name), ]) assert action_reports, '%s not found' % self.report_name action_report = action_reports[0] action = action_report.action return action.get_action_value() class Wizard(URLMixin, PoolBase): __no_slots__ = True # To allow setting State start_state = 'start' end_state = 'end' @classmethod def __setup__(cls): super(Wizard, cls).__setup__() cls.__rpc__ = { 'create': RPC(readonly=False), 'delete': RPC(readonly=False), 'execute': RPC(readonly=False, check_access=False), } # Copy states for attr in dir(cls): if not isinstance(getattr(cls, attr), State): continue state_name = attr state = getattr(cls, state_name) # Copy the original state definition to prevent side-effect with # the mutable attributes for parent_cls in cls.__mro__: parent_state = getattr(parent_cls, state_name, None) if isinstance(parent_state, State): state = parent_state state = copy.deepcopy(state) setattr(cls, attr, state) @classmethod def __post_setup__(cls): super(Wizard, cls).__post_setup__() # Set states cls.states = {} for attr in dir(cls): if attr.startswith('_'): continue if isinstance(getattr(cls, attr), State): cls.states[attr] = getattr(cls, attr) @classmethod def __register__(cls, module_name): super(Wizard, cls).__register__(module_name) pool = Pool() Translation = pool.get('ir.translation') Translation.register_wizard(cls, module_name) @classmethod def check_access(cls): pool = Pool() ModelAccess = pool.get('ir.model.access') ActionWizard = pool.get('ir.action.wizard') User = pool.get('res.user') Group = pool.get('res.group') context = Transaction().context if Transaction().user == 0: return with check_access(): model = context.get('active_model') if model: Model = pool.get(model) if model and model != 'ir.ui.menu': ModelAccess.check(model, 'read') models = ActionWizard.get_models( cls.__name__, action_id=context.get('action_id')) if model and models and model not in models: groups = Group.browse(User.get_groups()) raise AccessError( gettext( 'ir.msg_access_wizard_model_error', wizard=cls.__name__, model=model), gettext( 'ir.msg_context_groups', groups=', '.join(g.rec_name for g in groups))) groups = set(User.get_groups()) wizard_groups = ActionWizard.get_groups(cls.__name__, action_id=context.get('action_id')) if wizard_groups: if not groups & wizard_groups: groups = Group.browse(User.get_groups()) raise AccessError( gettext( 'ir.msg_access_wizard_error', wizard=cls.__name__), gettext( 'ir.msg_context_groups', groups=', '.join(g.rec_name for g in groups))) elif model and model != 'ir.ui.menu': if (not callable(getattr(Model, 'table_query', None)) or Model.write.__func__ != ModelSQL.write.__func__): ModelAccess.check(model, 'write') if model: ids = context.get('active_ids') or [] id_ = context.get('active_id') if id_ not in ids and id_ is not None: ids.append(id_) # Check read access Model.read(ids, ['id']) @classmethod def create(cls): "Create a session" Session = Pool().get('ir.session.wizard') cls.check_access() return (Session.create([{}])[0].id, cls.start_state, cls.end_state) @classmethod def delete(cls, session_id): "Delete the session" Session = Pool().get('ir.session.wizard') end = getattr(cls, cls.end_state, None) if end: wizard = cls(session_id) action = end(wizard) else: action = None Session.delete([Session(session_id)]) return action @classmethod def execute(cls, session_id, data, state_name): ''' Execute the wizard state. session_id is a Session id data is a dictionary with the session data to update state_name is the name of state to execute Returns a dictionary with: - ``actions``: a list of Action to execute - ``view``: a dictionary with: - ``fields_view``: a fields/view definition - ``defaults``: a dictionary with default values - ``values``: a dictionary with values - ``buttons``: a list of buttons ''' transaction = Transaction() counter = transaction.counter cls.check_access() wizard = cls(session_id) for key, values in data.items(): record = getattr(wizard, key) for field, value in values.items(): if field == 'id': continue setattr(record, field, value) result = wizard._execute(state_name) # Log before _save increases the counter if transaction.counter != counter and wizard.model: wizard.model.log( wizard.records, 'wizard', f'{cls.__name__}:{state_name}') wizard._save() return result def _execute(self, state_name): if state_name == self.end_state: return {} state = self.states[state_name] result = {} if isinstance(state, StateView): view = state.get_view(self, state_name) fields = list(view['fields'].keys()) defaults = state.get_defaults(self, state_name, fields) values = state.get_values(self, state_name, fields) buttons = state.get_buttons(self, state_name) result['view'] = { 'fields_view': view, 'defaults': defaults, 'values': values, 'buttons': buttons, 'state': state_name, } elif isinstance(state, StateTransition): do_result = None if isinstance(state, StateAction): action = state.get_action() do = getattr(self, 'do_%s' % state_name, None) if do: do_result = do(action) else: do_result = action, {} transition = getattr(self, 'transition_%s' % state_name, None) if transition: result = self._execute(transition()) if do_result: result.setdefault('actions', []).append(do_result) return result def __init__(self, session_id): pool = Pool() Session = pool.get('ir.session.wizard') self._session_id = session_id session = Session(session_id) data = json.loads(session.data, object_hook=JSONDecoder()) for state_name, state in self.states.items(): if isinstance(state, StateView): Target = pool.get(state.model_name) data.setdefault(state_name, {}) setattr(self, state_name, Target(**data[state_name])) def _save(self): "Save the session in database" Session = Pool().get('ir.session.wizard') data = {} for state_name, state in self.states.items(): if isinstance(state, StateView): data[state_name] = getattr(self, state_name)._default_values session = Session(self._session_id) data = json.dumps(data, cls=JSONEncoder, separators=(',', ':')) if data != session.data.encode('utf-8'): Session.write([session], { 'data': data, }) @cached_property def model(self): pool = Pool() context = Transaction().context if context.get('active_model'): return pool.get(context['active_model']) @cached_property def record(self): context = Transaction().context if context.get('active_id') is not None: return self.model(context['active_id']) @cached_property def records(self): context = Transaction().context if context.get('active_ids'): return self.model.browse(context['active_ids']) return []