# -*- coding: utf-8 -*- # This file is part of Tryton. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. import doctest import glob import hashlib import inspect import json import multiprocessing import operator import os import pathlib import re import subprocess import sys import time import unittest import unittest.mock from configparser import ConfigParser from fnmatch import fnmatchcase from functools import reduce, wraps from itertools import chain from lxml import etree from sql import Table from werkzeug.test import Client from trytond import backend from trytond.cache import Cache from trytond.config import config, parse_uri from trytond.model import ( ModelSingleton, ModelSQL, ModelStorage, ModelView, Workflow, fields) from trytond.model.fields import Function from trytond.pool import Pool, isregisteredby from trytond.protocols.wrappers import Response from trytond.pyson import PYSONDecoder, PYSONEncoder from trytond.tools import file_open, find_dir, is_instance_method from trytond.transaction import Transaction, TransactionError from trytond.wizard import StateAction, StateView from trytond.wsgi import app __all__ = [ 'CONTEXT', 'Client', 'DB_NAME', 'ModuleTestCase', 'RouteTestCase', 'USER', 'activate_module', 'doctest_checker', 'doctest_setup', 'doctest_teardown', 'load_doc_tests', 'with_transaction', ] Pool.test = True Pool.start() USER = 1 CONTEXT = {} if 'DB_NAME' in os.environ: DB_NAME = os.environ['DB_NAME'] elif backend.name == 'sqlite': DB_NAME = ':memory:' else: DB_NAME = 'test_' + str(int(time.time())) os.environ['DB_NAME'] = DB_NAME DB_CACHE = os.environ.get('DB_CACHE') def _cpu_count(): try: return multiprocessing.cpu_count() except NotImplementedError: return 1 DB_CACHE_JOBS = os.environ.get('DB_CACHE_JOBS', str(_cpu_count())) def activate_module(modules, lang='en'): ''' Activate modules for the tested database ''' if isinstance(modules, str): modules = [modules] name = '-'.join(modules) if lang != 'en': name += '--' + lang if not db_exist(DB_NAME) and restore_db_cache(name): return create_db(lang=lang) with Transaction().start(DB_NAME, 1, close=True) as transaction: pool = Pool() Module = pool.get('ir.module') records = Module.search([ ('name', 'in', modules), ]) assert len(records) == len(modules) records = Module.search([ ('name', 'in', modules), ('state', '!=', 'activated'), ]) if records: Module.activate(records) transaction.commit() ActivateUpgrade = pool.get('ir.module.activate_upgrade', type='wizard') instance_id, _, _ = ActivateUpgrade.create() transaction.commit() ActivateUpgrade(instance_id).transition_upgrade() ActivateUpgrade.delete(instance_id) transaction.commit() backup_db_cache(name) def restore_db_cache(name): result = False if DB_CACHE: cache_file = _db_cache_file(DB_CACHE, name) if backend.name == 'sqlite': result = _sqlite_copy(cache_file, restore=True) elif backend.name == 'postgresql': result = _pg_restore(cache_file) if result: Pool(DB_NAME).init() return result def backup_db_cache(name): if DB_CACHE: if not DB_CACHE.startswith('postgresql://'): os.makedirs(DB_CACHE, exist_ok=True) cache_file = _db_cache_file(DB_CACHE, name) if backend.name == 'sqlite': _sqlite_copy(cache_file) elif backend.name == 'postgresql': _pg_dump(cache_file) def _db_cache_file(path, name): hash_name = hashlib.shake_128(name.encode('utf8')).hexdigest(40 // 2) if DB_CACHE.startswith('postgresql://'): return f"{DB_CACHE}/test-{hash_name}" else: return os.path.join(path, f'{hash_name}-{backend.name}.dump') def _sqlite_copy(file_, restore=False): import sqlite3 as sqlite if ((restore and not os.path.exists(file_)) or (not restore and os.path.exists(file_))): return False with Transaction().start(DB_NAME, 0) as transaction, \ sqlite.connect(file_) as conn2: conn1 = transaction.connection if restore: conn2, conn1 = conn1, conn2 if hasattr(conn1, 'backup'): conn1.backup(conn2) else: try: import sqlitebck except ImportError: return False sqlitebck.copy(conn1, conn2) return True def _pg_options(): uri = parse_uri(config.get('database', 'uri')) options = [] env = os.environ.copy() if uri.hostname: options.extend(['-h', uri.hostname]) if uri.port: options.extend(['-p', str(uri.port)]) if uri.username: options.extend(['-U', uri.username]) if uri.password: env['PGPASSWORD'] = uri.password return options, env def _pg_restore(cache_file): def restore_from_template(): cache_name = cache_file[len(DB_CACHE) + 1:] if not db_exist(cache_name): return False with Transaction().start( None, 0, close=True, autocommit=True) as transaction: if db_exist(DB_NAME): transaction.database.drop(transaction.connection, DB_NAME) transaction.database.create( transaction.connection, DB_NAME, cache_name) return True def restore_from_file(): if not os.path.exists(cache_file): return False with Transaction().start( None, 0, close=True, autocommit=True) as transaction: transaction.database.create(transaction.connection, DB_NAME) cmd = ['pg_restore', '-d', DB_NAME, '-j', DB_CACHE_JOBS] options, env = _pg_options() cmd.extend(options) cmd.append(cache_file) return not subprocess.call(cmd, env=env) if cache_file.startswith('postgresql://'): return restore_from_template() else: try: return restore_from_file() except OSError: return restore_from_template() def _pg_dump(cache_file): def dump_on_template(): cache_name = cache_file[len(DB_CACHE) + 1:] if db_exist(cache_name): return False # Ensure any connection is left open backend.Database(DB_NAME).close() with Transaction().start( None, 0, close=True, autocommit=True) as transaction: transaction.database.create( transaction.connection, cache_name, DB_NAME) return True def dump_on_file(): if os.path.exists(cache_file): return False # Use directory format to support multiple processes cmd = ['pg_dump', '-f', cache_file, '-F', 'd', '-j', DB_CACHE_JOBS] options, env = _pg_options() cmd.extend(options) cmd.append(DB_NAME) return not subprocess.call(cmd, env=env) if cache_file.startswith('postgresql://'): dump_on_template() else: try: return dump_on_file() except OSError: return dump_on_template() def with_transaction(user=1, context=None): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): extras = {} while True: with Transaction().start( DB_NAME, user, context=context, **extras) as transaction: try: result = func(*args, **kwargs) except TransactionError as e: transaction.rollback() transaction.tasks.clear() e.fix(extras) continue finally: transaction.rollback() # Clear remaining tasks transaction.tasks.clear() # Drop the cache as the transaction is rollbacked Cache.drop(DB_NAME) return result return wrapper return decorator class _DBTestCase(unittest.TestCase): module = None extras = None language = 'en' @classmethod def setUpClass(cls): drop_db() modules = [cls.module] if cls.extras: modules.extend(cls.extras) activate_module(modules, lang=cls.language) super().setUpClass() @classmethod def tearDownClass(cls): super().tearDownClass() drop_db() class ModuleTestCase(_DBTestCase): "Tryton Module Test Case" @with_transaction() def test_rec_name(self): for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue # Skip testing default value even if the field doesn't exist # as there is a fallback to id if model._rec_name == 'name': continue with self.subTest(model=mname): self.assertIn(model._rec_name, model._fields.keys(), msg='Wrong _rec_name "%s" for %s' % ( model._rec_name, mname)) field = model._fields[model._rec_name] self.assertIn(field._type, {'char', 'text'}, msg="Wrong '%s' type for _rec_name of %s'" % ( field._type, mname)) @with_transaction() def test_model__access__(self): "Test existing model __access__" pool = Pool() for mname, Model in pool.iterobject(): if not isregisteredby(Model, self.module): continue for field_name in Model.__access__: with self.subTest(model=mname, field=field_name): self.assertIn(field_name, Model._fields.keys(), msg="Wrong __access__ '%s' for %s" % ( field_name, mname)) field = Model._fields[field_name] Target = field.get_target() self.assertTrue( Target, msg='Missing target for __access__ "%s" of %s' % ( field_name, mname)) @with_transaction() def test_view(self): 'Test validity of all views of the module' pool = Pool() View = pool.get('ir.ui.view') views = View.search([ ('module', '=', self.module), ]) directory = find_dir( self.module, subdir='modules' if self.module not in {'ir', 'res'} else '') view_files = set(glob.glob(os.path.join(directory, 'view', '*.xml'))) for view in views: if view.name: view_files.discard(os.path.join( directory, 'view', view.name + '.xml')) if not view.model: continue with self.subTest(view=view): if not view.inherit or view.inherit.model == view.model: self.assertTrue(view.arch, msg='missing architecture for view "%(name)s" ' 'of model "%(model)s"' % { 'name': view.name or str(view.id), 'model': view.model, }) if view.inherit and view.inherit.model == view.model: view_id = view.inherit.id else: view_id = view.id model = view.model Model = pool.get(model) res = Model.fields_view_get(view_id) self.assertEqual(res['model'], model) tree = etree.fromstring(res['arch']) validator = etree.RelaxNG(etree=View.get_rng(res['type'])) validator.assertValid(tree) tree_root = tree.getroottree().getroot() for element in tree_root.iter(): with self.subTest(element=element): if element.tag in { 'field', 'label', 'separator', 'group', 'page'}: attrs = ['name'] if element.tag == 'field': attrs += ['icon', 'symbol'] for attr in attrs: field = element.get(attr) if field: self.assertIn(field, res['fields'].keys(), msg='Missing field: %s in %s' % ( field, Model.__name__)) if element.tag == 'button': button_name = element.get('name') self.assertIn(button_name, Model._buttons.keys(), msg="Button '%s' is not in %s._buttons" % (button_name, Model.__name__)) self.assertFalse(view_files, msg="unused view files") @with_transaction() def test_icon(self): "Test icons of the module" pool = Pool() Icon = pool.get('ir.ui.icon') icons = Icon.search([('module', '=', self.module)]) directory = find_dir( self.module, subdir='modules' if self.module not in {'ir', 'res'} else '') icon_files = set(glob.glob(os.path.join(directory, 'icons', '*.svg'))) for icon in icons: icon_files.discard(os.path.join( directory, icon.path.replace('/', os.sep))) with self.subTest(icon=icon): self.assertTrue(icon.icon) self.assertFalse(icon_files, msg="unused icon files") @with_transaction() def test_rpc_callable(self): 'Test that RPC methods are callable' for _, model in Pool().iterobject(): for method_name in model.__rpc__: with self.subTest(model=model, method=method_name): self.assertTrue( callable(getattr(model, method_name, None)), msg="'%s' is not callable on '%s'" % (method_name, model.__name__)) @with_transaction() def test_missing_depends(self): 'Test for missing depends' for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue for fname, field in model._fields.items(): depends = { f for f in field.depends if not f.startswith('_parent_')} with self.subTest(model=mname, field=fname): self.assertLessEqual(depends, set(model._fields), msg='Unknown depends %s in "%s"."%s"' % ( list(depends - set(model._fields)), mname, fname)) if issubclass(model, ModelView): for bname, button in model._buttons.items(): depends = set(button.get('depends', [])) with self.subTest(model=mname, button=bname): self.assertLessEqual(depends, set(model._fields), msg='Unknown depends %s in button "%s"."%s"' % ( list(depends - set(model._fields)), mname, bname)) @with_transaction() def test_depends(self): "Test depends" def test_missing_relation(depend, depends, qualname): prefix = [] for d in depend.split('.'): if d.startswith('_parent_'): relation = '.'.join( prefix + [d[len('_parent_'):]]) self.assertIn(relation, depends, msg='Missing "%s" in %s' % (relation, qualname)) prefix.append(d) def test_parent_empty(depend, qualname): if depend.startswith('_parent_'): self.assertIn('.', depend, msg='Invalid empty "%s" in %s' % (depend, qualname)) def test_missing_parent(model, depend, depends, qualname): dfield = model._fields.get(depend) parent_depends = {d.split('.', 1)[0] for d in depends} if dfield and dfield._type == 'many2one': target = dfield.get_target() for tfield in target._fields.values(): if (tfield._type == 'one2many' and tfield.model_name == mname and tfield.field == depend): self.assertIn('_parent_%s' % depend, parent_depends, msg='Missing "_parent_%s" in %s' % ( depend, qualname)) def test_depend_exists(model, depend, qualname): try: depend, nested = depend.split('.', 1) except ValueError: nested = None if depend.startswith('_parent_'): depend = depend[len('_parent_'):] self.assertIsInstance(getattr(model, depend, None), fields.Field, msg='Unknown "%s" in %s' % (depend, qualname)) if nested: target = getattr(model, depend).get_target() test_depend_exists(target, nested, qualname) for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue for fname, field in model._fields.items(): with self.subTest(model=mname, field=fname): for attribute in [ 'depends', 'on_change', 'on_change_with', 'selection_change_with', 'autocomplete']: depends = getattr(field, attribute, set()) if attribute == 'depends': depends |= field.display_depends depends |= field.edition_depends depends |= field.validation_depends qualname = '"%s"."%s"."%s"' % (mname, fname, attribute) for depend in depends: test_depend_exists(model, depend, qualname) test_missing_relation(depend, depends, qualname) test_parent_empty(depend, qualname) if attribute != 'depends': test_missing_parent( model, depend, depends, qualname) @with_transaction() def test_field_methods(self): 'Test field methods' def test_methods(mname, model, attr): for prefixes in [['default_'], ['on_change_', 'on_change_with_'], ['order_'], ['domain_'], ['autocomplete_']]: if attr in {'on_change_with', 'on_change_notify'}: continue # TODO those method should be renamed if attr == 'default_get': continue if mname == 'ir.rule' and attr == 'domain_get': continue # Skip if it is a field if attr in model._fields: continue fnames = [attr[len(prefix):] for prefix in prefixes if attr.startswith(prefix)] if not fnames: continue self.assertTrue(any(f in model._fields for f in fnames), msg='Field method "%s"."%s" for unknown field' % ( mname, attr)) if attr.startswith('default_'): fname = attr[len('default_'):] if isinstance(model._fields[fname], fields.MultiValue): try: getattr(model, attr)(pattern=None) # get_multivalue may raise an AttributeError # if pattern is not defined on the model except AttributeError: pass else: getattr(model, attr)() elif attr.startswith('order_'): model.search([], order=[(attr[len('order_'):], None)]) elif attr.startswith('domain_'): model.search([(attr[len('domain_'):], '=', None)]) elif any(attr.startswith(p) for p in [ 'on_change_', 'on_change_with_', 'autocomplete_']): record = model() getattr(record, attr)() for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue for attr in dir(model): with self.subTest(model=mname, attr=attr): test_methods(mname, model, attr) @with_transaction() def test_field_relation_target(self): "Test field relation and target" pool = Pool() def test_relation_target(mname, model, fname, field): if isinstance(field, fields.One2Many): Relation = field.get_target() rfield = field.field elif isinstance(field, fields.Many2Many): Relation = field.get_relation() rfield = field.origin else: return if rfield: self.assertIn(rfield, Relation._fields.keys(), msg=('Missing relation field "%s" on "%s" ' 'for "%s"."%s"') % ( rfield, Relation.__name__, mname, fname)) reverse_field = Relation._fields[rfield] self.assertIn( reverse_field._type, [ 'reference', 'many2one', 'one2one'], msg=('Wrong type for relation field "%s" on "%s" ' 'for "%s"."%s"') % ( rfield, Relation.__name__, mname, fname)) if (reverse_field._type == 'many2one' and issubclass(model, ModelSQL) # Do not test table_query models # as they can manipulate their id and not callable(model.table_query)): self.assertEqual( reverse_field.model_name, model.__name__, msg=('Wrong model for relation field "%s" on "%s" ' 'for "%s"."%s"') % ( rfield, Relation.__name__, mname, fname)) Target = field.get_target() self.assertTrue( Target, msg='Missing target for "%s"."%s"' % (mname, fname)) for mname, model in pool.iterobject(): if not isregisteredby(model, self.module): continue for fname, field in model._fields.items(): with self.subTest(model=mname, field=fname): test_relation_target(mname, model, fname, field) @with_transaction() def test_field_relation_domain(self): "Test domain of relation fields" pool = Pool() for mname, model in pool.iterobject(): if not isregisteredby(model, self.module): continue for fname, field in model._fields.items(): if not field.domain: continue if hasattr(field, 'get_target'): Target = field.get_target() else: continue if not issubclass(Target, ModelStorage): continue with self.subTest(model=mname, field=fname): domain = PYSONDecoder({}).decode( PYSONEncoder().encode(field.domain)) Target.search(domain, limit=1) @with_transaction() def test_menu_action(self): 'Test that menu actions are accessible to menu\'s group' pool = Pool() Menu = pool.get('ir.ui.menu') ModelData = pool.get('ir.model.data') module_menus = ModelData.search([ ('model', '=', 'ir.ui.menu'), ('module', '=', self.module), ]) menus = Menu.browse([mm.db_id for mm in module_menus]) for menu, module_menu in zip(menus, module_menus): if not menu.action_keywords: continue menu_groups = set(menu.groups) actions_groups = reduce(operator.or_, (set(k.action.groups) for k in menu.action_keywords if k.keyword == 'tree_open')) if not actions_groups: continue with self.subTest(menu=menu): self.assertLessEqual(menu_groups, actions_groups, msg='Menu "%(menu_xml_id)s" actions are not accessible to ' '%(groups)s' % { 'menu_xml_id': module_menu.fs_id, 'groups': ','.join(g.name for g in menu_groups - actions_groups), }) @with_transaction() def test_model_access(self): 'Test missing default model access' pool = Pool() Access = pool.get('ir.model.access') no_groups = {a.model for a in Access.search([ ('group', '=', None), ])} def has_access(Model, models): if Model.__name__ in models: return True for field_name in Model.__access__: Target = Model._fields[field_name].get_target() if has_access(Target, models): return True for mname, Model in pool.iterobject(): if has_access(Model, no_groups): no_groups.add(mname) with_groups = {a.model for a in Access.search([ ('group', '!=', None), ])} self.assertGreaterEqual(no_groups, with_groups, msg='Model "%(models)s" are missing a default access' % { 'models': list(with_groups - no_groups), }) @with_transaction() def test_workflow_transitions(self): 'Test all workflow transitions exist' for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue if not issubclass(model, Workflow): continue field = getattr(model, model._transition_state) if isinstance(field.selection, (tuple, list)): values = field.selection else: # instance method may not return all the possible values if is_instance_method(model, field.selection): continue values = getattr(model, field.selection)() states = set(dict(values)) transition_states = set(chain(*model._transitions)) with self.subTest(model=mname): self.assertLessEqual(transition_states, states, msg='Unknown transition states "%(states)s" ' 'in model "%(model)s". ' % { 'states': list(transition_states - states), 'model': model.__name__, }) @with_transaction() def test_wizards(self): 'Test wizards are correctly defined' for wizard_name, wizard in Pool().iterobject(type='wizard'): if not isregisteredby(wizard, self.module, type_='wizard'): continue session_id, start_state, _ = wizard.create() with self.subTest(wizard=wizard_name): self.assertIn(start_state, wizard.states.keys(), msg='Unknown start state ' '"%(state)s" on wizard "%(wizard)s"' % { 'state': start_state, 'wizard': wizard_name, }) wizard_instance = wizard(session_id) for state_name, state in wizard_instance.states.items(): with self.subTest(wizard=wizard_name, state=state_name): if isinstance(state, StateView): # Don't test defaults as they may depend on context view = state.get_view(wizard_instance, state_name) self.assertEqual( view.get('type'), 'form', msg='Wrong view type for "%(state)s" ' 'on wizard "%(wizard)s"' % { 'state': state_name, 'wizard': wizard_name, }) for button in state.get_buttons( wizard_instance, state_name): if button['state'] == wizard.end_state: continue self.assertIn( button['state'], wizard_instance.states.keys(), msg='Unknown button state from "%(state)s" ' 'on wizard "%(wizard)s' % { 'state': state_name, 'wizard': wizard_name, }) if isinstance(state, StateAction): state.get_action() @with_transaction() def test_selection_fields(self): 'Test selection values' for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue for field_name, field in model._fields.items(): selection = getattr(field, 'selection', None) if selection is None: continue selection_values = field.selection if not isinstance(selection_values, (tuple, list)): sel_func = getattr(model, field.selection) if not is_instance_method(model, field.selection): selection_values = sel_func() else: record = model() selection_values = sel_func(record) with self.subTest(model=mname, field=field_name): self.assertTrue(all(len(v) == 2 for v in selection_values), msg='Invalid selection values "%(values)s" on field ' '"%(field)s" of model "%(model)s"' % { 'values': selection_values, 'field': field_name, 'model': model.__name__, }) if field._type == 'multiselection': self.assertNotIn(None, dict(selection_values).keys()) @with_transaction() def test_function_fields(self): "Test function fields methods" for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue for field_name, field in model._fields.items(): if not isinstance(field, Function): continue for func_name in [field.getter, field.setter, field.searcher]: if not func_name: continue with self.subTest( model=mname, field=field_name, function=func_name): self.assertTrue(getattr(model, func_name, None), msg="Missing method '%(func_name)s' " "on model '%(model)s' for field '%(field)s'" % { 'func_name': func_name, 'model': model.__name__, 'field': field_name, }) if func_name == field.searcher: getattr(model, field.searcher)( field_name, (field_name, '=', None)) @with_transaction() def test_ir_action_window(self): 'Test action windows are correctly defined' pool = Pool() ModelData = pool.get('ir.model.data') ActionWindow = pool.get('ir.action.act_window') def test_action_window(action_window): if not action_window.res_model: return Model = pool.get(action_window.res_model) for active_id, active_ids in [ (None, []), (1, [1]), (1, [1, 2]), ]: decoder = PYSONDecoder({ 'active_id': active_id, 'active_ids': active_ids, 'active_model': action_window.res_model, }) domain = decoder.decode(action_window.pyson_domain) order = decoder.decode(action_window.pyson_order) context = decoder.decode(action_window.pyson_context) search_value = decoder.decode(action_window.pyson_search_value) if action_window.context_domain: domain = ['AND', domain, decoder.decode(action_window.context_domain)] with Transaction().set_context(context): Model.search( domain, order=order, limit=action_window.limit) if search_value: Model.search(search_value) for action_domain in action_window.act_window_domains: if not action_domain.domain: continue Model.search(decoder.decode(action_domain.domain)) if action_window.context_model: pool.get(action_window.context_model) for model_data in ModelData.search([ ('module', '=', self.module), ('model', '=', 'ir.action.act_window'), ]): action_window = ActionWindow(model_data.db_id) with self.subTest(action_window=action_window): test_action_window(action_window) @with_transaction() def test_modelsingleton_inherit_order(self): 'Test ModelSingleton, ModelSQL, ModelStorage order in the MRO' for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue if (not issubclass(model, ModelSingleton) or not issubclass(model, ModelSQL)): continue mro = inspect.getmro(model) singleton_index = mro.index(ModelSingleton) sql_index = mro.index(ModelSQL) with self.subTest(model=mname): self.assertLess(singleton_index, sql_index, msg="ModelSingleton must appear before ModelSQL " "in the parent classes of '%s'." % mname) @with_transaction() def test_pool_slots(self): "Test pool object has __slots__" for type_ in ['model', 'wizard', 'report']: for name, cls in Pool().iterobject(type_): if not isregisteredby(cls, self.module): continue if getattr(cls, '__no_slots__', None): continue with self.subTest(type=type_, name=name): for kls in cls.__mro__: if kls is object: continue self.assertTrue(hasattr(kls, '__slots__'), msg="The %s of %s '%s' has no __slots__" % (kls, type_, name)) @with_transaction() def test_buttons_registered(self): 'Test all buttons are registered in ir.model.button' pool = Pool() Button = pool.get('ir.model.button') for mname, model in Pool().iterobject(): if not isregisteredby(model, self.module): continue if not issubclass(model, ModelView): continue ir_buttons = {b.name for b in Button.search([ ('model.model', '=', model.__name__), ])} buttons = set(model._buttons) with self.subTest(model=mname): self.assertGreaterEqual(ir_buttons, buttons, msg='The buttons "%(buttons)s" of Model "%(model)s" ' 'are not registered in ir.model.button.' % { 'buttons': list(buttons - ir_buttons), 'model': model.__name__, }) @with_transaction() def test_buttons_states(self): "Test the states of buttons" pool = Pool() keys = {'readonly', 'invisible', 'icon', 'pre_validate', 'depends'} for mname, model in pool.iterobject(): if not isregisteredby(model, self.module): continue if not issubclass(model, ModelView): continue for button, states in model._buttons.items(): with self.subTest(model=mname, button=button): self.assertTrue(set(states).issubset(keys), msg='The button "%(button)s" of Model "%(model)s" has ' 'extra keys "%(keys)s".' % { 'button': button, 'model': mname, 'keys': set(states) - keys, }) @with_transaction() def test_xml_files(self): "Test validity of the xml files of the module" config = ConfigParser() with file_open('%s/tryton.cfg' % self.module, subdir='modules', mode='r', encoding='utf-8') as fp: config.read_file(fp) if not config.has_option('tryton', 'xml'): return with file_open('tryton.rng', subdir='', mode='rb') as fp: rng = etree.parse(fp) validator = etree.RelaxNG(etree=rng) for xml_file in filter(None, config.get('tryton', 'xml').splitlines()): with self.subTest(xml=xml_file): with file_open('%s/%s' % (self.module, xml_file), subdir='modules', mode='rb') as fp: tree = etree.parse(fp) validator.assertValid(tree) class RouteTestCase(_DBTestCase): "Tryton Route Test Case" @classmethod def setUpClass(cls): super().setUpClass() with Transaction().start(DB_NAME, 1): cls.setUpDatabase() @classmethod def setUpDatabase(cls): pass @property def db_name(self): return DB_NAME def client(self): return Client(app, Response) def db_exist(name=DB_NAME): database = backend.Database().connect() return name in database.list() def create_db(name=DB_NAME, lang='en'): if not db_exist(name): database = backend.Database() database.connect() connection = database.get_connection(autocommit=True) try: database.create(connection, name) finally: database.put_connection(connection, True) database = backend.Database(name) connection = database.get_connection() try: with connection.cursor() as cursor: database.init() ir_configuration = Table('ir_configuration') cursor.execute(*ir_configuration.insert( [ir_configuration.language], [[lang]])) connection.commit() finally: database.put_connection(connection) pool = Pool(name) pool.init(update=['res', 'ir'], lang=[lang]) with Transaction().start(name, 0): User = pool.get('res.user') Lang = pool.get('ir.lang') language, = Lang.search([('code', '=', lang)]) language.translatable = True language.save() users = User.search([('login', '!=', 'root')]) User.write(users, { 'language': language.id, }) Module = pool.get('ir.module') Module.update_list() else: pool = Pool(name) pool.init() class ExtensionTestCase(unittest.TestCase): extension = None @classmethod def setUpClass(cls): super().setUpClass() cls._activate_extension() @classmethod def tearDownClass(cls): super().tearDownClass() cls._deactivate_extension() @classmethod @with_transaction() def _activate_extension(cls): connection = Transaction().connection cursor = connection.cursor() cursor.execute('CREATE EXTENSION "%s"' % cls.extension) connection.commit() cls._clear_cache() @classmethod @with_transaction() def _deactivate_extension(cls): connection = Transaction().connection cursor = connection.cursor() cursor.execute('DROP EXTENSION "%s"' % cls.extension) connection.commit() cls._clear_cache() @classmethod def _clear_cache(cls): backend.Database._has_proc.clear() def drop_db(name=DB_NAME): if db_exist(name): database = backend.Database(name) database.close() with Transaction().start( None, 0, close=True, autocommit=True) as transaction: database.drop(transaction.connection, name) Pool.stop(name) Cache.drop(name) def drop_create(name=DB_NAME, lang='en'): if db_exist(name): drop_db(name) create_db(name, lang) def doctest_setup(test): return drop_create() def doctest_teardown(test): unittest.mock.patch.stopall() return drop_db() STRIP_DECIMAL = doctest.register_optionflag('STRIP_DECIMAL') class OutputChecker(doctest.OutputChecker): def check_output(self, want, got, optionflags): if optionflags & STRIP_DECIMAL: want = self._strip_decimal(want) got = self._strip_decimal(got) return super().check_output(want, got, optionflags) def _strip_decimal(self, value): return re.sub( r"Decimal\s*\('(\d*\.\d*?)0+'\)", r"Decimal('\1')", value) doctest_checker = OutputChecker() def load_doc_tests(name, path, loader, tests, pattern): def shouldIncludeScenario(path): return ( loader.testNamePatterns is None or any( fnmatchcase(path, pattern) for pattern in loader.testNamePatterns)) directory = os.path.dirname(path) # TODO: replace by glob root_dir in Python 3.10 cwd = os.getcwd() optionflags = ( doctest.REPORT_ONLY_FIRST_FAILURE | doctest.ELLIPSIS | doctest.IGNORE_EXCEPTION_DETAIL) if backend.name == 'sqlite': optionflags |= STRIP_DECIMAL try: os.chdir(directory) for scenario in filter( shouldIncludeScenario, glob.glob('*.rst')): config = pathlib.Path(scenario).with_suffix('.json') if os.path.exists(config): with config.open() as fp: configs = json.load(fp) else: configs = [{}] for globs in configs: tests.addTests(doctest.DocFileSuite( scenario, package=name, globs=globs, tearDown=doctest_teardown, encoding='utf-8', checker=doctest_checker, optionflags=optionflags)) finally: os.chdir(cwd) return tests class TestSuite(unittest.TestSuite): def run(self, *args, **kwargs): while True: try: exist = db_exist() break except backend.DatabaseOperationalError as err: # Retry on connection error sys.stderr.write(str(err)) time.sleep(1) result = super(TestSuite, self).run(*args, **kwargs) if not exist: drop_db() return result def load_tests(loader, tests, pattern): ''' Return test suite for other modules ''' return TestSuite()