Initial import from Docker volume

This commit is contained in:
root
2025-12-26 13:11:43 +00:00
commit 4998dc066a
13336 changed files with 1767801 additions and 0 deletions

BIN
.admin.py.swp Executable file

Binary file not shown.

42
__init__.py Executable file
View File

@@ -0,0 +1,42 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import os
import time
import warnings
from email import charset
import __main__
from lxml import etree, objectify
try:
from requests import utils as requests_utils
except ImportError:
requests_utils = None
__version__ = "7.2.9"
os.environ.setdefault(
'TRYTOND_APPNAME',
os.path.basename(getattr(__main__, '__file__', 'trytond')))
os.environ.setdefault('TRYTOND_TZ', os.environ.get('TZ', 'UTC'))
os.environ['TZ'] = 'UTC'
if hasattr(time, 'tzset'):
time.tzset()
if time.tzname[0] != 'UTC':
warnings.warn('Timezone must be set to UTC instead of %s' % time.tzname[0])
# set email encoding for utf-8 to 'quoted-printable'
charset.add_charset('utf-8', charset.QP, charset.QP)
# prevent XML vulnerabilities by default
etree.set_default_parser(etree.XMLParser(resolve_entities=False))
objectify.set_default_parser(objectify.makeparser(resolve_entities=False))
def default_user_agent(name="Tryton"):
return f"{name}/{__version__}"
if requests_utils:
requests_utils.default_user_agent = default_user_agent

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

183
admin.py Executable file
View File

@@ -0,0 +1,183 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import logging
import os
import random
import sys
from getpass import getpass
from sql import Literal, Table
from trytond import backend
from trytond.config import config
from trytond.pool import Pool
from trytond.sendmail import send_test_email
from trytond.transaction import Transaction, TransactionError, inactive_records
__all__ = ['run']
logger = logging.getLogger(__name__)
def run(options):
main_lang = config.get('database', 'language')
init = {}
if options.test_email:
send_test_email(options.test_email)
for db_name in options.database_names:
init[db_name] = False
database = backend.Database(db_name)
database.connect()
if options.update:
if not database.test():
logger.info("init db")
database.init()
init[db_name] = True
elif not database.test():
raise Exception('"%s" is not a Tryton database.' % db_name)
for db_name in options.database_names:
if options.update:
with Transaction().start(db_name, 0) as transaction, \
transaction.connection.cursor() as cursor:
database = backend.Database(db_name)
database.connect()
if not database.test():
raise Exception('"%s" is not a Tryton database.' % db_name)
lang = Table('ir_lang')
cursor.execute(*lang.select(lang.code,
where=lang.translatable == Literal(True)))
lang = set([x[0] for x in cursor])
lang.add(main_lang)
else:
lang = set()
lang |= set(options.languages)
pool = Pool(db_name)
pool.init(update=options.update, lang=list(lang),
activatedeps=options.activatedeps,
indexes=options.indexes)
if options.update_modules_list:
with Transaction().start(db_name, 0) as transaction:
Module = pool.get('ir.module')
Module.update_list()
if lang:
with Transaction().start(db_name, 0) as transaction:
pool = Pool()
Lang = pool.get('ir.lang')
languages = Lang.search([
('code', 'in', lang),
])
Lang.write(languages, {
'translatable': True,
})
for db_name in options.database_names:
if options.email is not None:
email = options.email
elif init[db_name]:
email = input(
'"admin" email for "%s": ' % db_name)
else:
email = None
password = ''
if init[db_name] or options.password:
# try to read password from environment variable
# TRYTONPASSFILE, empty TRYTONPASSFILE ignored
passpath = os.getenv('TRYTONPASSFILE')
if passpath:
try:
with open(passpath) as passfile:
password, = passfile.read().splitlines()
except Exception as err:
sys.stderr.write('Can not read password '
'from "%s": "%s"\n' % (passpath, err))
if not password and not options.reset_password:
while True:
password = getpass(
'"admin" password for "%s": ' % db_name)
password2 = getpass('"admin" password confirmation: ')
if password != password2:
sys.stderr.write('"admin" password confirmation '
'doesn\'t match "admin" password.\n')
continue
if not password:
sys.stderr.write('"admin" password is required.\n')
continue
break
transaction_extras = {}
while True:
with Transaction().start(
db_name, 0, **transaction_extras) as transaction:
try:
pool = Pool()
User = pool.get('res.user')
Configuration = pool.get('ir.configuration')
configuration = Configuration(1)
with inactive_records():
admin, = User.search([('login', '=', 'admin')])
if email is not None:
admin.email = email
if init[db_name] or options.password:
configuration.language = main_lang
if not options.reset_password:
admin.password = password
admin.save()
if options.reset_password:
User.reset_password([admin])
if options.hostname is not None:
configuration.hostname = options.hostname or None
configuration.save()
except TransactionError as e:
transaction.rollback()
e.fix(transaction_extras)
continue
break
with Transaction().start(db_name, 0, readonly=True):
if options.validate is not None:
validate(options.validate, options.validate_percentage)
def validate(models, percentage=100):
from trytond.model import ModelSingleton, ModelStorage
from trytond.model.exceptions import ValidationError
logger = logging.getLogger('validate')
pool = Pool()
if not models:
models = sorted([n for n, _ in pool.iterobject()])
ratio = min(100, percentage) / 100
in_max = Transaction().database.IN_MAX
for name in models:
logger.info("validate: %s", name)
Model = pool.get(name)
if not issubclass(Model, ModelStorage):
continue
offset = 0
limit = in_max
while True:
records = Model.search(
[], order=[('id', 'ASC')], offset=offset, limit=limit)
if not records:
break
records = Model.browse(
random.sample(records, int(len(records) * ratio)))
try:
for record in records:
try:
Model._validate([record])
except ValidationError as exception:
logger.error("%s: KO '%s'", record, exception)
else:
logger.info("%s: OK", record)
except TransactionError:
logger.info("%s: SKIPPED", name)
break
if issubclass(Model, ModelSingleton):
break
offset += limit

45
application.py Executable file
View File

@@ -0,0 +1,45 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import csv
import logging.config
import os
import threading
from io import StringIO
__all__ = ['app']
# Logging must be set before importing
logging_config = os.environ.get('TRYTOND_LOGGING_CONFIG')
logging_level = int(os.environ.get(
'TRYTOND_LOGGING_LEVEL', default=logging.ERROR))
if logging_config:
logging.config.fileConfig(logging_config)
else:
logformat = ('%(process)s %(thread)s [%(asctime)s] '
'%(levelname)s %(name)s %(message)s')
level = max(logging_level, logging.NOTSET)
logging.basicConfig(level=level, format=logformat)
logging.captureWarnings(True)
if os.environ.get('TRYTOND_COROUTINE'):
from gevent import monkey
monkey.patch_all()
from trytond.pool import Pool # noqa: E402
from trytond.wsgi import app # noqa: E402
Pool.start()
# TRYTOND_CONFIG it's managed by importing config
db_names = os.environ.get('TRYTOND_DATABASE_NAMES')
if db_names:
# Read with csv so database name can include special chars
reader = csv.reader(StringIO(db_names))
threads = []
for name in next(reader):
thread = threading.Thread(target=lambda: Pool(name).init())
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
assert len(threads := threading.enumerate()) == 1, f"len({threads}) != 1"

39
backend/__init__.py Executable file
View File

@@ -0,0 +1,39 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import importlib
import urllib.parse
try:
from backports.entry_points_selectable import entry_points
except ImportError:
from importlib.metadata import entry_points
from trytond.config import config
__all__ = [
'name', 'Database', 'TableHandler',
'DatabaseIntegrityError', 'DatabaseDataError', 'DatabaseOperationalError',
'DatabaseTimeoutError']
name = urllib.parse.urlparse(config.get('database', 'uri', default='')).scheme
_modname = 'trytond.backend.%s' % name
try:
_module = importlib.import_module(_modname)
except ImportError:
for ep in entry_points().select(group='trytond.backend', name=name):
try:
_module = ep.load()
break
except ImportError:
continue
else:
raise
Database = _module.Database
DatabaseIntegrityError = _module.DatabaseIntegrityError
DatabaseDataError = _module.DatabaseDataError
DatabaseOperationalError = _module.DatabaseOperationalError
DatabaseTimeoutError = _module.DatabaseTimeoutError
TableHandler = _module.TableHandler

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

183
backend/database.py Executable file
View File

@@ -0,0 +1,183 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from collections import namedtuple
from sql import For
DatabaseIntegrityError = None
DatabaseOperationalError = None
DatabaseTimeoutError = None
SQLType = namedtuple('SQLType', 'base type')
class DatabaseInterface(object):
'''
Define generic interface for database connection
'''
flavor = None
IN_MAX = 1000
def __new__(cls, name=''):
return object.__new__(cls)
def __init__(self, name=''):
self.name = name
def connect(self):
raise NotImplementedError
def get_connection(
self, autocommit=False, readonly=False, statement_timeout=None):
raise NotImplementedError
def put_connection(self, connection, close=False):
raise NotImplementedError
def close(self):
raise NotImplementedError
@classmethod
def create(cls, connection, database_name):
raise NotImplementedError
@classmethod
def drop(cls, connection, database_name):
raise NotImplementedError
def list(self, hostname=None):
raise NotImplementedError
def init(self):
raise NotImplementedError
def test(self, hostname=None):
'''
Test if it is a Tryton database.
'''
raise NotImplementedError
def nextid(self, connection, table, count=1):
pass
def setnextid(self, connection, table, value):
pass
def currid(self, connection, table):
pass
@classmethod
def lock(cls, connection, table):
raise NotImplementedError
def lock_id(self, id, timeout=None):
raise NotImplementedError
def has_constraint(self, constraint):
raise NotImplementedError
def has_returning(self):
return False
def has_multirow_insert(self):
return False
def has_select_for(self):
return False
def get_select_for_skip_locked(self):
return For
def has_window_functions(self):
return False
def has_unaccent(self):
return False
def has_unaccent_indexable(self):
return False
def unaccent(self, value):
return value
def has_similarity(self):
return False
def similarity(self, column, value):
raise NotImplementedError
def has_search_full_text(self):
return False
def format_full_text(self, *documents, language=None):
return '\n'.join(documents)
def format_full_text_query(self, query, language=None):
raise NotImplementedError
def search_full_text(self, document, query):
raise NotImplementedError
def rank_full_text(self, document, query, normalize=None):
"Return the expression that ranks query on document"
raise NotImplementedError
@classmethod
def has_sequence(cls):
return False
def sequence_exist(self, connection, name):
if not self.has_sequence():
return
raise NotImplementedError
def sequence_create(
self, connection, name, number_increment=1, start_value=1):
if not self.has_sequence():
return
raise NotImplementedError
def sequence_update(
self, connection, name, number_increment=1, start_value=1):
if not self.has_sequence():
return
raise NotImplementedError
def sequence_rename(self, connection, old_name, new_name):
if not self.has_sequence():
return
raise NotImplementedError
def sequence_delete(self, connection, name):
if not self.has_sequence():
return
raise NotImplementedError
def sequence_next_number(self, connection, name):
if not self.has_sequence():
return
raise NotImplementedError
def has_channel(self):
return False
def sql_type(self, type_):
pass
def sql_format(self, type_, value):
pass
def json_get(self, column, key=None):
raise NotImplementedError
def json_key_exists(self, column, key):
raise NotImplementedError
def json_any_keys_exist(self, column, keys):
raise NotImplementedError
def json_all_keys_exist(self, column, keys):
raise NotImplementedError
def json_contains(self, column, json):
raise NotImplementedError

12
backend/postgresql/__init__.py Executable file
View File

@@ -0,0 +1,12 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from .database import (
Database, DatabaseDataError, DatabaseIntegrityError,
DatabaseOperationalError, DatabaseTimeoutError)
from .table import TableHandler
__all__ = [
Database, TableHandler,
DatabaseIntegrityError, DatabaseDataError, DatabaseOperationalError,
DatabaseTimeoutError]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

814
backend/postgresql/database.py Executable file
View File

@@ -0,0 +1,814 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import json
import logging
import os
import time
import warnings
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from itertools import chain, repeat
from threading import RLock
from psycopg2 import Binary, connect
from psycopg2.extensions import (
ISOLATION_LEVEL_REPEATABLE_READ, UNICODE, AsIs, cursor, register_adapter,
register_type)
from psycopg2.pool import PoolError, ThreadedConnectionPool
from psycopg2.sql import SQL, Identifier
try:
from psycopg2.extensions import PYDATE, PYDATETIME, PYINTERVAL, PYTIME
except ImportError:
PYDATE, PYDATETIME, PYTIME, PYINTERVAL = None, None, None, None
from psycopg2 import DataError as DatabaseDataError
from psycopg2 import IntegrityError as DatabaseIntegrityError
from psycopg2 import InterfaceError
from psycopg2 import OperationalError as DatabaseOperationalError
from psycopg2 import ProgrammingError
from psycopg2.errors import QueryCanceled as DatabaseTimeoutError
from psycopg2.extras import register_default_json, register_default_jsonb
from sql import Cast, Flavor, For, Table
from sql.conditionals import Coalesce
from sql.functions import Function
from sql.operators import BinaryOperator, Concat
from trytond.backend.database import DatabaseInterface, SQLType
from trytond.config import config, parse_uri
from trytond.tools.gevent import is_gevent_monkey_patched
__all__ = [
'Database',
'DatabaseIntegrityError', 'DatabaseDataError', 'DatabaseOperationalError',
'DatabaseTimeoutError']
logger = logging.getLogger(__name__)
os.environ['PGTZ'] = os.environ.get('TZ', '')
_timeout = config.getint('database', 'timeout')
_minconn = config.getint('database', 'minconn', default=1)
_maxconn = config.getint('database', 'maxconn', default=64)
_default_name = config.get('database', 'default_name', default='template1')
def unescape_quote(s):
if s.startswith('"') and s.endswith('"'):
return s.strip('"').replace('""', '"')
return s
def replace_special_values(s, **mapping):
for name, value in mapping.items():
s = s.replace('$' + name, value)
return s
class LoggingCursor(cursor):
def execute(self, sql, args=None):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(self.mogrify(sql, args))
cursor.execute(self, sql, args)
class ForSkipLocked(For):
def __str__(self):
assert not self.nowait, "Can not use both NO WAIT and SKIP LOCKED"
return super().__str__() + (' SKIP LOCKED' if not self.nowait else '')
class Unaccent(Function):
__slots__ = ()
_function = config.get('database', 'unaccent_function', default='unaccent')
class Similarity(Function):
__slots__ = ()
_function = config.get(
'database', 'similarity_function', default='similarity')
class Match(BinaryOperator):
__slots__ = ()
_operator = '@@'
class ToTsvector(Function):
__slots__ = ()
_function = 'to_tsvector'
class Setweight(Function):
__slots__ = ()
_function = 'setweight'
class TsQuery(Function):
__slots__ = ()
class ToTsQuery(TsQuery):
__slots__ = ()
_function = 'to_tsquery'
class PlainToTsQuery(TsQuery):
__slots__ = ()
_function = 'plainto_tsquery'
class PhraseToTsQuery(TsQuery):
__slots__ = ()
_function = 'phraseto_tsquery'
class WebsearchToTsQuery(TsQuery):
__slots__ = ()
_function = 'websearch_to_tsquery'
class TsRank(Function):
__slots__ = ()
_function = 'ts_rank'
class AdvisoryLock(Function):
_function = 'pg_advisory_xact_lock'
class TryAdvisoryLock(Function):
_function = 'pg_try_advisory_xact_lock'
class JSONBExtractPath(Function):
__slots__ = ()
_function = 'jsonb_extract_path'
class JSONKeyExists(BinaryOperator):
__slots__ = ()
_operator = '?'
class _BinaryOperatorArray(BinaryOperator):
"Binary Operator that convert list into Array"
@property
def _operands(self):
if isinstance(self.right, list):
return (self.left, None)
return super()._operands
@property
def params(self):
params = super().params
if isinstance(self.right, list):
params = params[:-1] + (self.right,)
return params
class JSONAnyKeyExist(_BinaryOperatorArray):
__slots__ = ()
_operator = '?|'
class JSONAllKeyExist(_BinaryOperatorArray):
__slots__ = ()
_operator = '?&'
class JSONContains(BinaryOperator):
__slots__ = ()
_operator = '@>'
class Database(DatabaseInterface):
index_translators = []
_lock = RLock()
_databases = defaultdict(dict)
_connpool = None
_list_cache = {}
_list_cache_timestamp = {}
_search_path = None
_current_user = None
_has_returning = None
_has_select_for_skip_locked = None
_has_proc = defaultdict(lambda: defaultdict(dict))
_extensions = defaultdict(dict)
_search_full_text_languages = defaultdict(dict)
flavor = Flavor(ilike=True)
TYPES_MAPPING = {
'SMALLINT': SQLType('INT2', 'INT2'),
'BIGINT': SQLType('INT8', 'INT8'),
'BLOB': SQLType('BYTEA', 'BYTEA'),
'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP(0)'),
'REAL': SQLType('FLOAT4', 'FLOAT4'),
'FLOAT': SQLType('FLOAT8', 'FLOAT8'),
'FULLTEXT': SQLType('TSVECTOR', 'TSVECTOR'),
'INTEGER': SQLType('INT4', 'INT4'),
'JSON': SQLType('JSONB', 'JSONB'),
'TIMESTAMP': SQLType('TIMESTAMP', 'TIMESTAMP(6)'),
}
def __new__(cls, name=_default_name):
with cls._lock:
now = datetime.now()
databases = cls._databases[os.getpid()]
for database in list(databases.values()):
if ((now - database._last_use).total_seconds() > _timeout
and database.name != name
and not database._connpool._used):
database.close()
if name in databases:
inst = databases[name]
else:
inst = DatabaseInterface.__new__(cls, name=name)
try:
inst._connpool = ThreadedConnectionPool(
_minconn, _maxconn, **cls._connection_params(name),
cursor_factory=LoggingCursor)
except Exception:
logger.error(
'connection to "%s" failed', name, exc_info=True)
raise
else:
logger.info('connection to "%s" succeeded', name)
databases[name] = inst
inst._last_use = datetime.now()
return inst
def __init__(self, name=_default_name):
super(Database, self).__init__(name)
@classmethod
def _connection_params(cls, name):
uri = parse_uri(config.get('database', 'uri'))
if uri.path and uri.path != '/':
warnings.warn("The path specified in the URI will be overridden")
params = {
'dsn': uri._replace(path=name).geturl(),
'fallback_application_name': os.environ.get(
'TRYTOND_APPNAME', 'trytond'),
}
return params
def connect(self):
return self
def get_connection(
self, autocommit=False, readonly=False, statement_timeout=None):
retry = max(config.getint('database', 'retry'), _maxconn)
for count in range(retry, -1, -1):
try:
conn = self._connpool.getconn()
except (PoolError, DatabaseOperationalError):
if count and not self._connpool.closed:
logger.info('waiting a connection')
time.sleep(1)
continue
raise
except Exception:
logger.error(
'connection to "%s" failed', self.name, exc_info=True)
raise
try:
conn.set_session(
isolation_level=ISOLATION_LEVEL_REPEATABLE_READ,
readonly=readonly,
autocommit=autocommit)
with conn.cursor() as cur:
if statement_timeout:
cur.execute('SET statement_timeout=%s' %
(statement_timeout * 1000))
else:
# Detect disconnection
cur.execute('SELECT 1')
except DatabaseOperationalError:
self._connpool.putconn(conn, close=True)
continue
break
return conn
def put_connection(self, connection, close=False):
try:
connection.reset()
except InterfaceError:
pass
self._connpool.putconn(connection, close=close)
def close(self):
with self._lock:
logger.info('disconnection from "%s"', self.name)
self._connpool.closeall()
self._databases[os.getpid()].pop(self.name)
@classmethod
def create(cls, connection, database_name, template='template0'):
cursor = connection.cursor()
cursor.execute(
SQL(
"CREATE DATABASE {} TEMPLATE {} ENCODING 'unicode'")
.format(
Identifier(database_name),
Identifier(template)))
connection.commit()
cls._list_cache.clear()
@classmethod
def drop(cls, connection, database_name):
cursor = connection.cursor()
cursor.execute(SQL("DROP DATABASE {}")
.format(Identifier(database_name)))
cls._list_cache.clear()
cls._has_proc.pop(database_name, None)
cls._search_full_text_languages.pop(database_name, None)
def get_version(self, connection):
version = connection.server_version
major, rest = divmod(int(version), 10000)
minor, patch = divmod(rest, 100)
return (major, minor, patch)
def list(self, hostname=None):
now = time.time()
timeout = config.getint('session', 'timeout')
res = self.__class__._list_cache.get(hostname)
timestamp = self.__class__._list_cache_timestamp.get(hostname, now)
if res and abs(timestamp - now) < timeout:
return res
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute('SELECT datname FROM pg_database '
'WHERE datistemplate = false ORDER BY datname')
res = []
for db_name, in cursor:
try:
conn = connect(**self._connection_params(db_name))
try:
with conn:
if self._test(conn, hostname=hostname):
res.append(db_name)
finally:
conn.close()
except Exception:
logger.debug(
'Test failed for "%s"', db_name, exc_info=True)
continue
finally:
self.put_connection(connection, close=True)
self.__class__._list_cache[hostname] = res
self.__class__._list_cache_timestamp[hostname] = now
return res
def init(self):
from trytond.modules import get_module_info
connection = self.get_connection()
try:
cursor = connection.cursor()
sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
with open(sql_file) as fp:
for line in fp.read().split(';'):
if (len(line) > 0) and (not line.isspace()):
cursor.execute(line)
for module in ['ir', 'res']:
info = get_module_info(module)
cursor.execute('INSERT INTO ir_module '
'(create_uid, create_date, name, state) '
'VALUES (%s, now(), %s, %s) '
'RETURNING id',
(0, module, 'to activate'))
module_id = cursor.fetchone()[0]
for dependency in info.get('depends', []):
cursor.execute('INSERT INTO ir_module_dependency '
'(create_uid, create_date, module, name) '
'VALUES (%s, now(), %s, %s)',
(0, module_id, dependency))
connection.commit()
finally:
self.put_connection(connection)
def test(self, hostname=None):
try:
connection = self.get_connection()
except Exception:
logger.debug('Test failed for "%s"', self.name, exc_info=True)
return False
try:
return self._test(connection, hostname=hostname)
finally:
self.put_connection(connection, close=True)
@classmethod
def _test(cls, connection, hostname=None):
cursor = connection.cursor()
tables = ('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
'ir_translation', 'ir_lang', 'ir_configuration')
cursor.execute('SELECT table_name FROM information_schema.tables '
'WHERE table_name IN %s', (tables,))
if len(cursor.fetchall()) != len(tables):
return False
if hostname:
try:
cursor.execute(
'SELECT hostname FROM ir_configuration')
hostnames = {h for h, in cursor if h}
if hostnames and hostname not in hostnames:
return False
except ProgrammingError:
pass
return True
def nextid(self, connection, table, count=1):
cursor = connection.cursor()
cursor.execute(
"SELECT nextval(pg_get_serial_sequence(format(%s, %s), %s)) "
"FROM generate_series(1, %s)",
('%I', table, 'id', count))
if count == 1:
return cursor.fetchone()[0]
else:
return [id for id, in cursor]
def setnextid(self, connection, table, value):
cursor = connection.cursor()
cursor.execute(
"SELECT setval(pg_get_serial_sequence(format(%s, %s), %s), %s)",
('%I', table, 'id', value))
def currid(self, connection, table):
cursor = connection.cursor()
cursor.execute(
"SELECT pg_get_serial_sequence(format(%s, %s), %s)",
('%I', table, 'id'))
sequence_name, = cursor.fetchone()
cursor.execute(f"SELECT last_value FROM {sequence_name}")
return cursor.fetchone()[0]
def lock(self, connection, table):
cursor = connection.cursor()
cursor.execute(SQL('LOCK {} IN EXCLUSIVE MODE NOWAIT').format(
Identifier(table)))
def lock_id(self, id, timeout=None):
if not timeout:
return TryAdvisoryLock(id)
else:
return AdvisoryLock(id)
def has_constraint(self, constraint):
return True
def has_multirow_insert(self):
return True
def get_table_schema(self, connection, table_name):
cursor = connection.cursor()
for schema in self.search_path:
cursor.execute('SELECT 1 '
'FROM information_schema.tables '
'WHERE table_name = %s AND table_schema = %s',
(table_name, schema))
if cursor.rowcount:
return schema
@property
def current_user(self):
if self._current_user is None:
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute('SELECT current_user')
self._current_user = cursor.fetchone()[0]
finally:
self.put_connection(connection)
return self._current_user
@property
def search_path(self):
if self._search_path is None:
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute('SHOW search_path')
path, = cursor.fetchone()
special_values = {
'user': self.current_user,
}
self._search_path = [
unescape_quote(replace_special_values(
p.strip(), **special_values))
for p in path.split(',')]
finally:
self.put_connection(connection)
return self._search_path
def has_returning(self):
if self._has_returning is None:
connection = self.get_connection()
try:
# RETURNING clause is available since PostgreSQL 8.2
self._has_returning = self.get_version(connection) >= (8, 2)
finally:
self.put_connection(connection)
return self._has_returning
def has_select_for(self):
return True
def get_select_for_skip_locked(self):
if self._has_select_for_skip_locked is None:
connection = self.get_connection()
try:
# SKIP LOCKED clause is available since PostgreSQL 9.5
self._has_select_for_skip_locked = (
self.get_version(connection) >= (9, 5))
finally:
self.put_connection(connection)
if self._has_select_for_skip_locked:
return ForSkipLocked
else:
return For
def has_window_functions(self):
return True
@classmethod
def has_sequence(cls):
return True
def has_proc(self, name, property='oid'):
if (name in self._has_proc[self.name]
and property in self._has_proc[self.name][name]):
return self._has_proc[self.name][name][property]
connection = self.get_connection()
result = False
try:
cursor = connection.cursor()
cursor.execute(
SQL('SELECT {} FROM pg_proc WHERE proname=%s').format(
Identifier(property)), (name,))
result = cursor.fetchone()
if result:
result, = result
finally:
self.put_connection(connection)
self._has_proc[self.name][name][property] = result
return result
def has_unaccent(self):
return self.has_proc(Unaccent._function)
def has_unaccent_indexable(self):
return self.has_proc(Unaccent._function, 'provolatile') == 'i'
def has_similarity(self):
return self.has_proc(Similarity._function)
def similarity(self, column, value):
return Similarity(column, value)
def has_search_full_text(self):
return True
def _search_full_text_language(self, language):
languages = self._search_full_text_languages[self.name]
if language not in languages:
lang = Table('ir_lang')
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute(*lang.select(
Coalesce(lang.pg_text_search, 'simple'),
where=lang.code == language,
limit=1))
config_name, = cursor.fetchone()
finally:
self.put_connection(connection)
languages[language] = config_name
else:
config_name = languages[language]
return config_name
def format_full_text(self, *documents, language=None):
size = max(len(documents) // 4, 1)
if len(documents) > 1:
weights = chain(
['A'] * size, ['B'] * size, ['C'] * size, repeat('D'))
else:
weights = [None]
expression = None
if language:
config_name = self._search_full_text_language(language)
else:
config_name = None
for document, weight in zip(documents, weights):
if not document:
continue
if config_name:
ts_vector = ToTsvector(config_name, document)
else:
ts_vector = ToTsvector('simple', document)
if weight:
ts_vector = Setweight(ts_vector, weight)
if expression is None:
expression = ts_vector
else:
expression = Concat(expression, ts_vector)
return expression
def format_full_text_query(self, query, language=None):
connection = self.get_connection()
try:
version = self.get_version(connection)
finally:
self.put_connection(connection)
if not isinstance(query, TsQuery):
if version >= (11, 0):
ToTsQuery = WebsearchToTsQuery
else:
ToTsQuery = PlainToTsQuery
if language:
config_name = self._search_full_text_language(language)
else:
config_name = 'simple'
query = ToTsQuery(config_name, query)
return query
def search_full_text(self, document, query):
return Match(document, query)
def rank_full_text(self, document, query, normalize=None):
# TODO: weights and cover density
norm_int = 0
if normalize:
values = {
'document log': 1,
'document': 2,
'mean': 4,
'word': 8,
'word log': 16,
'rank': 32,
}
for norm in normalize:
norm_int |= values.get(norm, 0)
return TsRank(document, query, norm_int)
def sql_type(self, type_):
if type_ in self.TYPES_MAPPING:
return self.TYPES_MAPPING[type_]
if type_.startswith('VARCHAR'):
return SQLType('VARCHAR', type_)
return SQLType(type_, type_)
def sql_format(self, type_, value):
if type_ == 'BLOB':
if value is not None:
return Binary(value)
return value
def unaccent(self, value):
if self.has_unaccent():
return Unaccent(value)
return value
def sequence_exist(self, connection, name):
cursor = connection.cursor()
for schema in self.search_path:
cursor.execute('SELECT 1 '
'FROM information_schema.sequences '
'WHERE sequence_name = %s AND sequence_schema = %s',
(name, schema))
if cursor.rowcount:
return True
return False
def sequence_create(
self, connection, name, number_increment=1, start_value=1):
cursor = connection.cursor()
cursor.execute(
SQL("CREATE SEQUENCE {} INCREMENT BY %s START WITH %s").format(
Identifier(name)),
(number_increment, start_value))
def sequence_update(
self, connection, name, number_increment=1, start_value=1):
cursor = connection.cursor()
cursor.execute(
SQL("ALTER SEQUENCE {} INCREMENT BY %s RESTART WITH %s").format(
Identifier(name)),
(number_increment, start_value))
def sequence_rename(self, connection, old_name, new_name):
cursor = connection.cursor()
if (self.sequence_exist(connection, old_name)
and not self.sequence_exist(connection, new_name)):
cursor.execute(
SQL("ALTER TABLE {} RENAME TO {}").format(
Identifier(old_name),
Identifier(new_name)))
def sequence_delete(self, connection, name):
cursor = connection.cursor()
cursor.execute(SQL("DROP SEQUENCE {}").format(
Identifier(name)))
def sequence_next_number(self, connection, name):
cursor = connection.cursor()
version = self.get_version(connection)
if version >= (10, 0):
cursor.execute(
'SELECT increment_by '
'FROM pg_sequences '
'WHERE sequencename=%s',
(name,))
increment, = cursor.fetchone()
cursor.execute(
SQL(
'SELECT CASE WHEN NOT is_called THEN last_value '
'ELSE last_value + %s '
'END '
'FROM {}').format(Identifier(name)),
(increment,))
else:
cursor.execute(
SQL(
'SELECT CASE WHEN NOT is_called THEN last_value '
'ELSE last_value + increment_by '
'END '
'FROM {}').format(Identifier(name)))
return cursor.fetchone()[0]
def has_channel(self):
return True
def has_extension(self, extension_name):
if extension_name in self._extensions[self.name]:
return self._extensions[self.name][extension_name]
connection = self.get_connection()
result = False
try:
cursor = connection.cursor()
cursor.execute(
"SELECT 1 FROM pg_extension WHERE extname=%s",
(extension_name,))
result = bool(cursor.rowcount)
finally:
self.put_connection(connection)
self._extensions[self.name][extension_name] = result
return result
def json_get(self, column, key=None):
column = Cast(column, 'jsonb')
if key:
column = JSONBExtractPath(column, key)
return column
def json_key_exists(self, column, key):
return JSONKeyExists(Cast(column, 'jsonb'), key)
def json_any_keys_exist(self, column, keys):
return JSONAnyKeyExist(Cast(column, 'jsonb'), keys)
def json_all_keys_exist(self, column, keys):
return JSONAllKeyExist(Cast(column, 'jsonb'), keys)
def json_contains(self, column, json):
return JSONContains(Cast(column, 'jsonb'), Cast(json, 'jsonb'))
register_type(UNICODE)
if PYDATE:
register_type(PYDATE)
if PYDATETIME:
register_type(PYDATETIME)
if PYTIME:
register_type(PYTIME)
if PYINTERVAL:
register_type(PYINTERVAL)
register_adapter(float, lambda value: AsIs(repr(value)))
register_adapter(Decimal, lambda value: AsIs(str(value)))
def convert_json(value):
from trytond.protocols.jsonrpc import JSONDecoder
return json.loads(value, object_hook=JSONDecoder())
register_default_json(loads=convert_json)
register_default_jsonb(loads=convert_json)
if is_gevent_monkey_patched():
from psycopg2.extensions import set_wait_callback
from psycopg2.extras import wait_select
set_wait_callback(wait_select)

156
backend/postgresql/init.sql Executable file
View File

@@ -0,0 +1,156 @@
CREATE TABLE ir_configuration (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_configuration_id_positive CHECK(id >= 0),
language VARCHAR,
hostname VARCHAR,
PRIMARY KEY(id)
);
CREATE TABLE ir_model (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_model_id_positive CHECK(id >= 0),
model VARCHAR NOT NULL,
name VARCHAR,
info TEXT,
module VARCHAR,
PRIMARY KEY(id)
);
ALTER TABLE ir_model ADD CONSTRAINT ir_model_model_uniq UNIQUE (model);
CREATE TABLE ir_model_field (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_model_field_id_positive CHECK(id >= 0),
model VARCHAR NOT NULL,
name VARCHAR NOT NULL,
relation VARCHAR,
field_description VARCHAR,
ttype VARCHAR,
help TEXT,
module VARCHAR,
"access" BOOL,
PRIMARY KEY(id),
FOREIGN KEY (model) REFERENCES ir_model(model) ON DELETE CASCADE
);
ALTER TABLE ir_model_field ADD CONSTRAINT ir_model_field_name_model_uniq UNIQUE (name, model);
CREATE TABLE ir_ui_view (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_ui_view_id_positive CHECK(id >= 0),
model VARCHAR NOT NULL,
"type" VARCHAR,
data TEXT NOT NULL,
field_childs VARCHAR,
priority INTEGER NOT NULL,
PRIMARY KEY(id)
);
CREATE TABLE ir_ui_menu (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_ui_menu_id_positive CHECK(id >= 0),
parent INTEGER,
name VARCHAR NOT NULL,
icon VARCHAR,
PRIMARY KEY (id),
FOREIGN KEY (parent) REFERENCES ir_ui_menu (id) ON DELETE SET NULL
);
CREATE TABLE ir_translation (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_translation_id_positive CHECK(id >= 0),
lang VARCHAR,
src TEXT,
name VARCHAR NOT NULL,
res_id INTEGER,
value TEXT,
"type" VARCHAR,
module VARCHAR,
fuzzy BOOLEAN NOT NULL,
PRIMARY KEY(id)
);
CREATE TABLE ir_lang (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_lang_id_positive CHECK(id >= 0),
name VARCHAR NOT NULL,
code VARCHAR NOT NULL,
translatable BOOLEAN NOT NULL,
parent VARCHAR,
active BOOLEAN NOT NULL,
direction VARCHAR NOT NULL,
PRIMARY KEY(id)
);
CREATE TABLE res_user (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT res_user_id_positive CHECK(id >= 0),
name VARCHAR NOT NULL,
active BOOLEAN NOT NULL,
login VARCHAR NOT NULL,
password VARCHAR,
PRIMARY KEY(id)
);
ALTER TABLE res_user ADD CONSTRAINT res_user_login_key UNIQUE (login);
INSERT INTO res_user (id, login, password, name, active) VALUES (0, 'root', NULL, 'Root', False);
CREATE TABLE res_group (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT res_group_id_positive CHECK(id >= 0),
name VARCHAR NOT NULL,
PRIMARY KEY(id)
);
CREATE TABLE "res_user-res_group" (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT "res_user-res_group_id_positive" CHECK(id >= 0),
"user" INTEGER NOT NULL,
"group" INTEGER NOT NULL,
FOREIGN KEY ("user") REFERENCES res_user (id) ON DELETE CASCADE,
FOREIGN KEY ("group") REFERENCES res_group (id) ON DELETE CASCADE,
PRIMARY KEY(id)
);
CREATE TABLE ir_module (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_module_id_positive CHECK(id >= 0),
create_uid INTEGER NOT NULL,
create_date TIMESTAMP WITHOUT TIME ZONE NOT NULL,
write_date TIMESTAMP WITHOUT TIME ZONE,
write_uid INTEGER,
name VARCHAR NOT NULL,
state VARCHAR,
PRIMARY KEY(id),
FOREIGN KEY (create_uid) REFERENCES res_user ON DELETE SET NULL,
FOREIGN KEY (write_uid) REFERENCES res_user ON DELETE SET NULL
);
ALTER TABLE ir_module ADD CONSTRAINT ir_module_name_uniq UNIQUE (name);
CREATE TABLE ir_module_dependency (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_module_dependency_id_positive CHECK(id >= 0),
create_uid INTEGER NOT NULL,
create_date TIMESTAMP WITHOUT TIME ZONE NOT NULL,
write_date TIMESTAMP WITHOUT TIME ZONE,
write_uid INTEGER,
name VARCHAR,
module INTEGER,
PRIMARY KEY(id),
FOREIGN KEY (create_uid) REFERENCES res_user ON DELETE SET NULL,
FOREIGN KEY (write_uid) REFERENCES res_user ON DELETE SET NULL,
FOREIGN KEY (module) REFERENCES ir_module ON DELETE CASCADE
);
CREATE TABLE ir_cache (
id INTEGER GENERATED BY DEFAULT AS IDENTITY NOT NULL
CONSTRAINT ir_cache_id_positive CHECK(id >= 0),
name VARCHAR NOT NULL,
"timestamp" TIMESTAMP WITHOUT TIME ZONE,
create_date TIMESTAMP WITHOUT TIME ZONE,
create_uid INTEGER,
write_date TIMESTAMP WITHOUT TIME ZONE,
write_uid INTEGER
);

731
backend/postgresql/table.py Executable file
View File

@@ -0,0 +1,731 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import logging
import re
from psycopg2.sql import SQL, Identifier
from trytond.backend.table import (
IndexTranslatorInterface, TableHandlerInterface)
from trytond.transaction import Transaction
__all__ = ['TableHandler']
logger = logging.getLogger(__name__)
VARCHAR_SIZE_RE = re.compile(r'VARCHAR\(([0-9]+)\)')
class TableHandler(TableHandlerInterface):
namedatalen = 64
index_translators = []
def _init(self, model, history=False):
super()._init(model, history=history)
self.__columns = None
self.__constraints = None
self.__fk_deltypes = None
self.__indexes = None
transaction = Transaction()
cursor = transaction.connection.cursor()
# Create new table if necessary
if not self.table_exist(self.table_name):
cursor.execute(SQL('CREATE TABLE {} ()').format(
Identifier(self.table_name)))
self.table_schema = transaction.database.get_table_schema(
transaction.connection, self.table_name)
cursor.execute('SELECT tableowner = current_user FROM pg_tables '
'WHERE tablename = %s AND schemaname = %s',
(self.table_name, self.table_schema))
self.is_owner, = cursor.fetchone()
if model.__doc__ and self.is_owner:
cursor.execute(SQL('COMMENT ON TABLE {} IS %s').format(
Identifier(self.table_name)),
(model.__doc__,))
def migrate_to_identity(table, column):
previous_seq_name = f"{table}_{column}_seq"
cursor.execute(
"SELECT nextval(format(%s, %s))", ('%I', previous_seq_name,))
next_val, = cursor.fetchone()
cursor.execute(
"SELECT seqincrement, seqmax, seqmin, seqcache "
"FROM pg_sequence WHERE seqrelid = %s::regclass",
(previous_seq_name,))
increment, s_max, s_min, cache = cursor.fetchone()
# Previously created sequences were setting bigint values for those
# identity column mimic the type of the underlying column
if (s_max > 2 ** 31 - 1
and self._columns[column]['typname'] != 'int8'):
s_max = 2 ** 31 - 1
if (s_min < -(2 ** 31)
and self._columns[column]['typname'] != 'int8'):
s_min = -(2 ** 31)
cursor.execute(
SQL("ALTER TABLE {} ALTER COLUMN {} DROP DEFAULT").format(
Identifier(table), Identifier(column)))
cursor.execute(
SQL("DROP SEQUENCE {}").format(
Identifier(previous_seq_name)))
cursor.execute(
SQL("ALTER TABLE {} ALTER COLUMN {} "
"ADD GENERATED BY DEFAULT AS IDENTITY").format(
Identifier(table), Identifier(column)))
cursor.execute(
"SELECT pg_get_serial_sequence(format(%s, %s), %s)",
('%I', table, column))
serial_seq_name, = cursor.fetchone()
cursor.execute(
(f"ALTER SEQUENCE {serial_seq_name} INCREMENT BY %s "
"MINVALUE %s MAXVALUE %s RESTART WITH %s CACHE %s"),
(increment, s_min, s_max, next_val, cache))
update_definitions = False
if 'id' not in self._columns:
update_definitions = True
if not self.history:
cursor.execute(
SQL(
"ALTER TABLE {} ADD COLUMN id INTEGER "
"GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY").format(
Identifier(self.table_name)))
else:
cursor.execute(
SQL('ALTER TABLE {} ADD COLUMN id INTEGER')
.format(Identifier(self.table_name)))
else:
if not self.history and not self.__columns['id']['identity']:
update_definitions = True
migrate_to_identity(self.table_name, 'id')
if self.history and '__id' not in self._columns:
update_definitions = True
cursor.execute(
SQL(
"ALTER TABLE {} ADD COLUMN __id INTEGER "
"GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY").format(
Identifier(self.table_name)))
elif self.history:
if not self.__columns['__id']['identity']:
update_definitions = True
cursor.execute(
SQL("ALTER TABLE {} ALTER COLUMN id DROP DEFAULT").format(
Identifier(self.table_name)))
migrate_to_identity(self.table_name, '__id')
if update_definitions:
self._update_definitions(columns=True)
@classmethod
def table_exist(cls, table_name):
transaction = Transaction()
return bool(transaction.database.get_table_schema(
transaction.connection, table_name))
@classmethod
def table_rename(cls, old_name, new_name):
transaction = Transaction()
cursor = transaction.connection.cursor()
# Rename table
if (cls.table_exist(old_name)
and not cls.table_exist(new_name)):
cursor.execute(SQL('ALTER TABLE {} RENAME TO {}').format(
Identifier(old_name), Identifier(new_name)))
# Migrate from 6.6: rename old sequence
old_sequence = old_name + '_id_seq'
new_sequence = new_name + '_id_seq'
transaction.database.sequence_rename(
transaction.connection, old_sequence, new_sequence)
# Rename history table
old_history = old_name + "__history"
new_history = new_name + "__history"
if (cls.table_exist(old_history)
and not cls.table_exist(new_history)):
cursor.execute('ALTER TABLE "%s" RENAME TO "%s"'
% (old_history, new_history))
def column_exist(self, column_name):
return column_name in self._columns
def column_rename(self, old_name, new_name):
cursor = Transaction().connection.cursor()
if self.column_exist(old_name):
if not self.column_exist(new_name):
cursor.execute(SQL(
'ALTER TABLE {} RENAME COLUMN {} TO {}').format(
Identifier(self.table_name),
Identifier(old_name),
Identifier(new_name)))
self._update_definitions(columns=True)
else:
logger.warning(
'Unable to rename column %s on table %s to %s.',
old_name, self.table_name, new_name)
@property
def _columns(self):
if self.__columns is None:
cursor = Transaction().connection.cursor()
self.__columns = {}
# Fetch columns definitions from the table
cursor.execute('SELECT '
'column_name, udt_name, is_nullable, '
'character_maximum_length, '
'column_default, is_identity '
'FROM information_schema.columns '
'WHERE table_name = %s AND table_schema = %s',
(self.table_name, self.table_schema))
for column, typname, nullable, size, default, identity in cursor:
self.__columns[column] = {
'typname': typname,
'notnull': True if nullable == 'NO' else False,
'size': size,
'default': default,
'identity': False if identity == 'NO' else True,
}
return self.__columns
@property
def _constraints(self):
if self.__constraints is None:
cursor = Transaction().connection.cursor()
# fetch constraints for the table
cursor.execute('SELECT constraint_name '
'FROM information_schema.table_constraints '
'WHERE table_name = %s AND table_schema = %s',
(self.table_name, self.table_schema))
self.__constraints = [c for c, in cursor]
# add nonstandard exclude constraint
cursor.execute('SELECT c.conname '
'FROM pg_namespace nc, '
'pg_namespace nr, '
'pg_constraint c, '
'pg_class r '
'WHERE nc.oid = c.connamespace AND nr.oid = r.relnamespace '
'AND c.conrelid = r.oid '
"AND c.contype = 'x' " # exclude type
"AND r.relkind IN ('r', 'p') "
'AND r.relname = %s AND nr.nspname = %s',
(self.table_name, self.table_schema))
self.__constraints.extend((c for c, in cursor))
return self.__constraints
@property
def _fk_deltypes(self):
if self.__fk_deltypes is None:
cursor = Transaction().connection.cursor()
cursor.execute('SELECT k.column_name, r.delete_rule '
'FROM information_schema.key_column_usage AS k '
'JOIN information_schema.referential_constraints AS r '
'ON r.constraint_schema = k.constraint_schema '
'AND r.constraint_name = k.constraint_name '
'WHERE k.table_name = %s AND k.table_schema = %s',
(self.table_name, self.table_schema))
self.__fk_deltypes = dict(cursor)
return self.__fk_deltypes
@property
def _indexes(self):
if self.__indexes is None:
cursor = Transaction().connection.cursor()
# Fetch indexes defined for the table
cursor.execute("SELECT cl2.relname "
"FROM pg_index ind "
"JOIN pg_class cl on (cl.oid = ind.indrelid) "
"JOIN pg_namespace n ON (cl.relnamespace = n.oid) "
"JOIN pg_class cl2 on (cl2.oid = ind.indexrelid) "
"WHERE cl.relname = %s AND n.nspname = %s "
"AND NOT ind.indisprimary AND NOT ind.indisunique",
(self.table_name, self.table_schema))
self.__indexes = [l[0] for l in cursor]
return self.__indexes
def _update_definitions(self, columns=None, constraints=None):
if columns is None and constraints is None:
columns = constraints = True
if columns:
self.__columns = None
if constraints:
self.__constraints = None
self.__fk_deltypes = None
def alter_size(self, column_name, column_type):
cursor = Transaction().connection.cursor()
cursor.execute(
SQL("ALTER TABLE {} ALTER COLUMN {} TYPE {}").format(
Identifier(self.table_name),
Identifier(column_name),
SQL(column_type)))
self._update_definitions(columns=True)
def alter_type(self, column_name, column_type):
cursor = Transaction().connection.cursor()
cursor.execute(SQL('ALTER TABLE {} ALTER {} TYPE {}').format(
Identifier(self.table_name),
Identifier(column_name),
SQL(column_type)))
self._update_definitions(columns=True)
def column_is_type(self, column_name, type_, *, size=-1):
db_type = self._columns[column_name]['typname'].upper()
database = Transaction().database
base_type = database.sql_type(type_).base.upper()
if base_type == 'VARCHAR' and (size is None or size >= 0):
same_size = self._columns[column_name]['size'] == size
else:
same_size = True
return base_type == db_type and same_size
def db_default(self, column_name, value):
if value in [True, False]:
test = str(value).lower()
else:
test = value
if self._columns[column_name]['default'] != test:
cursor = Transaction().connection.cursor()
cursor.execute(
SQL(
'ALTER TABLE {} ALTER COLUMN {} SET DEFAULT %s').format(
Identifier(self.table_name),
Identifier(column_name)),
(value,))
def add_column(self, column_name, sql_type, default=None, comment=''):
cursor = Transaction().connection.cursor()
database = Transaction().database
column_type = database.sql_type(sql_type)
match = VARCHAR_SIZE_RE.match(sql_type)
field_size = int(match.group(1)) if match else None
def add_comment():
if comment and self.is_owner:
cursor.execute(
SQL('COMMENT ON COLUMN {}.{} IS %s').format(
Identifier(self.table_name),
Identifier(column_name)),
(comment,))
if self.column_exist(column_name):
if (column_name in ('create_date', 'write_date')
and column_type[1].lower() != 'timestamp(6)'):
# Migrate dates from timestamp(0) to timestamp
cursor.execute(
SQL(
'ALTER TABLE {} ALTER COLUMN {} TYPE timestamp')
.format(
Identifier(self.table_name),
Identifier(column_name)))
add_comment()
base_type = column_type[0].lower()
typname = self._columns[column_name]['typname']
if base_type != typname:
if (typname, base_type) in [
('varchar', 'text'),
('text', 'varchar'),
('date', 'timestamp'),
('int2', 'int4'),
('int2', 'float4'),
('int2', 'int8'),
('int2', 'float8'),
('int2', 'numeric'),
('int4', 'int8'),
('int4', 'float8'),
('int4', 'numeric'),
('int8', 'float8'),
('int8', 'numeric'),
('float4', 'numeric'),
('float4', 'float8'),
('float8', 'numeric'),
]:
self.alter_type(column_name, base_type)
elif (typname, base_type) in [
('int8', 'int4'),
('int8', 'int2'),
('int4', 'int2'),
('float8', 'float4'),
]:
pass
else:
logger.warning(
'Unable to migrate column %s on table %s '
'from %s to %s.',
column_name, self.table_name, typname, base_type)
if base_type == typname == 'varchar':
# Migrate size
from_size = self._columns[column_name]['size']
if field_size is None:
if from_size:
self.alter_size(column_name, base_type)
elif from_size == field_size:
pass
elif from_size and from_size < field_size:
self.alter_size(column_name, column_type[1])
else:
logger.warning(
'Unable to migrate column %s on table %s '
'from varchar(%s) to varchar(%s).',
column_name, self.table_name,
from_size if from_size and from_size > 0 else "",
field_size)
return
column_type = column_type[1]
cursor.execute(
SQL('ALTER TABLE {} ADD COLUMN {} {}').format(
Identifier(self.table_name),
Identifier(column_name),
SQL(column_type)))
add_comment()
if default:
# check if table is non-empty:
cursor.execute('SELECT 1 FROM "%s" limit 1' % self.table_name)
if cursor.rowcount:
# Populate column with default values:
cursor.execute(
SQL('UPDATE {} SET {} = %s').format(
Identifier(self.table_name),
Identifier(column_name)),
(default(),))
self._update_definitions(columns=True)
def add_fk(self, columns, reference, ref_columns=None, on_delete=None):
if on_delete is not None:
on_delete = on_delete.upper()
else:
on_delete = 'SET NULL'
if isinstance(columns, str):
columns = [columns]
cursor = Transaction().connection.cursor()
if ref_columns:
ref_columns_name = '_' + '_'.join(ref_columns)
else:
ref_columns_name = ''
name = self.convert_name(
self.table_name + '_' + '_'.join(columns)
+ ref_columns_name + '_fkey')
if name in self._constraints:
for column_name in columns:
if self._fk_deltypes.get(column_name) != on_delete:
self.drop_fk(columns, ref_columns)
add = True
break
else:
add = False
else:
add = True
if add:
columns = SQL(', ').join(map(Identifier, columns))
if not ref_columns:
ref_columns = ['id']
ref_columns = SQL(', ').join(map(Identifier, ref_columns))
cursor.execute(
SQL(
"ALTER TABLE {table} "
"ADD CONSTRAINT {constraint} "
"FOREIGN KEY ({columns}) "
"REFERENCES {reference} ({ref_columns}) "
"ON DELETE {action}"
)
.format(
table=Identifier(self.table_name),
constraint=Identifier(name),
columns=columns,
reference=Identifier(reference),
ref_columns=ref_columns,
action=SQL(on_delete)))
self._update_definitions(constraints=True)
def drop_fk(self, columns, ref_columns=None, table=None):
if isinstance(columns, str):
columns = [columns]
if ref_columns:
ref_columns_name = '_' + '_'.join(ref_columns)
else:
ref_columns_name = ''
self.drop_constraint(
'_'.join(columns) + ref_columns_name + '_fkey', table=table)
def not_null_action(self, column_name, action='add'):
if not self.column_exist(column_name):
return
with Transaction().connection.cursor() as cursor:
if action == 'add':
if self._columns[column_name]['notnull']:
return
cursor.execute(SQL(
'SELECT id FROM {} WHERE {} IS NULL LIMIT 1').format(
Identifier(self.table_name),
Identifier(column_name)))
if not cursor.rowcount:
cursor.execute(
SQL(
'ALTER TABLE {} ALTER COLUMN {} SET NOT NULL')
.format(
Identifier(self.table_name),
Identifier(column_name)))
self._update_definitions(columns=True)
else:
logger.warning(
"Unable to set not null on column %s of table %s.\n"
"Try restarting one more time.\n"
"If that doesn't work update the records and restart "
"again.",
column_name, self.table_name)
elif action == 'remove':
if not self._columns[column_name]['notnull']:
return
cursor.execute(
SQL('ALTER TABLE {} ALTER COLUMN {} DROP NOT NULL')
.format(
Identifier(self.table_name),
Identifier(column_name)))
self._update_definitions(columns=True)
else:
raise Exception('Not null action not supported!')
def add_constraint(self, ident, constraint):
ident = self.convert_name(self.table_name + "_" + ident)
if ident in self._constraints:
# This constrain already exist
return
cursor = Transaction().connection.cursor()
cursor.execute(
SQL('ALTER TABLE {} ADD CONSTRAINT {} {}').format(
Identifier(self.table_name),
Identifier(ident),
SQL(str(constraint))),
constraint.params)
self._update_definitions(constraints=True)
def drop_constraint(self, ident, table=None):
ident = self.convert_name((table or self.table_name) + "_" + ident)
if ident not in self._constraints:
return
cursor = Transaction().connection.cursor()
cursor.execute(
SQL('ALTER TABLE {} DROP CONSTRAINT {}').format(
Identifier(self.table_name), Identifier(ident)))
self._update_definitions(constraints=True)
def set_indexes(self, indexes, concurrently=False):
cursor = Transaction().connection.cursor()
old = set(self._indexes)
for index in indexes:
translator = self.index_translator_for(index)
if translator:
name, query, params = translator.definition(index)
name = '_'.join([self.table_name, name])
name = 'idx_' + self.convert_name(name, reserved=len('idx_'))
cursor.execute(
'SELECT idx.indisvalid '
'FROM pg_index idx '
'JOIN pg_class cls ON cls.oid = idx.indexrelid '
'WHERE cls.relname = %s',
(name,))
if (idx_valid := cursor.fetchone()) and not idx_valid[0]:
cursor.execute(
SQL("DROP INDEX {}").format(Identifier(name)))
cursor.execute(
SQL('CREATE INDEX {} IF NOT EXISTS {} ON {} USING {}')
.format(
SQL('CONCURRENTLY' if concurrently else ''),
Identifier(name),
Identifier(self.table_name),
query),
params)
old.discard(name)
for name in old:
if name.startswith('idx_') or name.endswith('_index'):
cursor.execute(SQL('DROP INDEX {}').format(Identifier(name)))
self.__indexes = None
def drop_column(self, column_name):
if not self.column_exist(column_name):
return
cursor = Transaction().connection.cursor()
cursor.execute(SQL('ALTER TABLE {} DROP COLUMN {}').format(
Identifier(self.table_name),
Identifier(column_name)))
self._update_definitions(columns=True)
@classmethod
def drop_table(cls, model, table, cascade=False):
cursor = Transaction().connection.cursor()
cursor.execute('DELETE FROM ir_model_data WHERE model = %s', (model,))
query = 'DROP TABLE {}'
if cascade:
query = query + ' CASCADE'
cursor.execute(SQL(query).format(Identifier(table)))
class IndexMixin:
_type = None
def __init_subclass__(cls):
TableHandler.index_translators.append(cls)
@classmethod
def definition(cls, index):
expr_template = SQL('{expression} {collate} {opclass} {order}')
indexed_expressions = cls._get_indexed_expressions(index)
expressions = []
params = []
for expression, usage in indexed_expressions:
expressions.append(expr_template.format(
**cls._get_expression_variables(expression, usage)))
params.extend(expression.params)
include = SQL('')
if index.options.get('include'):
include = SQL('INCLUDE ({columns})').format(
columns=SQL(',').join(map(
lambda c: SQL(str(c)),
index.options.get('include'))))
where = SQL('')
if index.options.get('where'):
where = SQL('WHERE {where}').format(
where=SQL(str(index.options['where'])))
params.extend(index.options['where'].params)
query = SQL('{type} ({expressions}) {include} {where}').format(
type=SQL(cls._type),
expressions=SQL(',').join(expressions),
include=include,
where=where)
name = cls._get_name(query, params)
return name, query, params
@classmethod
def _get_indexed_expressions(cls, index):
return index.expressions
@classmethod
def _get_expression_variables(cls, expression, usage):
variables = {
'expression': SQL(str(expression)),
'collate': SQL(''),
'opclass': SQL(''),
'order': SQL(''),
}
if usage.options.get('collation'):
variables['collate'] = SQL('COLLATE {}').format(
usage.options['collation'])
if usage.options.get('order'):
order = usage.options['order'].upper()
variables['order'] = SQL(order)
return variables
class HashTranslator(IndexMixin, IndexTranslatorInterface):
_type = 'HASH'
@classmethod
def score(cls, index):
if (len(index.expressions) > 1
or index.expressions[0][1].__class__.__name__ != 'Equality'):
return 0
if index.options.get('include'):
return 0
return 100
@classmethod
def _get_indexed_expressions(cls, index):
return [
(e, u) for e, u in index.expressions
if u.__class__.__name__ == 'Equality'][:1]
class BTreeTranslator(IndexMixin, IndexTranslatorInterface):
_type = 'BTREE'
@classmethod
def score(cls, index):
score = 0
for _, usage in index.expressions:
if usage.__class__.__name__ == 'Range':
score += 100
elif usage.__class__.__name__ == 'Equality':
score += 50
elif usage.__class__.__name__ == 'Similarity':
score += 20
if usage.options.get('begin'):
score += 100
return score
@classmethod
def _get_indexed_expressions(cls, index):
return [
(e, u) for e, u in index.expressions
if u.__class__.__name__ in {'Equality', 'Range', 'Similarity'}]
@classmethod
def _get_expression_variables(cls, expression, usage):
params = super()._get_expression_variables(expression, usage)
if (usage.__class__.__name__ == 'Similarity'
and not usage.options.get('collation')):
# text_pattern_ops and varchar_pattern_ops are the same
params['opclass'] = SQL('varchar_pattern_ops')
return params
class TrigramTranslator(IndexMixin, IndexTranslatorInterface):
_type = 'GIN'
@classmethod
def score(cls, index):
database = Transaction().database
has_btree_gin = database.has_extension('btree_gin')
has_trigram = database.has_extension('pg_trgm')
if not has_btree_gin and not has_trigram:
return 0
score = 0
for _, usage in index.expressions:
if usage.__class__.__name__ == 'Similarity':
if has_trigram:
score += 100
else:
score += 50
elif has_btree_gin:
if usage.__class__.__name__ == 'Range':
score += 90
elif usage.__class__.__name__ == 'Equality':
score += 40
else:
return 0
return score
@classmethod
def _get_indexed_expressions(cls, index):
database = Transaction().database
has_btree_gin = database.has_extension('btree_gin')
has_trigram = database.has_extension('pg_trgm')
def filter(usage):
if usage.__class__.__name__ == 'Similarity':
return has_trigram
elif usage.__class__.__name__ in {'Range', 'Equality'}:
return has_btree_gin
else:
return False
return [(e, u) for e, u in index.expressions if filter(u)]
@classmethod
def _get_expression_variables(cls, expression, usage):
params = super()._get_expression_variables(expression, usage)
if usage.__class__.__name__ == 'Similarity':
params['opclass'] = SQL('gin_trgm_ops')
return params

12
backend/sqlite/__init__.py Executable file
View File

@@ -0,0 +1,12 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from .database import (
Database, DatabaseDataError, DatabaseIntegrityError,
DatabaseOperationalError, DatabaseTimeoutError)
from .table import TableHandler
__all__ = [
Database, TableHandler,
DatabaseIntegrityError, DatabaseDataError, DatabaseOperationalError,
DatabaseTimeoutError]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

648
backend/sqlite/database.py Executable file
View File

@@ -0,0 +1,648 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import datetime
import logging
import math
import os
import random
import sqlite3 as sqlite
import threading
import time
import urllib.parse
import warnings
from decimal import Decimal
from sqlite3 import DatabaseError
from sqlite3 import IntegrityError as DatabaseIntegrityError
from sqlite3 import OperationalError as DatabaseOperationalError
from weakref import WeakKeyDictionary
from sql import Expression, Flavor, Literal, Null, Query, Table
from sql.conditionals import NullIf
from sql.functions import (
CharLength, CurrentTimestamp, Extract, Function, Overlay, Position,
Substring, Trim)
from trytond.backend.database import DatabaseInterface, SQLType
from trytond.config import config, parse_uri
from trytond.tools import safe_join
from trytond.transaction import Transaction
__all__ = [
'Database',
'DatabaseIntegrityError', 'DatabaseDataError', 'DatabaseOperationalError',
'DatabaseTimeoutError']
logger = logging.getLogger(__name__)
_default_name = config.get('database', 'default_name', default=':memory:')
class DatabaseDataError(DatabaseError):
pass
class DatabaseTimeoutError(Exception):
pass
class SQLiteExtract(Function):
__slots__ = ()
_function = 'EXTRACT'
@staticmethod
def extract(lookup_type, date):
if date is None:
return None
if len(date) == 10:
year, month, day = map(int, date.split('-'))
date = datetime.date(year, month, day)
else:
datepart, timepart = date.split(" ")
year, month, day = map(int, datepart.split("-"))
timepart_full = timepart.split(".")
hours, minutes, seconds = map(int, timepart_full[0].split(":"))
if len(timepart_full) == 2:
microseconds = int(timepart_full[1])
else:
microseconds = 0
date = datetime.datetime(year, month, day, hours, minutes, seconds,
microseconds)
if lookup_type.lower() == 'century':
return date.year / 100 + (date.year % 100 and 1 or 0)
elif lookup_type.lower() == 'decade':
return date.year / 10
elif lookup_type.lower() == 'dow':
return (date.weekday() + 1) % 7
elif lookup_type.lower() == 'doy':
return date.timetuple().tm_yday
elif lookup_type.lower() == 'epoch':
return int(time.mktime(date.timetuple()))
elif lookup_type.lower() == 'microseconds':
return date.microsecond
elif lookup_type.lower() == 'millennium':
return date.year / 1000 + (date.year % 1000 and 1 or 0)
elif lookup_type.lower() == 'milliseconds':
return date.microsecond / 1000
elif lookup_type.lower() == 'quarter':
return date.month / 4 + 1
elif lookup_type.lower() == 'week':
return date.isocalendar()[1]
return getattr(date, lookup_type.lower())
def date_trunc(_type, date):
if not _type:
return date
if date is None:
return None
for format_ in [
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d',
'%H:%M:%S',
]:
try:
value = datetime.datetime.strptime(date, format_)
except ValueError:
continue
else:
break
else:
return None
for attribute, replace in [
('microsecond', 0),
('second', 0),
('minute', 0),
('hour', 0),
('day', 1),
('month', 1)]:
if _type.lower().startswith(attribute):
break
value = value.replace(**{attribute: replace})
return str(value)
def split_part(text, delimiter, count):
if text is None:
return None
return (text.split(delimiter) + [''] * (count - 1))[count - 1]
class SQLitePosition(Function):
__slots__ = ()
_function = 'POSITION'
@staticmethod
def position(substring, string):
if string is None:
return
try:
return string.index(substring) + 1
except ValueError:
return 0
def replace(text, pattern, replacement):
return str(text).replace(pattern, replacement)
def now():
transaction = Transaction()
return _nows.setdefault(transaction, {}).setdefault(
transaction.started_at, datetime.datetime.now().isoformat(' '))
_nows = WeakKeyDictionary()
def to_char(value, format):
try:
value = datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S.%f')
except ValueError:
try:
value = datetime.datetime.strptime(value, '%Y-%m-%d').date()
except ValueError:
pass
if isinstance(value, datetime.date):
# Convert SQL pattern into compatible Python
return value.strftime(format
.replace('%', '%%')
.replace('HH12', '%I')
.replace('HH24', '%H')
.replace('HH', '%I')
.replace('MI', '%M')
.replace('SS', '%S')
.replace('US', '%f')
.replace('AM', '%p')
.replace('A.M.', '%p')
.replace('PM', '%p')
.replace('P.M.', '%p')
.replace('am', '%p')
.replace('a.m.', '%p')
.replace('pm', '%p')
.replace('p.m.', '%p')
.replace('YYYY', '%Y')
.replace('YY', '%y')
.replace('Month', '%B')
.replace('Mon', '%b')
.replace('MM', '%m')
.replace('Day', '%A')
.replace('Dy', '%a')
.replace('DDD', '%j')
.replace('DD', '%d')
.replace('D', '%w')
.replace('TZ', '%Z')
)
elif isinstance(value, datetime.timedelta):
raise NotImplementedError
else:
raise NotImplementedError
class SQLiteSubstring(Function):
__slots__ = ()
_function = 'SUBSTR'
class SQLiteOverlay(Function):
__slots__ = ()
_function = 'OVERLAY'
@staticmethod
def overlay(string, placing_string, from_, for_=None):
if for_ is None:
for_ = len(placing_string)
return string[:from_ - 1] + placing_string + string[from_ - 1 + for_:]
class SQLiteCharLength(Function):
__slots__ = ()
_function = 'LENGTH'
class SQLiteCurrentTimestamp(Function):
__slots__ = ()
_function = 'NOW' # More precise
class SQLiteTrim(Trim):
def __str__(self):
flavor = Flavor.get()
param = flavor.param
function = {
'BOTH': 'TRIM',
'LEADING': 'LTRIM',
'TRAILING': 'RTRIM',
}[self.position]
def format(arg):
if isinstance(arg, str):
return param
else:
return str(arg)
return function + '(%s, %s)' % (
format(self.string), format(self.characters))
@property
def params(self):
if isinstance(self.string, str):
params = [self.string]
else:
params = list(self.string.params)
params.append(self.characters)
return params
def sign(value):
if value > 0:
return 1
elif value < 0:
return -1
else:
return value
def greatest(*args):
args = [a for a in args if a is not None]
if args:
return max(args)
else:
return None
def least(*args):
args = [a for a in args if a is not None]
if args:
return min(args)
else:
return None
def bool_and(*args):
return all(args)
def bool_or(*args):
return any(args)
def cbrt(value):
return math.pow(value, 1 / 3)
def div(a, b):
return a // b
def trunc(value, digits):
return math.trunc(value * 10 ** digits) / 10 ** digits
MAPPING = {
Extract: SQLiteExtract,
Position: SQLitePosition,
Substring: SQLiteSubstring,
Overlay: SQLiteOverlay,
CharLength: SQLiteCharLength,
CurrentTimestamp: SQLiteCurrentTimestamp,
Trim: SQLiteTrim,
}
class JSONExtract(Function):
__slots__ = ()
_function = 'JSON_EXTRACT'
class JSONQuote(Function):
__slots__ = ()
_function = 'JSON_QUOTE'
class SQLiteCursor(sqlite.Cursor):
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class SQLiteConnection(sqlite.Connection):
def cursor(self):
return super(SQLiteConnection, self).cursor(SQLiteCursor)
class Database(DatabaseInterface):
_local = threading.local()
_conn = None
flavor = Flavor(
paramstyle='qmark', function_mapping=MAPPING, null_ordering=False,
max_limit=-1)
IN_MAX = 200
TYPES_MAPPING = {
'BIGINT': SQLType('INTEGER', 'INTEGER'),
'BOOL': SQLType('BOOLEAN', 'BOOLEAN'),
'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP'),
'FULLTEXT': SQLType('TEXT', 'TEXT'),
'JSON': SQLType('TEXT', 'TEXT'),
}
def __new__(cls, name=_default_name):
if (name == ':memory:'
and getattr(cls._local, 'memory_database', None)):
return cls._local.memory_database
return DatabaseInterface.__new__(cls, name=name)
def __init__(self, name=_default_name):
super(Database, self).__init__(name=name)
if name == ':memory:':
Database._local.memory_database = self
def connect(self):
if self._conn is not None:
return self
self._conn = sqlite.connect(
self._make_uri(), uri=True,
detect_types=sqlite.PARSE_DECLTYPES | sqlite.PARSE_COLNAMES,
factory=SQLiteConnection)
self._conn.create_function('extract', 2, SQLiteExtract.extract)
self._conn.create_function('date_trunc', 2, date_trunc)
self._conn.create_function('split_part', 3, split_part)
self._conn.create_function('to_char', 2, to_char)
if sqlite.sqlite_version_info < (3, 3, 14):
self._conn.create_function('replace', 3, replace)
self._conn.create_function('now', 0, now)
self._conn.create_function('greatest', -1, greatest)
self._conn.create_function('least', -1, least)
self._conn.create_function('bool_and', -1, bool_and)
self._conn.create_function('bool_or', -1, bool_or)
# Mathematical functions
self._conn.create_function('cbrt', 1, cbrt)
self._conn.create_function('ceil', 1, math.ceil)
self._conn.create_function('degrees', 1, math.degrees)
self._conn.create_function('div', 2, div)
self._conn.create_function('exp', 1, math.exp)
self._conn.create_function('floor', 1, math.floor)
self._conn.create_function('ln', 1, math.log)
self._conn.create_function('log', 1, math.log10)
self._conn.create_function('mod', 2, math.fmod)
self._conn.create_function('pi', 0, lambda: math.pi)
self._conn.create_function('power', 2, math.pow)
self._conn.create_function('radians', 1, math.radians)
self._conn.create_function('sign', 1, sign)
self._conn.create_function('sqrt', 1, math.sqrt)
self._conn.create_function('trunc', 1, math.trunc)
self._conn.create_function('trunc', 2, trunc)
# Trigonomentric functions
self._conn.create_function('acos', 1, math.acos)
self._conn.create_function('asin', 1, math.asin)
self._conn.create_function('atan', 1, math.atan)
self._conn.create_function('atan2', 2, math.atan2)
self._conn.create_function('cos', 1, math.cos)
self._conn.create_function(
'cot', 1, lambda x: 1 / math.tan(x) if x else math.inf)
self._conn.create_function('sin', 1, math.sin)
self._conn.create_function('tan', 1, math.tan)
# Random functions
self._conn.create_function('random', 0, random.random)
self._conn.create_function('setseed', 1, random.seed)
# String functions
self._conn.create_function('overlay', 3, SQLiteOverlay.overlay)
self._conn.create_function('overlay', 4, SQLiteOverlay.overlay)
self._conn.create_function('position', 2, SQLitePosition.position)
if (hasattr(self._conn, 'set_trace_callback')
and logger.isEnabledFor(logging.DEBUG)):
self._conn.set_trace_callback(logger.debug)
self._conn.execute('PRAGMA foreign_keys = ON')
return self
def _make_uri(self):
uri = config.get('database', 'uri')
base_uri = parse_uri(uri)
if base_uri.path and base_uri.path != '/':
warnings.warn("The path specified in the URI will be overridden")
if self.name == ':memory:':
query_string = urllib.parse.parse_qs(base_uri.query)
query_string['mode'] = 'memory'
query = urllib.parse.urlencode(query_string, doseq=True)
db_uri = base_uri._replace(netloc='', path='/', query=query)
else:
db_path = safe_join(
config.get('database', 'path'), self.name + '.sqlite')
if not os.path.isfile(db_path):
raise IOError("Database '%s' doesn't exist!" % db_path)
db_uri = base_uri._replace(path=db_path)
# Use unparse before replacing sqlite with file because SQLite accepts
# a relative path URI like file:db/test.sqlite which doesn't conform to
# RFC8089 which urllib follows and enforces when the scheme is 'file'
db_uri = urllib.parse.urlunparse(db_uri)
return db_uri.replace('sqlite', 'file', 1)
def get_connection(
self, autocommit=False, readonly=False, statement_timeout=None):
if self._conn is None:
self.connect()
if autocommit:
self._conn.isolation_level = None
else:
self._conn.isolation_level = 'IMMEDIATE'
return self._conn
def put_connection(self, connection=None, close=False):
pass
def close(self):
if self.name == ':memory:':
return
if self._conn is None:
return
self._conn = None
@classmethod
def create(cls, connection, database_name):
if database_name == ':memory:':
path = ':memory:'
else:
if os.sep in database_name:
return
path = os.path.join(config.get('database', 'path'),
database_name + '.sqlite')
with sqlite.connect(path) as conn:
cursor = conn.cursor()
cursor.close()
@classmethod
def drop(cls, connection, database_name):
if database_name == ':memory:':
cls._local.memory_database._conn = None
return
if os.sep in database_name:
return
os.remove(os.path.join(config.get('database', 'path'),
database_name + '.sqlite'))
def list(self, hostname=None):
res = []
listdir = [':memory:']
try:
listdir += os.listdir(config.get('database', 'path'))
except OSError:
pass
for db_file in listdir:
if db_file.endswith('.sqlite') or db_file == ':memory:':
if db_file == ':memory:':
db_name = ':memory:'
else:
db_name = db_file[:-7]
try:
database = Database(db_name).connect()
except Exception:
logger.debug(
'Test failed for "%s"', db_name, exc_info=True)
continue
if database.test(hostname=hostname):
res.append(db_name)
database.close()
return res
def init(self):
from trytond.modules import get_module_info
Flavor.set(self.flavor)
with self.get_connection() as conn:
cursor = conn.cursor()
sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
with open(sql_file) as fp:
for line in fp.read().split(';'):
if (len(line) > 0) and (not line.isspace()):
cursor.execute(line)
ir_module = Table('ir_module')
ir_module_dependency = Table('ir_module_dependency')
for module in ['ir', 'res']:
info = get_module_info(module)
insert = ir_module.insert(
[ir_module.create_uid, ir_module.create_date,
ir_module.name, ir_module.state],
[[0, CurrentTimestamp(), module, 'to activate']])
cursor.execute(*insert)
cursor.execute('SELECT last_insert_rowid()')
module_id, = cursor.fetchone()
for dependency in info.get('depends', []):
insert = ir_module_dependency.insert(
[ir_module_dependency.create_uid,
ir_module_dependency.create_date,
ir_module_dependency.module,
ir_module_dependency.name,
],
[[0, CurrentTimestamp(), module_id, dependency]])
cursor.execute(*insert)
conn.commit()
def test(self, hostname=None):
Flavor.set(self.flavor)
tables = ['ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
'ir_translation', 'ir_lang', 'ir_configuration']
sqlite_master = Table('sqlite_master')
select = sqlite_master.select(sqlite_master.name)
select.where = sqlite_master.type == 'table'
select.where &= sqlite_master.name.in_(tables)
with self._conn as conn:
cursor = conn.cursor()
try:
cursor.execute(*select)
except Exception:
return False
if len(cursor.fetchall()) != len(tables):
return False
if hostname:
configuration = Table('ir_configuration')
try:
cursor.execute(*configuration.select(
configuration.hostname))
except Exception:
return False
hostnames = {h for h, in cursor if h}
if hostnames and hostname not in hostnames:
return False
return True
def lastid(self, cursor):
# This call is not thread safe
return cursor.lastrowid
def lock(self, connection, table):
pass
def lock_id(self, id, timeout=None):
return Literal(True)
def has_constraint(self, constraint):
return False
def has_multirow_insert(self):
return True
def has_window_functions(self):
return sqlite.sqlite_version_info >= (3, 25, 0)
def sql_type(self, type_):
if type_ in self.TYPES_MAPPING:
return self.TYPES_MAPPING[type_]
if type_.startswith('VARCHAR'):
return SQLType('VARCHAR', type_)
return SQLType(type_, type_)
def sql_format(self, type_, value):
if type_ in ('INTEGER', 'BIGINT'):
if (value is not None
and not isinstance(value, (Query, Expression))):
value = int(value)
return value
def json_get(self, column, key=None):
if key:
column = JSONExtract(column, '$.%s' % key)
return NullIf(JSONQuote(column), JSONQuote(Null))
sqlite.register_converter('NUMERIC', lambda val: Decimal(val.decode('utf-8')))
sqlite.register_adapter(Decimal, lambda val: str(val).encode('utf-8'))
def adapt_datetime(val):
return val.replace(tzinfo=None).isoformat(" ")
sqlite.register_adapter(datetime.datetime, adapt_datetime)
sqlite.register_adapter(datetime.time, lambda val: val.isoformat())
sqlite.register_converter('TIME',
lambda val: datetime.time(*map(int, val.decode('utf-8').split(':'))))
sqlite.register_adapter(datetime.timedelta, lambda val: val.total_seconds())
def convert_interval(value):
value = float(value)
# It is not allowed to instatiate timedelta with the min/max total seconds
if value >= _interval_max:
return datetime.timedelta.max
elif value <= _interval_min:
return datetime.timedelta.min
return datetime.timedelta(seconds=value)
_interval_max = datetime.timedelta.max.total_seconds()
_interval_min = datetime.timedelta.min.total_seconds()
sqlite.register_converter('INTERVAL', convert_interval)

185
backend/sqlite/init.sql Executable file
View File

@@ -0,0 +1,185 @@
CREATE TABLE ir_configuration (
id INTEGER PRIMARY KEY AUTOINCREMENT,
language VARCHAR,
hostname VARCHAR,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE ir_model (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model VARCHAR,
name VARCHAR,
info TEXT,
module VARCHAR,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE ir_model_field (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model VARCHAR,
name VARCHAR,
relation VARCHAR,
field_description VARCHAR,
ttype VARCHAR,
help TEXT,
module VARCHAR,
"access" BOOLEAN,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE ir_ui_view (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model VARCHAR,
"type" VARCHAR,
data TEXT,
field_childs VARCHAR,
priority INTEGER,
domain VARCHAR,
inherit INTEGER,
module VARCHAR,
name VARCHAR,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE ir_ui_menu (
id INTEGER PRIMARY KEY AUTOINCREMENT,
parent INTEGER,
name VARCHAR,
icon VARCHAR,
active BOOLEAN,
sequence INTEGER,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE ir_translation (
id INTEGER PRIMARY KEY AUTOINCREMENT,
lang VARCHAR,
src TEXT,
name VARCHAR,
res_id INTEGER,
value TEXT,
"type" VARCHAR,
module VARCHAR,
fuzzy BOOLEAN,
overriding_module VARCHAR,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE ir_lang (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR,
code VARCHAR,
translatable BOOLEAN,
parent VARCHAR,
active BOOLEAN,
direction VARCHAR,
am VARCHAR,
pm VARCHAR,
"date" VARCHAR,
grouping VARCHAR,
decimal_point VARCHAR,
thousands_sep VARCHAR,
mon_grouping VARCHAR,
mon_decimal_point VARCHAR,
mon_thousands_sep VARCHAR,
p_sign_posn INTEGER,
n_sign_posn INTEGER,
positive_sign VARCHAR,
negative_sign VARCHAR,
p_cs_precedes BOOLEAN,
n_cs_precedes BOOLEAN,
p_sep_by_space BOOLEAN,
n_sep_by_space BOOLEAN,
pg_text_search VARCHAR,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE res_user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR,
active BOOLEAN,
login VARCHAR,
password VARCHAR,
email VARCHAR,
language INTEGER,
menu INTEGER,
password_hash VARCHAR,
password_reset VARCHAR,
password_reset_expire TIMESTAMP,
signature TEXT,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
INSERT INTO res_user (id, login, password, name, active) VALUES (0, 'root', NULL, 'Root', 0);
CREATE TABLE res_group (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR
);
CREATE TABLE "res_user-res_group" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
"user" INTEGER,
"group" INTEGER,
active BOOLEAN,
parent INTEGER,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);
CREATE TABLE ir_module (
id INTEGER PRIMARY KEY AUTOINCREMENT,
create_uid INTEGER,
create_date TIMESTAMP,
write_date TIMESTAMP,
write_uid INTEGER,
name VARCHAR,
state VARCHAR
);
CREATE TABLE ir_module_dependency (
id INTEGER PRIMARY KEY AUTOINCREMENT,
create_uid INTEGER,
create_date TIMESTAMP,
write_date TIMESTAMP,
write_uid INTEGER,
name VARCHAR,
module INTEGER
);
CREATE TABLE ir_cache (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR,
"timestamp" TIMESTAMP,
create_date TIMESTAMP,
create_uid INTEGER,
write_date TIMESTAMP,
write_uid INTEGER
);

383
backend/sqlite/table.py Executable file
View File

@@ -0,0 +1,383 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import logging
import re
import warnings
from weakref import WeakKeyDictionary
from trytond.backend.table import (
IndexTranslatorInterface, TableHandlerInterface)
from trytond.transaction import Transaction
from .database import sqlite
__all__ = ['TableHandler']
logger = logging.getLogger(__name__)
VARCHAR_SIZE_RE = re.compile(r'VARCHAR\(([0-9]+)\)')
def _escape_identifier(name):
return '"%s"' % name.replace('"', '""')
class TableHandler(TableHandlerInterface):
__handlers = WeakKeyDictionary()
index_translators = []
def _init(self, model, history=False):
super()._init(model, history=history)
self.__columns = None
self.__indexes = None
self._model = model
cursor = Transaction().connection.cursor()
# Create new table if necessary
if not self.table_exist(self.table_name):
if not self.history:
cursor.execute('CREATE TABLE %s '
'(id INTEGER PRIMARY KEY AUTOINCREMENT)'
% _escape_identifier(self.table_name))
else:
cursor.execute('CREATE TABLE %s '
'(__id INTEGER PRIMARY KEY AUTOINCREMENT, '
'id INTEGER)' % _escape_identifier(self.table_name))
self._update_definitions()
@classmethod
def table_exist(cls, table_name):
cursor = Transaction().connection.cursor()
cursor.execute("SELECT sql FROM sqlite_master "
"WHERE type = 'table' AND name = ?",
(table_name,))
res = cursor.fetchone()
if not res:
return False
return True
@classmethod
def table_rename(cls, old_name, new_name):
cursor = Transaction().connection.cursor()
if (cls.table_exist(old_name)
and not cls.table_exist(new_name)):
cursor.execute('ALTER TABLE %s RENAME TO %s'
% (_escape_identifier(old_name), _escape_identifier(new_name)))
# Rename history table
old_history = old_name + "__history"
new_history = new_name + "__history"
if (cls.table_exist(old_history)
and not cls.table_exist(new_history)):
cursor.execute('ALTER TABLE %s RENAME TO %s'
% (_escape_identifier(old_history),
_escape_identifier(new_history)))
def column_exist(self, column_name):
return column_name in self._columns
def _recreate_table(self, update_columns=None, drop_columns=None):
if update_columns is None:
update_columns = {}
if drop_columns is None:
drop_columns = []
transaction = Transaction()
database = transaction.database
cursor = transaction.connection.cursor()
temp_table = '__temp_%s' % self.table_name
temp_columns = dict(self._columns)
self.table_rename(self.table_name, temp_table)
self._init(self._model, history=self.history)
columns, old_columns = [], []
for name, values in temp_columns.items():
if name in drop_columns:
continue
typname = update_columns.get(name, {}).get(
'typname', values['typname'])
size = update_columns.get(name, {}).get('size', values['size'])
name = update_columns.get(name, {}).get('name', name)
self._add_raw_column(
name, database.sql_type(typname), field_size=size)
columns.append(name)
old_columns.append(name)
cursor.execute(('INSERT INTO %s ('
+ ','.join(_escape_identifier(x) for x in columns)
+ ') SELECT '
+ ','.join(_escape_identifier(x) for x in old_columns)
+ ' FROM %s') % (
_escape_identifier(self.table_name),
_escape_identifier(temp_table)))
cursor.execute('DROP TABLE %s' % _escape_identifier(temp_table))
self._update_definitions()
def column_rename(self, old_name, new_name):
cursor = Transaction().connection.cursor()
if self.column_exist(old_name):
if not self.column_exist(new_name):
if sqlite.sqlite_version_info >= (3, 25, 0):
cursor.execute('ALTER TABLE %s RENAME COLUMN %s TO %s' % (
_escape_identifier(self.table_name),
_escape_identifier(old_name),
_escape_identifier(new_name)))
self._update_definitions(columns=True)
else:
self._recreate_table({old_name: {'name': new_name}})
else:
logger.warning(
'Unable to rename column %s on table %s to %s.',
old_name, self.table_name, new_name)
@property
def _columns(self):
if self.__columns is None:
cursor = Transaction().connection.cursor()
cursor.execute('PRAGMA table_info("' + self.table_name + '")')
self.__columns = {}
for _, column, type_, notnull, hasdef, _ in cursor:
column = re.sub(r'^\"|\"$', '', column)
match = re.match(r'(\w+)(\((.*?)\))?', type_)
if match:
typname = match.group(1).upper()
size = match.group(3) and int(match.group(3)) or 0
else:
typname = type_.upper()
size = None
self.__columns[column] = {
'notnull': notnull,
'hasdef': hasdef,
'size': size,
'typname': typname,
}
return self.__columns
@property
def _indexes(self):
if self.__indexes is None:
cursor = Transaction().connection.cursor()
try:
cursor.execute('PRAGMA index_list("' + self.table_name + '")')
except IndexError: # There is sometimes IndexError
cursor.execute('PRAGMA index_list("' + self.table_name + '")')
self.__indexes = [l[1] for l in cursor]
return self.__indexes
def _update_definitions(self, columns=True):
if columns:
self.__columns = None
def alter_size(self, column_name, column_type):
self._recreate_table({column_name: {'size': column_type}})
def alter_type(self, column_name, column_type):
self._recreate_table({column_name: {'typname': column_type}})
def column_is_type(self, column_name, type_, *, size=-1):
db_type = self._columns[column_name]['typname'].upper()
database = Transaction().database
base_type = database.sql_type(type_).base.upper()
if base_type == 'VARCHAR' and (size is None or size >= 0):
same_size = self._columns[column_name]['size'] == size
else:
same_size = True
return base_type == db_type and same_size
def db_default(self, column_name, value):
warnings.warn('Unable to set default on column with SQLite backend')
def add_column(self, column_name, sql_type, default=None, comment=''):
database = Transaction().database
column_type = database.sql_type(sql_type)
match = VARCHAR_SIZE_RE.match(sql_type)
field_size = int(match.group(1)) if match else None
self._add_raw_column(column_name, column_type, default, field_size,
comment)
def _add_raw_column(self, column_name, column_type, default=None,
field_size=None, string=''):
if self.column_exist(column_name):
base_type = column_type[0].upper()
if base_type != self._columns[column_name]['typname']:
if (self._columns[column_name]['typname'], base_type) in [
('VARCHAR', 'TEXT'),
('TEXT', 'VARCHAR'),
('DATE', 'TIMESTAMP'),
('INTEGER', 'FLOAT'),
('INTEGER', 'NUMERIC'),
('FLOAT', 'NUMERIC'),
]:
self.alter_type(column_name, base_type)
else:
logger.warning(
'Unable to migrate column %s on table %s '
'from %s to %s.',
column_name, self.table_name,
self._columns[column_name]['typname'], base_type)
if (base_type == 'VARCHAR'
and self._columns[column_name]['typname'] == 'VARCHAR'):
# Migrate size
from_size = self._columns[column_name]['size']
if field_size is None:
if from_size > 0:
self.alter_size(column_name, base_type)
elif from_size == field_size:
pass
elif from_size and from_size < field_size:
self.alter_size(column_name, column_type[1])
else:
logger.warning(
'Unable to migrate column %s on table %s '
'from varchar(%s) to varchar(%s).',
column_name, self.table_name,
from_size if from_size and from_size > 0 else "",
field_size)
return
cursor = Transaction().connection.cursor()
column_type = column_type[1]
cursor.execute(('ALTER TABLE %s ADD COLUMN %s %s') % (
_escape_identifier(self.table_name),
_escape_identifier(column_name),
column_type))
if default:
# check if table is non-empty:
cursor.execute('SELECT 1 FROM %s limit 1'
% _escape_identifier(self.table_name))
if cursor.fetchone():
# Populate column with default values:
cursor.execute('UPDATE ' + _escape_identifier(self.table_name)
+ ' SET ' + _escape_identifier(column_name) + ' = ?',
(default(),))
self._update_definitions(columns=True)
def add_fk(self, columns, reference, ref_columns=None, on_delete=None):
warnings.warn('Unable to add foreign key with SQLite backend')
def drop_fk(self, columns=None, ref_columns=None, table=None):
warnings.warn('Unable to drop foreign key with SQLite backend')
def not_null_action(self, column_name, action='add'):
if not self.column_exist(column_name):
return
if action == 'add':
warnings.warn('Unable to set not null with SQLite backend')
elif action == 'remove':
warnings.warn('Unable to remove not null with SQLite backend')
else:
raise Exception('Not null action not supported!')
def add_constraint(self, ident, constraint):
warnings.warn('Unable to add constraint with SQLite backend')
def drop_constraint(self, ident, table=None):
warnings.warn('Unable to drop constraint with SQLite backend')
def set_indexes(self, indexes, concurrently=False):
cursor = Transaction().connection.cursor()
old = set(self._indexes)
for index in indexes:
translator = self.index_translator_for(index)
if translator:
name, query, params = translator.definition(index)
name = '_'.join([self.table_name, name])
name = 'idx_' + self.convert_name(name, reserved=len('idx_'))
# SQLite does not support parameters for index creation
if not params:
cursor.execute(
'CREATE INDEX IF NOT EXISTS %s ON %s %s' % (
_escape_identifier(name),
_escape_identifier(self.table_name),
query),
params)
else:
warnings.warn("Can not create index with parameters")
old.discard(name)
for name in old:
if name.startswith('idx_') or name.endswith('_index'):
cursor.execute('DROP INDEX %s' % _escape_identifier(name))
self.__indexes = None
def drop_column(self, column_name):
if not self.column_exist(column_name):
return
transaction = Transaction()
cursor = transaction.connection.cursor()
if sqlite.sqlite_version_info >= (3, 35, 0):
cursor.execute('ALTER TABLE %s DROP COLUMN %s' % (
_escape_identifier(self.table_name),
_escape_identifier(column_name)))
self._update_definitions(columns=True)
else:
self._recreate_table(drop_columns=[column_name])
@classmethod
def drop_table(cls, model, table, cascade=False):
cursor = Transaction().connection.cursor()
cursor.execute('DELETE from ir_model_data where model = ?',
(model,))
query = 'DROP TABLE %s' % _escape_identifier(table)
if cascade:
query = query + ' CASCADE'
cursor.execute(query)
class IndexMixin:
def __init_subclass__(cls):
TableHandler.index_translators.append(cls)
@classmethod
def definition(cls, index):
expr_template = '%(expression)s %(collate)s %(order)s'
params = []
expressions = []
for expression, usage in index.expressions:
expressions.append(expr_template %
cls._get_expression_variables(expression, usage))
params.extend(expression.params)
where = ''
if index.options.get('where'):
where = 'WHERE %s' % index.options['where']
params.extend(index.options['where'].params)
query = '(%(expressions)s) %(where)s' % {
'expressions': ','.join(expressions),
'where': where,
}
name = cls._get_name(query, params)
return name, query, params
@classmethod
def _get_expression_variables(cls, expression, usage):
variables = {
'expression': str(expression),
'collate': '',
'order': '',
}
if usage.options.get('collation'):
variables['collate'] = 'COLLATE %s' % usage.options['collation']
if usage.options.get('order'):
order = usage.options['order'].upper()
for predicate in ['NULLS FIRST', 'NULLS LAST']:
if order.endswith(predicate):
order = order[:-len(predicate)]
variables['order'] = order
return variables
class IndexTranslator(IndexMixin, IndexTranslatorInterface):
@classmethod
def score(cls, index):
supported_indexes_count = sum(
int(u.__class__.__name__ in {'Equality', 'Range'})
for _, u in index.expressions)
return supported_indexes_count * 100

140
backend/table.py Executable file
View File

@@ -0,0 +1,140 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import hashlib
from weakref import WeakKeyDictionary
from trytond.transaction import Transaction
class TableHandlerInterface(object):
'''
Define generic interface to handle database table
'''
namedatalen = None
index_translators = None
__handlers = WeakKeyDictionary()
def __new__(cls, model, history=False):
transaction = Transaction()
handlers = cls.__handlers.setdefault(transaction, {})
key = (model.__name__, history)
if key not in handlers:
instance = handlers[key] = super().__new__(cls)
instance._init(model, history=history)
return handlers[key]
def _init(self, model, history=False):
'''
:param model: the Model linked to the table
:param module_name: the module name
:param history: a boolean to define if it is a history table
'''
super(TableHandlerInterface, self).__init__()
if history:
self.table_name = model._table + '__history'
else:
self.table_name = model._table
self.object_name = model.__name__
if history:
self.sequence_name = self.table_name + '___id_seq'
else:
self.sequence_name = self.table_name + '_id_seq'
self.history = history
@classmethod
def table_exist(cls, table_name):
raise NotImplementedError
@classmethod
def table_rename(cls, old_name, new_name):
raise NotImplementedError
def column_exist(self, column_name):
raise NotImplementedError
def column_rename(self, old_name, new_name):
raise NotImplementedError
def alter_size(self, column_name, column_type):
raise NotImplementedError
def alter_type(self, column_name, column_type):
raise NotImplementedError
def column_is_type(self, column_name, type_, *, size=-1):
raise NotImplementedError
def db_default(self, column_name, value):
raise NotImplementedError
def add_column(self, column_name, abstract_type, default=None, comment=''):
raise NotImplementedError
def add_fk(self, columns, reference, ref_columns=None, on_delete=None):
raise NotImplementedError
def drop_fk(self, columns, ref_columns=None, table=None):
raise NotImplementedError
def not_null_action(self, column_name, action='add'):
raise NotImplementedError
def add_constraint(self, ident, constraint):
raise NotImplementedError
def drop_constraint(self, ident, table=None):
raise NotImplementedError
def create_index(self, index):
raise NotImplementedError
def drop_column(self, column_name):
raise NotImplementedError
@classmethod
def drop_table(cls, model, table, cascade=False):
raise NotImplementedError
@classmethod
def convert_name(cls, name, reserved=0):
if cls.namedatalen:
length = cls.namedatalen - reserved
if length <= 0:
raise ValueError
if len(name) >= length:
if isinstance(name, str):
name = name.encode('utf-8')
name = hashlib.sha256(name).hexdigest()[:length - 1]
return name
def set_indexes(self, indexes, concurrently=False):
raise NotImplementedError
def index_translator_for(self, index):
return next(
filter(
lambda t: t.score(index) > 0,
sorted(
self.index_translators, key=lambda t: t.score(index),
reverse=True)),
None)
class IndexTranslatorInterface:
@classmethod
def _get_name(cls, query, params):
def hash_(s):
return hashlib.shake_128(s.encode('utf-8')).hexdigest(16)
names = [str(query)]
if params:
names.append(str(params))
return '_'.join(map(hash_, names))
@classmethod
def definition(cls, index):
raise NotImplementedError
@classmethod
def score(cls, index):
raise NotImplementedError

269
bus.py Executable file
View File

@@ -0,0 +1,269 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import collections
import json
import logging
import os
import selectors
import threading
import time
import uuid
from urllib.parse import urljoin
from trytond import backend
from trytond.config import config
from trytond.protocols.jsonrpc import JSONDecoder, JSONEncoder
from trytond.protocols.wrappers import (
HTTPStatus, Response, exceptions, redirect)
from trytond.tools import resolve
from trytond.transaction import Transaction
from trytond.wsgi import app
logger = logging.getLogger(__name__)
_db_timeout = config.getint('database', 'timeout')
_cache_timeout = config.getint('bus', 'cache_timeout')
_select_timeout = config.getint('bus', 'select_timeout')
_long_polling_timeout = config.getint('bus', 'long_polling_timeout')
_allow_subscribe = config.getboolean('bus', 'allow_subscribe')
_url_host = config.get('bus', 'url_host')
_web_cache_timeout = config.getint('web', 'cache_timeout')
class _MessageQueue:
Message = collections.namedtuple('Message', 'channel content timestamp')
def __init__(self, timeout):
super().__init__()
self._lock = collections.defaultdict(threading.Lock)
self._timeout = timeout
self._messages = []
def append(self, channel, element):
self._messages.append(
self.Message(channel, element, time.time()))
def get_next(self, channels, from_id=None):
oldest = time.time() - self._timeout
to_delete_index = 0
found = False
first_message = None
message = self.Message(None, None, None)
for idx, item in enumerate(self._messages):
if item.timestamp < oldest:
to_delete_index = idx
continue
if item.channel not in channels:
continue
if not first_message:
first_message = item
if from_id is None or found:
message = item
break
found = item.content['message_id'] == from_id
else:
if first_message and not found:
message = first_message
with self._lock[os.getpid()]:
del self._messages[:to_delete_index]
return message.channel, message.content
class LongPollingBus:
_channel = 'bus'
_queues_lock = collections.defaultdict(threading.Lock)
_queues = collections.defaultdict(
lambda: {'timeout': None, 'events': collections.defaultdict(list)})
_messages = {}
@classmethod
def subscribe(cls, database, channels, last_message=None):
pid = os.getpid()
with cls._queues_lock[pid]:
start_listener = (pid, database) not in cls._queues
cls._queues[pid, database]['timeout'] = time.time() + _db_timeout
if start_listener:
listener = threading.Thread(
target=cls._listen, args=(database,), daemon=True)
cls._queues[pid, database]['listener'] = listener
listener.start()
messages = cls._messages.get(database)
if messages:
channel, content = messages.get_next(channels, last_message)
if content:
return cls.create_response(channel, content)
event = threading.Event()
for channel in channels:
if channel in cls._queues[pid, database]['events']:
event_channel = cls._queues[pid, database]['events'][channel]
else:
with cls._queues_lock[pid]:
event_channel = cls._queues[pid, database][
'events'][channel]
event_channel.append(event)
triggered = event.wait(_long_polling_timeout)
if not triggered:
response = cls.create_response(None, None)
else:
response = cls.create_response(
*cls._messages[database].get_next(channels, last_message))
with cls._queues_lock[pid]:
for channel in channels:
events = cls._queues[pid, database]['events'][channel]
for e in events[:]:
if e.is_set():
events.remove(e)
return response
@classmethod
def create_response(cls, channel, message):
response_data = {
'message': message,
'channel': channel,
}
logger.debug('Bus: %s', response_data)
return response_data
@classmethod
def _listen(cls, database):
db = backend.Database(database)
if not db.has_channel():
raise exceptions.NotImplemented
logger.info("listening on channel '%s'", cls._channel)
conn = db.get_connection(autocommit=True)
pid = os.getpid()
selector = selectors.DefaultSelector()
try:
cursor = conn.cursor()
cursor.execute('LISTEN "%s"' % cls._channel)
cls._messages[database] = messages = _MessageQueue(_cache_timeout)
now = time.time()
selector.register(conn, selectors.EVENT_READ)
while cls._queues[pid, database]['timeout'] > now:
selector.select(timeout=_select_timeout)
conn.poll()
while conn.notifies:
notification = conn.notifies.pop()
payload = json.loads(
notification.payload,
object_hook=JSONDecoder())
channel = payload['channel']
message = payload['message']
messages.append(channel, message)
with cls._queues_lock[pid]:
events = cls._queues[pid, database][
'events'][channel].copy()
cls._queues[pid, database]['events'][channel].clear()
for event in events:
event.set()
now = time.time()
except Exception:
logger.error('bus listener on "%s" crashed', database,
exc_info=True)
with cls._queues_lock[pid]:
del cls._queues[pid, database]
raise
finally:
selector.close()
db.put_connection(conn)
with cls._queues_lock[pid]:
if cls._queues[pid, database]['timeout'] <= now:
del cls._queues[pid, database]
else:
# A query arrived between the end of the while and here
listener = threading.Thread(
target=cls._listen, args=(database,), daemon=True)
cls._queues[pid, database]['listener'] = listener
listener.start()
@classmethod
def publish(cls, channel, message):
transaction = Transaction()
if not transaction.database.has_channel():
logger.debug('Database backend do not support channels')
return
cursor = transaction.connection.cursor()
message['message_id'] = str(uuid.uuid4())
payload = json.dumps({
'channel': channel,
'message': message,
}, cls=JSONEncoder, separators=(',', ':'))
cursor.execute('NOTIFY "%s", %%s' % cls._channel, (payload,))
if config.get('bus', 'class'):
Bus = resolve(config.get('bus', 'class'))
else:
Bus = LongPollingBus
@app.route('/<string:database_name>/bus', methods=['POST'])
@app.auth_required
def subscribe(request, database_name):
if not _allow_subscribe:
raise exceptions.NotImplemented
if _url_host and _url_host != request.host_url:
response = redirect(
urljoin(_url_host, request.path), HTTPStatus.PERMANENT_REDIRECT)
# Allow to change the redirection after some time
response.headers['Cache-Control'] = (
'private, max-age=%s' % _web_cache_timeout)
return response
user = request.authorization.get('userid')
channels = request.parsed_data.get('channels', [])
if user is None:
raise exceptions.BadRequest
channels = set(filter(lambda c: not c.startswith('user:'), channels))
channels.add('user:%s' % user)
last_message = request.parsed_data.get('last_message')
logger.debug(
"getting bus messages from %s@%s%s for %s since %s",
request.authorization.username, request.remote_addr, request.path,
channels, last_message)
bus_response = Bus.subscribe(database_name, channels, last_message)
return Response(
json.dumps(bus_response, cls=JSONEncoder, separators=(',', ':')),
content_type='application/json')
def notify(title, body=None, priority=1, user=None, client=None):
if user is None:
if client is None:
context_client = Transaction().context.get('client')
if context_client:
channel = 'client:%s' % context_client
else:
return
else:
channel = 'client:%s' % client
elif client is None:
channel = 'user:%s' % user
else:
channel = 'client:%s' % client
return Bus.publish(channel, {
'type': 'notification',
'title': title,
'body': body,
'priority': priority,
})

492
cache.py Executable file
View File

@@ -0,0 +1,492 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import datetime as dt
import json
import logging
import os
import selectors
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from weakref import WeakKeyDictionary
from sql import Table
from sql.aggregate import Max
from sql.functions import CurrentTimestamp, Function
from trytond import backend
from trytond.config import config
from trytond.pool import Pool
from trytond.tools import grouped_slice, resolve
from trytond.transaction import Transaction
__all__ = ['BaseCache', 'Cache', 'LRUDict', 'LRUDictTransaction']
_clear_timeout = config.getint('cache', 'clean_timeout', default=5 * 60)
_default_size_limit = config.getint('cache', 'default')
logger = logging.getLogger(__name__)
def _cast(column):
class SQLite_DateTime(Function):
__slots__ = ()
_function = 'DATETIME'
if backend.name == 'sqlite':
column = SQLite_DateTime(column)
return column
def freeze(o):
if isinstance(o, (set, tuple, list)):
return tuple(freeze(x) for x in o)
elif isinstance(o, dict):
return frozenset((x, freeze(y)) for x, y in o.items())
else:
return o
def unfreeze(o):
if isinstance(o, tuple):
return [unfreeze(x) for x in o]
elif isinstance(o, frozenset):
return dict((x, unfreeze(y)) for x, y in o)
else:
return o
def _get_modules(cursor):
ir_module = Table('ir_module')
cursor.execute(*ir_module.select(
ir_module.name,
where=ir_module.state.in_(
['activated', 'to upgrade', 'to remove'])))
return {m for m, in cursor}
class BaseCache(object):
_instances = {}
context_ignored_keys = {
'client', '_request', '_check_access', '_skip_warnings',
}
def __init__(
self, name, duration=None, context=True,
context_ignored_keys=None):
assert ((context_ignored_keys is not None and context)
or (context_ignored_keys is None)), (
f"context_ignored_keys ({context_ignored_keys}) is not valid"
f" in regards to context ({context}).")
self._name = name
self.size_limit = config.getint(
'cache', name, default=_default_size_limit)
self.context = context
self.context_ignored_keys = set()
if context and context_ignored_keys:
self.context_ignored_keys.update(context_ignored_keys)
self.hit = self.miss = 0
if isinstance(duration, dt.timedelta):
self.duration = duration
elif isinstance(duration, (int, float)):
self.duration = dt.timedelta(seconds=duration)
elif duration:
self.duration = dt.timedelta(**duration)
else:
self.duration = None
assert self._name not in self._instances
self._instances[self._name] = self
@classmethod
def stats(cls):
for name, inst in cls._instances.items():
yield {
'name': name,
'hit': inst.hit,
'miss': inst.miss,
}
def _key(self, key):
if self.context:
context = Transaction().context.copy()
for k in (self.__class__.context_ignored_keys
| self.context_ignored_keys):
context.pop(k, None)
return (key, freeze(context))
return key
def get(self, key, default=None):
raise NotImplementedError
def set(self, key, value):
raise NotImplementedError
def clear(self):
raise NotImplementedError
@classmethod
def clear_all(cls):
for inst in cls._instances.values():
inst.clear()
@classmethod
def sync(cls, transaction):
raise NotImplementedError
def sync_since(self, value):
raise NotImplementedError
@classmethod
def commit(cls, transaction):
raise NotImplementedError
@classmethod
def rollback(cls, transaction):
raise NotImplementedError
@classmethod
def drop(cls, dbname):
raise NotImplementedError
class MemoryCache(BaseCache):
"""
A key value LRU cache with size limit.
"""
_reset = WeakKeyDictionary()
_clean_last = None
_default_lower = Transaction.monotonic_time()
_listener = {}
_listener_lock = defaultdict(threading.Lock)
_table = 'ir_cache'
_channel = _table
def __init__(self, *args, **kwargs):
super(MemoryCache, self).__init__(*args, **kwargs)
self._database_cache = defaultdict(lambda: LRUDict(self.size_limit))
self._transaction_cache = WeakKeyDictionary()
self._transaction_lower = {}
self._timestamp = {}
def _get_cache(self):
transaction = Transaction()
dbname = transaction.database.name
lower = self._transaction_lower.get(dbname, self._default_lower)
if (self._name in self._reset.get(transaction, set())
or transaction.started_at < lower):
try:
return self._transaction_cache[transaction]
except KeyError:
cache = self._database_cache.default_factory()
self._transaction_cache[transaction] = cache
return cache
else:
return self._database_cache[dbname]
def get(self, key, default=None):
key = self._key(key)
cache = self._get_cache()
try:
(expire, result) = cache.pop(key)
if expire and expire < dt.datetime.now():
self.miss += 1
return default
cache[key] = (expire, result)
self.hit += 1
return deepcopy(result)
except (KeyError, TypeError):
self.miss += 1
return default
def set(self, key, value):
key = self._key(key)
cache = self._get_cache()
if self.duration:
expire = dt.datetime.now() + self.duration
else:
expire = None
try:
cache[key] = (expire, deepcopy(value))
except TypeError:
pass
return value
def clear(self):
transaction = Transaction()
self._reset.setdefault(transaction, set()).add(self._name)
self._transaction_cache.pop(transaction, None)
def _clear(self, dbname, timestamp=None):
logger.debug("clearing cache '%s' of '%s'", self._name, dbname)
self._timestamp[dbname] = timestamp
self._database_cache[dbname] = self._database_cache.default_factory()
self._transaction_lower[dbname] = max(
Transaction.monotonic_time(),
self._transaction_lower.get(dbname, self._default_lower))
@classmethod
def _clear_all(cls, dbname):
for inst in cls._instances.values():
inst._clear(dbname)
@classmethod
def sync(cls, transaction):
if cls._clean_last is None:
cls._clean_last = dt.datetime.now()
return
database = transaction.database
dbname = database.name
if not _clear_timeout and database.has_channel():
pid = os.getpid()
with cls._listener_lock[pid]:
if (pid, dbname) not in cls._listener:
cls._listener[pid, dbname] = listener = threading.Thread(
target=cls._listen, args=(dbname,), daemon=True)
listener.start()
return
last_clean = (dt.datetime.now() - cls._clean_last).total_seconds()
if last_clean < _clear_timeout:
return
connection = database.get_connection(readonly=True, autocommit=True)
try:
with connection.cursor() as cursor:
table = Table(cls._table)
cursor.execute(*table.select(
_cast(table.timestamp), table.name))
timestamps = {}
for timestamp, name in cursor:
timestamps[name] = timestamp
modules = _get_modules(cursor)
finally:
database.put_connection(connection)
for name, timestamp in timestamps.items():
try:
inst = cls._instances[name]
except KeyError:
continue
inst_timestamp = inst._timestamp.get(dbname)
if not inst_timestamp or timestamp > inst_timestamp:
inst._clear(dbname, timestamp)
Pool.refresh(dbname, modules)
cls._clean_last = dt.datetime.now()
def sync_since(self, value):
return self._clean_last > value
@classmethod
def commit(cls, transaction):
table = Table(cls._table)
reset = cls._reset.pop(transaction, None)
if not reset:
return
database = transaction.database
dbname = database.name
if not _clear_timeout and transaction.database.has_channel():
with transaction.connection.cursor() as cursor:
# The count computed as
# 8000 (max notify size) / 64 (max name data len)
for sub_reset in grouped_slice(reset, 125):
cursor.execute(
'NOTIFY "%s", %%s' % cls._channel,
(json.dumps(list(sub_reset), separators=(',', ':')),))
else:
connection = database.get_connection(
readonly=False, autocommit=True)
try:
with connection.cursor() as cursor:
for name in reset:
cursor.execute(*table.select(table.name, table.id,
table.timestamp,
where=table.name == name,
limit=1))
if cursor.fetchone():
# It would be better to insert only
cursor.execute(*table.update([table.timestamp],
[CurrentTimestamp()],
where=table.name == name))
else:
cursor.execute(*table.insert(
[table.timestamp, table.name],
[[CurrentTimestamp(), name]]))
cursor.execute(*table.select(
Max(table.timestamp),
where=table.name == name))
timestamp, = cursor.fetchone()
cursor.execute(*table.select(
_cast(Max(table.timestamp)),
where=table.name == name))
timestamp, = cursor.fetchone()
inst = cls._instances[name]
inst._clear(dbname, timestamp)
connection.commit()
finally:
database.put_connection(connection)
cls._clean_last = dt.datetime.now()
reset.clear()
@classmethod
def rollback(cls, transaction):
cls._reset.pop(transaction, None)
@classmethod
def drop(cls, dbname):
pid = os.getpid()
with cls._listener_lock[pid]:
listener = cls._listener.pop((pid, dbname), None)
if listener:
database = backend.Database(dbname)
conn = database.get_connection()
try:
cursor = conn.cursor()
cursor.execute('NOTIFY "%s"' % cls._channel)
conn.commit()
finally:
database.put_connection(conn)
listener.join()
for inst in cls._instances.values():
inst._timestamp.pop(dbname, None)
inst._database_cache.pop(dbname, None)
inst._transaction_lower.pop(dbname, None)
@classmethod
def refresh_pool(cls, transaction):
database = transaction.database
dbname = database.name
if not _clear_timeout and database.has_channel():
database = backend.Database(dbname)
conn = database.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'NOTIFY "%s", %%s' % cls._channel, ('refresh pool',))
conn.commit()
finally:
database.put_connection(conn)
@classmethod
def _listen(cls, dbname):
current_thread = threading.current_thread()
pid = os.getpid()
conn, selector = None, None
try:
database = backend.Database(dbname)
if not database.has_channel():
raise NotImplementedError
logger.info(
"listening on channel '%s' of '%s'", cls._channel, dbname)
conn = database.get_connection(autocommit=True)
selector = selectors.DefaultSelector()
cursor = conn.cursor()
cursor.execute('LISTEN "%s"' % cls._channel)
# Clear everything in case we missed a payload
Pool.refresh(dbname, _get_modules(cursor))
cls._clear_all(dbname)
current_thread.listening = True
selector.register(conn, selectors.EVENT_READ)
while cls._listener.get((pid, dbname)) == current_thread:
selector.select(timeout=60)
conn.poll()
while conn.notifies:
notification = conn.notifies.pop()
if notification.payload == 'refresh pool':
Pool.refresh(dbname, _get_modules(cursor))
elif notification.payload:
reset = json.loads(notification.payload)
for name in reset:
inst = cls._instances[name]
inst._clear(dbname)
cls._clean_last = dt.datetime.now()
except Exception:
logger.error(
"cache listener on '%s' crashed", dbname, exc_info=True)
raise
finally:
if selector:
selector.close()
if conn:
database.put_connection(conn)
with cls._listener_lock[pid]:
if cls._listener.get((pid, dbname)) == current_thread:
del cls._listener[pid, dbname]
if config.get('cache', 'class'):
Cache = resolve(config.get('cache', 'class'))
else:
Cache = MemoryCache
class LRUDict(OrderedDict):
"""
Dictionary with a size limit.
If size limit is reached, it will remove the first added items.
The default_factory provides the same behavior as in standard
collections.defaultdict.
If default_factory_with_key is set, the default_factory is called with the
missing key.
"""
__slots__ = ('size_limit',)
def __init__(self, size_limit,
default_factory=None, default_factory_with_key=False,
*args, **kwargs):
assert size_limit > 0
self.size_limit = size_limit
super(LRUDict, self).__init__(*args, **kwargs)
self.default_factory = default_factory
self.default_factory_with_key = default_factory_with_key
self._check_size_limit()
def __setitem__(self, key, value):
super(LRUDict, self).__setitem__(key, value)
self._check_size_limit()
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
if self.default_factory_with_key:
value = self.default_factory(key)
else:
value = self.default_factory()
self[key] = value
return value
def update(self, *args, **kwargs):
super(LRUDict, self).update(*args, **kwargs)
self._check_size_limit()
def setdefault(self, key, default=None):
default = super(LRUDict, self).setdefault(key, default=default)
self._check_size_limit()
return default
def _check_size_limit(self):
while len(self) > self.size_limit:
self.popitem(last=False)
class LRUDictTransaction(LRUDict):
"""
Dictionary with a size limit and default_factory. (see LRUDict)
It is refreshed when transaction counter is changed.
"""
__slots__ = ('transaction', 'counter')
def __init__(self, *args, **kwargs):
super(LRUDictTransaction, self).__init__(*args, **kwargs)
self.transaction = Transaction()
self.counter = self.transaction.counter
def clear(self):
super(LRUDictTransaction, self).clear()
self.counter = self.transaction.counter
def refresh(self):
if self.counter != self.transaction.counter:
self.clear()

201
commandline.py Executable file
View File

@@ -0,0 +1,201 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import argparse
import csv
import logging
import logging.config
import logging.handlers
import os
import os.path
from contextlib import contextmanager
from io import StringIO
from trytond import __version__
logger = logging.getLogger(__name__)
def database_completer(parsed_args, **kwargs):
from trytond.config import config
from trytond.transaction import Transaction
config.update_etc(parsed_args.configfile)
with Transaction().start(
None, 0, readonly=True, close=True) as transaction:
return transaction.database.list()
def module_completer(**kwargs):
from trytond.modules import get_modules
return get_modules()
def language_completer(**kwargs):
files = os.listdir(os.path.join(os.path.dirname(__file__), 'ir', 'locale'))
return [os.path.splitext(f)[0] for f in files]
def get_base_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--version', action='version',
version='%(prog)s ' + __version__)
parser.add_argument("-c", "--config", dest="configfile", metavar='FILE',
nargs='+', default=[os.environ.get('TRYTOND_CONFIG')],
help="specify configuration files")
return parser
def get_parser():
parser = get_base_parser()
parser.add_argument("-v", "--verbose", action='count',
dest="verbose", default=0, help="increase verbosity")
parser.add_argument('--dev', dest='dev', action='store_true',
help='enable development mode')
logging_config = os.environ.get('TRYTOND_LOGGING_CONFIG')
db_names = os.environ.get('TRYTOND_DATABASE_NAMES')
if db_names:
db_names = list(next(csv.reader(StringIO(db_names))))
else:
db_names = []
parser.add_argument(
"-d", "--database", dest="database_names", nargs='+',
default=db_names, metavar='DATABASE',
help="specify the database names").completer = database_completer
parser.add_argument(
"--logconf", dest="logconf", default=logging_config, metavar='FILE',
help="set logging configuration file (ConfigParser format)")
return parser
def get_parser_daemon():
parser = get_parser()
parser.add_argument("--pidfile", dest="pidfile", metavar='FILE',
help="set file to store the process id")
parser.add_argument(
"--coroutine", action="store_true", dest="coroutine",
default=bool(os.environ.get('TRYTOND_COROUTINE', False)),
help="use coroutine for concurrency")
return parser
def get_parser_worker():
parser = get_parser_daemon()
parser.add_argument("--name", dest='name',
help="work only on the named queue")
parser.add_argument("-n", dest='processes', type=int,
help="set number of processes to use")
parser.add_argument("--max", dest='maxtasksperchild', type=int,
help="set number of tasks a worker process before being replaced")
parser.add_argument("-t", "--timeout", dest='timeout', default=60,
type=int, help="set maximum timeout when waiting notification")
return parser
def get_parser_cron():
parser = get_parser_daemon()
parser.add_argument("-1", "--once", dest='once', action='store_true',
help="run pending tasks and halt")
return parser
def get_parser_admin():
parser = get_parser()
parser.add_argument(
"-u", "--update", dest="update", nargs='+', default=[],
metavar='MODULE',
help="activate or update modules").completer = module_completer
parser.add_argument(
"--indexes", dest="indexes",
action=getattr(argparse, 'BooleanOptionalAction', 'store_true'),
default=None, help="update indexes")
parser.add_argument("--all", dest="update", action="append_const",
const="ir", help="update all activated modules")
parser.add_argument("--activate-dependencies", dest="activatedeps",
action="store_true",
help="activate missing dependencies of updated modules")
parser.add_argument("--email", dest="email", help="set the admin email")
parser.add_argument("-p", "--password", dest="password",
action='store_true', help="set the admin password")
parser.add_argument("--reset-password", dest='reset_password',
action='store_true', help="reset the admin password")
parser.add_argument("--test-email", dest='test_email',
help="send a test email to the specified address")
parser.add_argument("-m", "--update-modules-list", action="store_true",
dest="update_modules_list", help="update the list of tryton modules")
parser.add_argument(
"-l", "--language", dest="languages", nargs='+',
default=[], metavar='CODE',
help="load language translations").completer = language_completer
parser.add_argument("--hostname", dest="hostname", default=None,
help="limit database listing to the hostname")
parser.add_argument("--validate", dest="validate", nargs='*',
metavar='MODEL', help="validate records of models")
parser.add_argument("--validate-percentage", dest="validate_percentage",
type=float, default=100, metavar="PERCENTAGE",
help="percentage of records to validate (default: 100)")
parser.epilog = ('The first time a database is initialized '
'or when the password is set, the admin password is read '
'from file defined by TRYTONPASSFILE environment variable '
'or interactively asked from the user.\n'
'The config file can be specified in the TRYTOND_CONFIG '
'environment variable.\n'
'The database URI can be specified in the TRYTOND_DATABASE_URI '
'environment variable.')
return parser
def get_parser_console():
parser = get_base_parser()
parser.add_argument(
"-d", "--database", dest="database_name",
required=True, metavar='DATABASE',
help="specify the database name").completer = database_completer
parser.add_argument("--histsize", dest="histsize", type=int, default=500,
help="set the number of commands to remember in the command history")
parser.add_argument("--readonly", dest="readonly", action='store_true',
help="start a readonly transaction")
parser.add_argument(
"--lock-table", dest="lock_tables", nargs='+', default=[],
metavar='TABLE', help="lock tables")
parser.epilog = "To store changes, `transaction.commit()` must be called."
return parser
def get_parser_stat():
parser = get_base_parser()
parser.epilog = "To exit press 'q', to inverse sort order press 'r'."
return parser
def config_log(options):
if options.logconf:
logging.config.fileConfig(
options.logconf, disable_existing_loggers=False)
logging.getLogger('server').info('using %s as logging '
'configuration file', options.logconf)
else:
logformat = ('%(process)s %(thread)s [%(asctime)s] '
'%(levelname)s %(name)s %(message)s')
if not options.verbose and 'TRYTOND_LOGGING_LEVEL' in os.environ:
logging_level = int(os.environ['TRYTOND_LOGGING_LEVEL'])
level = max(logging_level, logging.NOTSET)
else:
level = max(logging.ERROR - options.verbose * 10, logging.NOTSET)
logging.basicConfig(level=level, format=logformat)
logging.captureWarnings(True)
@contextmanager
def pidfile(options):
path = options.pidfile
if not path:
yield
else:
with open(path, 'w') as fd:
fd.write('%d' % os.getpid())
yield
os.unlink(path)

167
config.py Executable file
View File

@@ -0,0 +1,167 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import configparser
import logging
import os
import urllib.parse
from getpass import getuser
__all__ = ['config', 'get_hostname', 'get_port', 'split_netloc',
'parse_listen', 'parse_uri']
logger = logging.getLogger(__name__)
# Needed so urlunsplit to always set netloc
for backend_name in ['postgresql', 'sqlite']:
if backend_name not in urllib.parse.uses_netloc:
urllib.parse.uses_netloc.append(backend_name)
def get_hostname(netloc):
if '[' in netloc and ']' in netloc:
return netloc.split(']')[0][1:]
elif ':' in netloc:
return netloc.split(':')[0]
else:
return netloc
def get_port(netloc):
netloc = netloc.split(']')[-1]
return int(netloc.split(':')[1])
def split_netloc(netloc):
return get_hostname(netloc).replace('*', ''), get_port(netloc)
def parse_listen(value):
for netloc in value.split(','):
yield split_netloc(netloc)
def parse_uri(uri):
return urllib.parse.urlparse(uri)
class TrytonConfigParser(configparser.ConfigParser):
def __init__(self):
super().__init__(interpolation=None)
self.add_section('web')
self.set('web', 'listen', 'localhost:8000')
self.set('web', 'root', os.path.join(os.path.expanduser('~'), 'www'))
self.set('web', 'num_proxies', '0')
self.set('web', 'cache_timeout', str(60 * 60 * 12))
self.add_section('database')
self.set('database', 'uri',
os.environ.get('TRYTOND_DATABASE_URI', 'sqlite://'))
self.set('database', 'path', os.path.join(
os.path.expanduser('~'), 'db'))
self.set('database', 'list', 'True')
self.set('database', 'retry', '5')
self.set('database', 'language', 'en')
self.set('database', 'timeout', str(30 * 60))
self.set('database', 'subquery_threshold', str(1_000))
self.add_section('request')
self.set('request', 'max_size', str(2 * 1024 * 1024))
self.set('request', 'max_size_authenticated',
str(2 * 1024 * 1024 * 1024))
self.set('request', 'timeout', str(60))
self.add_section('cache')
self.set('cache', 'transaction', '10')
self.set('cache', 'model', '200')
self.set('cache', 'record', '2000')
self.set('cache', 'field', '100')
self.set('cache', 'default', '1024')
self.set('cache', 'ir.message', '10240')
self.set('cache', 'ir.translation', '10240')
self.add_section('queue')
self.set('queue', 'worker', 'False')
self.add_section('ssl')
self.add_section('email')
self.set('email', 'uri', 'smtp://localhost:25')
self.set('email', 'from', getuser())
self.add_section('session')
self.set('session', 'authentications', 'password')
self.set('session', 'max_age', str(60 * 60 * 24 * 30))
self.set('session', 'timeout', str(60 * 5))
self.set('session', 'max_attempt', '5')
self.set('session', 'max_attempt_ip_network', '300')
self.set('session', 'ip_network_4', '32')
self.set('session', 'ip_network_6', '56')
self.add_section('password')
self.set('password', 'length', '8')
self.set('password', 'reset_timeout', str(24 * 60 * 60))
self.add_section('bus')
self.set('bus', 'allow_subscribe', 'False')
self.set('bus', 'long_polling_timeout', str(5 * 60))
self.set('bus', 'cache_timeout', '5')
self.set('bus', 'select_timeout', '5')
self.add_section('html')
self.update_environ()
self.update_etc()
def update_environ(self):
for key, value in os.environ.items():
if not key.startswith('TRYTOND_'):
continue
try:
section, option = key[len('TRYTOND_'):].lower().split('__', 1)
except ValueError:
continue
if section.startswith('wsgi_'):
section = section.replace('wsgi_', 'wsgi ')
if not self.has_section(section):
self.add_section(section)
self.set(section, option, value)
def update_etc(self, configfile=os.environ.get('TRYTOND_CONFIG')):
if isinstance(configfile, str):
configfile = [configfile]
if not configfile or not [_f for _f in configfile if _f]:
return []
configfile = [os.path.expanduser(filename) for filename in configfile]
read_files = self.read(configfile)
logger.info('using %s as configuration files', ', '.join(read_files))
if configfile != read_files:
logger.error('could not load %s',
','.join(set(configfile) - set(read_files)))
return configfile
def get(self, section, option, *args, **kwargs):
default = kwargs.pop('default', None)
try:
return configparser.RawConfigParser.get(self, section, option,
*args, **kwargs)
except (configparser.NoOptionError, configparser.NoSectionError):
return default
def getint(self, section, option, *args, **kwargs):
default = kwargs.pop('default', None)
try:
return configparser.RawConfigParser.getint(self, section, option,
*args, **kwargs)
except (configparser.NoOptionError, configparser.NoSectionError,
TypeError):
return default
def getfloat(self, section, option, *args, **kwargs):
default = kwargs.pop('default', None)
try:
return configparser.RawConfigParser.getfloat(self, section, option,
*args, **kwargs)
except (configparser.NoOptionError, configparser.NoSectionError,
TypeError):
return default
def getboolean(self, section, option, *args, **kwargs):
default = kwargs.pop('default', None)
try:
return configparser.RawConfigParser.getboolean(
self, section, option, *args, **kwargs)
except (configparser.NoOptionError, configparser.NoSectionError,
AttributeError):
return default
config = TrytonConfigParser()

66
console.py Executable file
View File

@@ -0,0 +1,66 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import atexit
import os
import readline
import sys
from code import InteractiveConsole
from rlcompleter import Completer
from trytond import __version__
from trytond.pool import Pool
from trytond.transaction import Transaction
from trytond.worker import run_task
class Console(InteractiveConsole):
def __init__(self, locals=None, filename="<console>", histsize=-1,
histfile=os.path.expanduser("~/.trytond_console_history")):
super().__init__(locals, filename)
self.init_completer(locals)
self.init_history(histfile, histsize)
def init_completer(selfi, locals):
completer = Completer(locals)
readline.set_completer(completer.complete)
readline.parse_and_bind("tab: complete")
def init_history(self, histfile, histsize):
readline.parse_and_bind("tab: complete")
if hasattr(readline, 'read_history_file'):
try:
readline.read_history_file(histfile)
except FileNotFoundError:
pass
atexit.register(self.save_history, histfile, histsize)
def save_history(self, histfile, histsize):
readline.set_history_length(histsize)
readline.write_history_file(histfile)
def run(options):
db_name = options.database_name
pool = Pool(db_name)
with Transaction().start(db_name, 0, readonly=True):
pool.init()
with Transaction().start(
db_name, 0, readonly=options.readonly,
_lock_tables=options.lock_tables) as transaction:
local = {
'pool': pool,
'transaction': transaction,
}
if sys.stdin.isatty():
console = Console(local, histsize=options.histsize)
banner = "Tryton %s, Python %s on %s" % (
__version__, sys.version, sys.platform)
console.interact(banner=banner, exitmsg='')
else:
console = InteractiveConsole(local)
console.runcode(sys.stdin.read())
transaction.rollback()
while transaction.tasks:
task_id = transaction.tasks.pop()
run_task(pool, task_id)

23
const.py Executable file
View File

@@ -0,0 +1,23 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
OPERATORS = (
'where',
'not where',
'child_of',
'not child_of',
'parent_of',
'not parent_of',
'=',
'!=',
'like',
'not like',
'ilike',
'not ilike',
'in',
'not in',
'<=',
'>=',
'<',
'>',
)
MODULES_GROUP = 'trytond.modules'

844
convert.py Executable file
View File

@@ -0,0 +1,844 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import datetime
import logging
import re
import time
from collections import defaultdict
from decimal import Decimal
from xml import sax
from trytond import __version__
from trytond.pyson import CONTEXT, PYSONEncoder
from trytond.tools import grouped_slice
from trytond.transaction import Transaction, inactive_records
logger = logging.getLogger(__name__)
CDATA_START = re.compile(r'^\s*\<\!\[cdata\[', re.IGNORECASE)
CDATA_END = re.compile(r'\]\]\>\s*$', re.IGNORECASE)
class ParsingError(Exception):
pass
class DummyTagHandler:
"""Dubhandler implementing empty methods. Will be used when whe
want to ignore the xml content"""
def __init__(self):
pass
def startElement(self, name, attributes):
pass
def characters(self, data):
pass
def endElement(self, name):
pass
class MenuitemTagHandler:
"""Taghandler for the tag <record> """
def __init__(self, master_handler):
self.mh = master_handler
self.xml_id = None
def startElement(self, name, attributes):
cursor = Transaction().connection.cursor()
values = {}
try:
self.xml_id = attributes['id']
except KeyError:
self.xml_id = None
raise ParsingError("missing 'id' attribute")
for attr in ('name', 'sequence', 'parent', 'action', 'groups'):
if attr in attributes:
values[attr] = attributes.get(attr)
values['icon'] = attributes.get('icon', 'tryton-folder')
if attributes.get('active'):
values['active'] = bool(eval(attributes['active']))
if values.get('parent'):
model, id_ = self.mh.get_id(values['parent'])
if model != 'ir.ui.menu':
raise ParsingError(
"invalid 'ir.ui.menu' parent: %s" % model)
values['parent'] = id_
action_name = None
if values.get('action'):
model, action_id = self.mh.get_id(values['action'])
if not model.startswith('ir.action'):
raise ParsingError(
"invalid model for action: %s" % model)
# TODO maybe use a prefetch for this:
action = self.mh.pool.get('ir.action').__table__()
report = self.mh.pool.get('ir.action.report').__table__()
act_window = self.mh.pool.get('ir.action.act_window').__table__()
wizard = self.mh.pool.get('ir.action.wizard').__table__()
url = self.mh.pool.get('ir.action.url').__table__()
act_window_view = self.mh.pool.get(
'ir.action.act_window.view').__table__()
view = self.mh.pool.get('ir.ui.view').__table__()
icon = self.mh.pool.get('ir.ui.icon').__table__()
cursor.execute(*action.join(
report, 'LEFT',
condition=action.id == report.action
).join(act_window, 'LEFT',
condition=action.id == act_window.action
).join(wizard, 'LEFT',
condition=action.id == wizard.action
).join(url, 'LEFT',
condition=action.id == url.action
).join(act_window_view, 'LEFT',
condition=act_window.id == act_window_view.act_window
).join(view, 'LEFT',
condition=view.id == act_window_view.view
).join(icon, 'LEFT',
condition=action.icon == icon.id).select(
action.name.as_('action_name'),
action.type.as_('action_type'),
view.type.as_('view_type'),
view.field_childs.as_('field_childs'),
icon.name.as_('icon_name'),
where=(report.id == action_id)
| (act_window.id == action_id)
| (wizard.id == action_id)
| (url.id == action_id),
order_by=act_window_view.sequence, limit=1))
action_name, action_type, view_type, field_childs, icon_name = \
cursor.fetchone()
values['action'] = '%s,%s' % (action_type, action_id)
icon = attributes.get('icon', '')
if icon:
values['icon'] = icon
elif icon_name:
values['icon'] = icon_name
elif action_type == 'ir.action.wizard':
values['icon'] = 'tryton-launch'
elif action_type == 'ir.action.report':
values['icon'] = 'tryton-print'
elif action_type == 'ir.action.act_window':
if view_type == 'tree':
if field_childs:
values['icon'] = 'tryton-tree'
else:
values['icon'] = 'tryton-list'
elif view_type == 'form':
values['icon'] = 'tryton-form'
elif view_type == 'graph':
values['icon'] = 'tryton-graph'
elif view_type == 'calendar':
values['icon'] = 'tryton-calendar'
elif action_type == 'ir.action.url':
values['icon'] = 'tryton-public'
else:
values['icon'] = None
if values.get('groups'):
raise ParsingError("forbidden 'groups' attribute")
if not values.get('name'):
if not action_name:
raise ParsingError("missing 'name' or 'action' attribute")
else:
values['name'] = action_name
if values.get('sequence'):
values['sequence'] = int(values['sequence'])
self.values = values
def characters(self, data):
pass
def endElement(self, name):
"""Must return the object to use for the next call """
if name != "menuitem":
return self
else:
self.mh.import_record('ir.ui.menu', self.values, self.xml_id)
return None
def current_state(self):
return "menuitem '%s.%s'" % (self.mh.module, self.xml_id)
class RecordTagHandler:
"""Taghandler for the tag <record> and all the tags inside it"""
def __init__(self, master_handler):
# Remind reference of parent handler
self.mh = master_handler
# stock xml_id parsed in one module
self.xml_ids = []
self.model = None
self.xml_id = None
self.update = None
self.values = None
self.current_field = None
self.cdata = None
self.start_cdata = None
def startElement(self, name, attributes):
# Manage the top level tag
if name == "record":
try:
self.xml_id = attributes["id"]
except KeyError:
self.xml_id = None
raise ParsingError("missing 'id' attribute")
self.model = self.mh.pool.get(attributes["model"])
self.update = bool(int(attributes.get('update', '0')))
# create/update a dict containing fields values
self.values = {}
self.current_field = None
self.cdata = False
return self.xml_id
# Manage included tags:
elif name == "field":
field_name = attributes['name']
field_type = attributes.get('type', '')
# Remind the current name and if we have to load (see characters)
self.current_field = field_name
depends = attributes.get('depends', '').split(',')
depends = {m.strip() for m in depends if m}
if not depends.issubset(self.mh.modules):
self.current_field = None
return
# Create a new entry in the values
self.values[field_name] = ""
# Put a flag to escape cdata tags
if field_type == "xml":
self.cdata = "start"
# Catch the known attributes
search_attr = attributes.get('search', '')
ref_attr = attributes.get('ref', '')
eval_attr = attributes.get('eval', '')
pyson_attr = bool(int(attributes.get('pyson', '0')))
context = {}
context['time'] = time
context['version'] = __version__.rsplit('.', 1)[0]
context['ref'] = lambda xml_id: ','.join(self.mh.get_id(xml_id))
context['Decimal'] = Decimal
context['datetime'] = datetime
if pyson_attr:
context.update(CONTEXT)
field = self.model._fields[field_name]
if search_attr:
search_model = field.model_name
SearchModel = self.mh.pool.get(search_model)
with inactive_records():
found, = SearchModel.search(eval(search_attr, context))
self.values[field_name] = found.id
elif ref_attr:
model, id_ = self.mh.get_id(ref_attr)
if field._type == 'reference':
self.values[field_name] = '%s,%s' % (model, id_)
else:
if (field.model_name == 'ir.action'
and model.startswith('ir.action')):
pass
elif model != field.model_name:
raise ParsingError(
"invalid model for %s: %s" % (field_name, model))
self.values[field_name] = id_
elif eval_attr:
value = eval(eval_attr, context)
if pyson_attr:
value = PYSONEncoder(sort_keys=True).encode(value)
self.values[field_name] = value
else:
raise ParsingError(
"forbidden '%s' tag inside record tag" % name)
def characters(self, data):
"""If we are in a field tag, consume all the content"""
if not self.current_field:
return
# Escape start cdata tag if necessary
if self.cdata == "start":
data = CDATA_START.sub('', data)
self.start_cdata = "inside"
self.values[self.current_field] += data
def endElement(self, name):
"""Must return the object to use for the next call, if name is
not 'record' we return self to keep our hand on the
process. If name is 'record' we return None to end the
delegation"""
if name == "field":
if not self.current_field:
return self
# Escape end cdata tag :
if self.cdata in ('inside', 'start'):
self.values[self.current_field] = \
CDATA_END.sub('', self.values[self.current_field])
self.cdata = 'done'
self.current_field = None
return self
elif name == "record":
if self.xml_id in self.xml_ids and not self.update:
raise ParsingError("duplicate id: %s" % self.xml_id)
self.mh.import_record(
self.model.__name__, self.values, self.xml_id)
self.xml_ids.append(self.xml_id)
return None
else:
raise ParsingError("unexpected closing tag '%s'" % name)
def current_state(self):
return "record '%s.%s'" % (self.mh.module, self.xml_id)
class Fs2bdAccessor:
"""
Used in TrytondXmlHandler.
Provide some helper function to ease cache access and management.
"""
def __init__(self, ModelData, pool):
self.fs2db = {}
self.fetched_modules = []
self.ModelData = ModelData
self.browserecord = {}
self.pool = pool
def get(self, module, fs_id):
if module not in self.fetched_modules:
self.fetch_new_module(module)
return self.fs2db[module].get(fs_id, None)
def exists(self, module, fs_id):
if module not in self.fetched_modules:
self.fetch_new_module(module)
return fs_id in self.fs2db[module]
def get_browserecord(self, module, model_name, db_id):
if module not in self.fetched_modules:
self.fetch_new_module(module)
if model_name in self.browserecord[module] \
and db_id in self.browserecord[module][model_name]:
return self.browserecord[module][model_name][db_id]
return None
def set(self, module, fs_id, values):
"""
Whe call the prefetch function here to. Like that whe are sure
not to erase data when get is called.
"""
if module not in self.fetched_modules:
self.fetch_new_module(module)
if fs_id not in self.fs2db[module]:
self.fs2db[module][fs_id] = {}
fs2db_val = self.fs2db[module][fs_id]
for key, val in values.items():
fs2db_val[key] = val
def reset_browsercord(self, module, model_name, ids=None):
if module not in self.fetched_modules:
return
self.browserecord[module].setdefault(model_name, {})
Model = self.pool.get(model_name)
if not ids:
ids = list(self.browserecord[module][model_name].keys())
with Transaction().set_context(language='en'):
models = Model.browse(ids)
for model in models:
if model.id in self.browserecord[module][model_name]:
for cache in Transaction().cache.values():
if model_name in cache:
cache[model_name].pop(model.id, None)
self.browserecord[module][model_name][model.id] = model
def fetch_new_module(self, module):
self.fs2db[module] = {}
module_data_ids = self.ModelData.search([
('module', '=', module),
], order=[('db_id', 'ASC')])
record_ids = {}
for rec in self.ModelData.browse(module_data_ids):
self.fs2db[rec.module][rec.fs_id] = {
"db_id": rec.db_id, "model": rec.model,
"id": rec.id, "values": rec.values
}
record_ids.setdefault(rec.model, [])
record_ids[rec.model].append(rec.db_id)
self.browserecord[module] = {}
for model_name in record_ids.keys():
try:
Model = self.pool.get(model_name)
except KeyError:
continue
self.browserecord[module][model_name] = {}
for sub_record_ids in grouped_slice(record_ids[model_name]):
with inactive_records():
records = Model.search([
('id', 'in', list(sub_record_ids)),
], order=[('id', 'ASC')])
with Transaction().set_context(language='en'):
models = Model.browse(list(map(int, records)))
for model in models:
self.browserecord[module][model_name][model.id] = model
self.fetched_modules.append(module)
class TrytondXmlHandler(sax.handler.ContentHandler):
def __init__(self, pool, module, module_state, modules, languages):
"Register known taghandlers, and managed tags."
sax.handler.ContentHandler.__init__(self)
self.pool = pool
self.module = module
self.ModelData = pool.get('ir.model.data')
self.fs2db = Fs2bdAccessor(self.ModelData, pool)
self.to_delete = self.populate_to_delete()
self.noupdate = None
self.module_state = module_state
self.grouped = None
self.grouped_creations = defaultdict(dict)
self.grouped_write = defaultdict(list)
self.grouped_model_data = []
self.skip_data = False
self.modules = modules
self.languages = languages
# Tag handlders are used to delegate the processing
self.taghandlerlist = {
'record': RecordTagHandler(self),
'menuitem': MenuitemTagHandler(self),
}
self.taghandler = None
# Managed tags are handled by the current class
self.managedtags = ["data", "tryton"]
# Connect to the sax api:
self.sax_parser = sax.make_parser()
# Tell the parser we are not interested in XML namespaces
self.sax_parser.setFeature(sax.handler.feature_namespaces, 0)
self.sax_parser.setContentHandler(self)
def parse_xmlstream(self, stream):
"""
Take a byte stream has input and parse the xml content.
"""
source = sax.InputSource()
source.setByteStream(stream)
try:
self.sax_parser.parse(source)
except Exception as e:
raise ParsingError("in %s" % self.current_state()) from e
return self.to_delete
def startElement(self, name, attributes):
"""Rebind the current handler if necessary and call
startElement on it"""
if not self.taghandler:
if name in self.taghandlerlist:
self.taghandler = self.taghandlerlist[name]
elif name == "data":
self.noupdate = bool(int(attributes.get("noupdate", '0')))
self.grouped = bool(int(attributes.get('grouped', 0)))
self.skip_data = False
depends = attributes.get('depends', '').split(',')
depends = {m.strip() for m in depends if m}
if not depends.issubset(self.modules):
self.skip_data = True
if (attributes.get('language')
and attributes.get('language') not in self.languages):
self.skip_data = True
elif name == "tryton":
pass
else:
logger.info("Tag %s not supported", (name,))
return
if self.taghandler and not self.skip_data:
self.taghandler.startElement(name, attributes)
def characters(self, data):
if self.taghandler:
self.taghandler.characters(data)
def endElement(self, name):
if name == 'data' and self.grouped:
for model, values in self.grouped_creations.items():
self.create_records(model, values.values(), values.keys())
self.grouped_creations.clear()
for key, actions in self.grouped_write.items():
module, model = key
self.write_records(module, model, *actions)
self.grouped_write.clear()
if name == 'data' and self.grouped_model_data:
self.ModelData.write(*self.grouped_model_data)
del self.grouped_model_data[:]
# Closing tag found, if we are in a delegation the handler
# know what to do:
if self.taghandler and not self.skip_data:
self.taghandler = self.taghandler.endElement(name)
if self.taghandler == self.taghandlerlist.get(name):
self.taghandler = None
def current_state(self):
if self.taghandler:
return self.taghandler.current_state()
else:
return '?'
def get_id(self, xml_id):
if '.' in xml_id:
module, xml_id = xml_id.split('.')
else:
module = self.module
if self.fs2db.get(module, xml_id) is None:
raise ParsingError("%s.%s not found" % (module, xml_id))
value = self.fs2db.get(module, xml_id)
return value['model'], value["db_id"]
@staticmethod
def _clean_value(key, record):
"""
Take a field name, a browse_record, and a reference to the
corresponding object. Return a raw value has it must look on the
db.
"""
Model = record.__class__
# search the field type in the object or in a parent
field_type = Model._fields[key]._type
# handle the value regarding to the type
if field_type == 'many2one':
return getattr(record, key).id if getattr(record, key) else None
elif field_type == 'reference':
if not getattr(record, key):
return None
return str(getattr(record, key))
elif field_type in ['one2many', 'many2many']:
raise ParsingError(
"unsupported field %s of type %s" % (key, field_type))
else:
return getattr(record, key)
def populate_to_delete(self):
"""Create a list of all the records that whe should met in the update
process. The records that are not encountered are deleted from the
database in post_import."""
# Fetch the data in id descending order to avoid depedendcy
# problem when the corresponding recordds will be deleted:
module_data = self.ModelData.search([
('module', '=', self.module),
], order=[('id', 'DESC')])
return set(rec.fs_id for rec in module_data)
def import_record(self, model, values, fs_id):
module = self.module
if not fs_id:
raise ValueError("missing fs_id")
if '.' in fs_id:
assert len(fs_id.split('.')) == 2, ('"%s" contains too many dots. '
'file system ids should contain ot most one dot ! '
'These are used to refer to other modules data, '
'as in module.reference_id' % (fs_id))
module, fs_id = fs_id.split('.')
if not self.fs2db.get(module, fs_id):
raise ParsingError("%s.%s not found" % (module, fs_id))
Model = self.pool.get(model)
if self.fs2db.exists(module, fs_id):
# Remove this record from the to_delete list. This means that
# the corresponding record have been found.
if module == self.module and fs_id in self.to_delete:
self.to_delete.remove(fs_id)
if self.noupdate and self.module_state != 'to activate':
return
# this record is already in the db:
db_value = self.fs2db.get(module, fs_id)
db_id = db_value['db_id']
db_model = db_value['model']
mdata_id = db_value['id']
old_values = db_value['values']
# Check if record has not been deleted
if db_id is None:
return
if not old_values:
old_values = {}
else:
old_values = self.ModelData.load_values(old_values)
for key in old_values:
if isinstance(old_values[key], bytes):
# Fix for migration to unicode
old_values[key] = old_values[key].decode('utf-8')
if model != db_model:
raise ParsingError(
"wrong model '%s': %s.%s" % (model, module, fs_id))
record = self.fs2db.get_browserecord(module, Model.__name__, db_id)
# Re-create record if it was deleted
if not record:
with Transaction().set_context(
module=module, language='en'):
record, = Model.create([values])
# reset_browsercord
self.fs2db.reset_browsercord(
module, Model.__name__, [record.id])
record = self.fs2db.get_browserecord(
module, Model.__name__, record.id)
data = self.ModelData.search([
('fs_id', '=', fs_id),
('module', '=', module),
('model', '=', Model.__name__),
], limit=1)
self.ModelData.write(data, {
'db_id': record.id,
})
self.fs2db.get(module, fs_id)["db_id"] = record.id
to_update = {}
for key in values:
db_field = self._clean_value(key, record)
# if the fs value is the same as in the db, we ignore it
if db_field == values[key]:
continue
# we cannot update a field if it was changed by a user...
if key not in old_values:
expected_value = Model._defaults.get(key,
lambda *a: None)()
else:
expected_value = old_values[key]
# ... and we consider that there is an update if the
# expected value differs from the actual value, _and_
# if they are not false in a boolean context (ie None,
# False, {} or [])
if db_field != expected_value and (db_field or expected_value):
logger.warning(
"Field %s of %s@%s not updated (id: %s), because "
"it has changed since the last update",
key, record.id, model, fs_id)
continue
# so, the field in the fs and in the db are different,
# and no user changed the value in the db:
to_update[key] = values[key]
if self.grouped:
self.grouped_write[(module, model)].extend(
(record, to_update, old_values, values, fs_id, mdata_id))
else:
self.write_records(module, model,
record, to_update, old_values, values, fs_id, mdata_id)
else:
if self.grouped:
self.grouped_creations[model][fs_id] = values
else:
self.create_records(model, [values], [fs_id])
def create_records(self, model, vlist, fs_ids):
Model = self.pool.get(model)
with Transaction().set_context(module=self.module, language='en'):
records = Model.create(vlist)
mdata_values = []
for record, values, fs_id in zip(records, vlist, fs_ids):
for key in values:
values[key] = self._clean_value(key, record)
mdata_values.append({
'fs_id': fs_id,
'model': model,
'module': self.module,
'db_id': record.id,
'values': self.ModelData.dump_values(values),
'fs_values': self.ModelData.dump_values(values),
'noupdate': self.noupdate,
})
models_data = self.ModelData.create(mdata_values)
for record, values, fs_id, mdata in zip(
records, vlist, fs_ids, models_data):
self.fs2db.set(self.module, fs_id, {
'db_id': record.id,
'model': model,
'id': mdata.id,
'values': self.ModelData.dump_values(values),
})
self.fs2db.reset_browsercord(self.module, model,
[r.id for r in records])
def write_records(self, module, model,
record, values, old_values, new_values, fs_id, mdata_id, *args):
args = (record, values, old_values, new_values, fs_id, mdata_id) + args
Model = self.pool.get(model)
actions = iter(args)
to_update = []
for record, values, _, _, _, _ in zip(*((actions,) * 6)):
if values:
to_update += [[record], values]
# if there is values to update:
if to_update:
# write the values in the db:
with Transaction().set_context(
module=module, language='en'):
Model.write(*to_update)
self.fs2db.reset_browsercord(
module, Model.__name__, sum(to_update[::2], []))
actions = iter(to_update)
for records, values in zip(actions, actions):
record, = records
# re-read it: this ensure that we store the real value
# in the model_data table:
record = self.fs2db.get_browserecord(
module, Model.__name__, record.id)
if not record:
with Transaction().set_context(language='en'):
record = Model(record.id)
for key in values:
values[key] = self._clean_value(key, record)
actions = iter(args)
for record, values, old_values, new_values, fs_id, mdata_id in zip(
*((actions,) * 6)):
temp_values = old_values.copy()
temp_values.update(values)
values = temp_values
fs_values = old_values.copy()
fs_values.update(new_values)
if old_values != values or values != fs_values:
self.grouped_model_data.extend(([self.ModelData(mdata_id)], {
'fs_id': fs_id,
'model': model,
'module': module,
'db_id': record.id,
'values': self.ModelData.dump_values(values),
'fs_values': self.ModelData.dump_values(fs_values),
'noupdate': self.noupdate,
}))
# reset_browsercord to keep cache memory low
self.fs2db.reset_browsercord(module, Model.__name__, args[::6])
def post_import(pool, module, to_delete):
"""
Remove the records that are given in to_delete.
"""
transaction = Transaction()
mdata_delete = []
ModelData = pool.get("ir.model.data")
with inactive_records():
mdata = ModelData.search([
('fs_id', 'in', to_delete),
('module', '=', module),
], order=[('id', 'DESC')])
for mrec in mdata:
model, db_id, fs_id = mrec.model, mrec.db_id, mrec.fs_id
try:
# Deletion of the record
try:
Model = pool.get(model)
except KeyError:
Model = None
if Model:
Model.delete([Model(db_id)])
mdata_delete.append(mrec)
else:
logger.warning(
"could not delete %d@%s from %s.%s "
"because model no longer exists",
db_id, model, module, fs_id)
except Exception as e:
transaction.rollback()
logger.warning(
"could not delete %d@%s from %s.%s (%s).",
db_id, model, module, fs_id, e)
if 'active' in Model._fields:
try:
Model.write([Model(db_id)], {
'active': False,
})
except Exception as e:
transaction.rollback()
logger.error(
"could not deactivate %d@%s from %s.%s (%s)",
db_id, model, module, fs_id, e)
else:
logger.info(
"deleted %s@%s from %s.%s", db_id, model, module, fs_id)
transaction.commit()
# Clean model_data:
if mdata_delete:
ModelData.delete(mdata_delete)
transaction.commit()
return True

47
cron.py Executable file
View File

@@ -0,0 +1,47 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import logging
import threading
import time
from trytond.pool import Pool
from trytond.transaction import Transaction
__all__ = ['run']
logger = logging.getLogger(__name__)
def run(options):
threads = []
for name in options.database_names:
thread = threading.Thread(target=lambda: Pool(name).init())
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
threads = {}
while True:
for db_name in options.database_names:
thread = threads.get(db_name)
if thread and thread.is_alive():
logger.info(
'skip "%s" as previous cron still running', db_name)
continue
database_list = Pool.database_list()
pool = Pool(db_name)
if db_name not in database_list:
with Transaction().start(db_name, 0, readonly=True):
pool.init()
Cron = pool.get('ir.cron')
thread = threading.Thread(
target=Cron.run,
args=(db_name,), kwargs={})
logger.info('start thread for "%s"', db_name)
thread.start()
threads[db_name] = thread
if options.once:
break
time.sleep(60)
for thread in threads.values():
thread.join()

73
exceptions.py Executable file
View File

@@ -0,0 +1,73 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
class TrytonException(Exception):
pass
class UserError(TrytonException):
def __init__(self, message, description='', domain=None):
super().__init__('UserError', (message, description, domain))
self.message = message
self.description = description
self.domain = domain
self.code = 1
def __str__(self):
return '%s - %s' % (self.message, self.description)
class UserWarning(TrytonException):
"Exception that will be displayed as a warning message in the client."
def __init__(self, name, message, description=''):
super(UserWarning, self).__init__('UserWarning', (name, message,
description))
self.name = name
self.message = message
self.description = description
self.code = 2
def __str__(self):
return '%s - %s' % (self.message, self.description)
class LoginException(TrytonException):
"""Request the named parameter for the login process.
The type can be 'password' or 'char'.
"""
def __init__(self, name, message, type='password'):
super(LoginException, self).__init__(
'LoginException', (name, message, type))
self.name = name
self.message = message
self.type = type
self.code = 3
class ConcurrencyException(TrytonException):
def __init__(self, message):
super(ConcurrencyException, self).__init__('ConcurrencyException',
message)
self.message = message
self.code = 4
def __str__(self):
return self.message
class RateLimitException(TrytonException):
"""User has sent too many requests in a given amount of time."""
class MissingDependenciesException(TrytonException):
def __init__(self, missings):
self.missings = missings
def __str__(self):
return 'Missing dependencies: %s' % ' '.join(self.missings)

67
filestore.py Executable file
View File

@@ -0,0 +1,67 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import hashlib
import os
from trytond.config import config
from trytond.tools import resolve
__all__ = ['filestore']
class FileStore(object):
def get(self, id, prefix=''):
filename = self._filename(id, prefix)
with open(filename, 'rb') as fp:
return fp.read()
def getmany(self, ids, prefix=''):
return [self.get(id, prefix) for id in ids]
def size(self, id, prefix=''):
filename = self._filename(id, prefix)
statinfo = os.stat(filename)
return statinfo.st_size
def sizemany(self, ids, prefix=''):
return [self.size(id, prefix) for id in ids]
def set(self, data, prefix=''):
id = self._id(data)
filename = self._filename(id, prefix)
dirname = os.path.dirname(filename)
os.makedirs(dirname, mode=0o770, exist_ok=True)
collision = 0
while True:
basename = os.path.basename(filename)
if os.path.exists(filename):
if data != self.get(basename, prefix):
collision += 1
filename = self._filename(
'%s-%s' % (id, collision), prefix)
continue
else:
with open(filename, 'wb')as fp:
fp.write(data)
return basename
def setmany(self, data, prefix=''):
return [self.set(d, prefix) for d in data]
def _filename(self, id, prefix):
path = os.path.normpath(config.get('database', 'path'))
filename = os.path.join(path, prefix, id[0:2], id[2:4], id)
filename = os.path.normpath(filename)
if not filename.startswith(path):
raise ValueError('Bad prefix')
return filename
def _id(self, data):
return hashlib.md5(data).hexdigest()
if config.get('database', 'class'):
FileStore = resolve(config.get('database', 'class')) # noqa: F811
filestore = FileStore()

37
i18n.py Executable file
View File

@@ -0,0 +1,37 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from trytond.pool import Pool
from trytond.tools.string_ import LazyString
from trytond.transaction import Transaction
def gettext(message_id, *args, **variables):
"Returns the message translated into language"
if not Transaction().database:
return message_id
pool = Pool()
try:
Message = pool.get('ir.message')
except KeyError:
return message_id
if not args:
language = Transaction().language
else:
language, = args
try:
module, id_ = message_id.split('.')
except ValueError:
if pool.test:
raise
return message_id
try:
return Message.gettext(module, id_, language, **variables)
except (KeyError, ValueError):
if pool.test:
raise
return message_id
def lazy_gettext(message_id, *args, **variables):
"Like gettext but the string returned is lazy"
return LazyString(gettext, message_id, *args, **variables)

108
ir/__init__.py Executable file
View File

@@ -0,0 +1,108 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
from trytond.pool import Pool
from . import (
action, attachment, avatar, cache, calendar_, configuration, cron, date,
email_, error, export, lang, message, model, module, note, queue_, routes,
rule, sequence, session, translation, trigger, ui)
__all__ = ['register', 'routes']
def register():
Pool.register(
model.ModelField, # register first for model char migration
configuration.Configuration,
translation.Translation,
translation.TranslationSetStart,
translation.TranslationSetSucceed,
translation.TranslationCleanStart,
translation.TranslationCleanSucceed,
translation.TranslationUpdateStart,
translation.TranslationExportStart,
translation.TranslationExportResult,
sequence.SequenceType,
sequence.Sequence,
sequence.SequenceStrict,
ui.menu.UIMenu,
ui.menu.UIMenuFavorite,
ui.view.View,
ui.view.ShowViewStart,
ui.view.ViewTreeWidth,
ui.view.ViewTreeOptional,
ui.view.ViewTreeState,
ui.view.ViewSearch,
ui.icon.Icon,
action.Action,
action.ActionKeyword,
action.ActionReport,
action.ActionActWindow,
action.ActionActWindowView,
action.ActionActWindowDomain,
action.ActionWizard,
action.ActionURL,
model.Model,
model.ModelAccess,
model.ModelFieldAccess,
model.ModelButton,
model.ModelButtonRule,
model.ModelButtonClick,
model.ModelButtonReset,
model.ModelData,
model.Log,
model.PrintModelGraphStart,
attachment.Attachment,
note.Note,
note.NoteRead,
avatar.Avatar,
avatar.AvatarCache,
cron.Cron,
lang.Lang,
lang.LangConfigStart,
export.Export,
export.ExportLine,
rule.RuleGroup,
rule.Rule,
module.Module,
module.ModuleDependency,
module.ModuleConfigWizardItem,
module.ModuleConfigWizardFirst,
module.ModuleConfigWizardOther,
module.ModuleConfigWizardDone,
module.ModuleActivateUpgradeStart,
module.ModuleActivateUpgradeDone,
module.ModuleConfigStart,
cache.Cache,
date.Date,
trigger.Trigger,
trigger.TriggerLog,
session.Session,
session.SessionWizard,
queue_.Queue,
calendar_.Month,
calendar_.Day,
message.Message,
email_.Email,
email_.EmailAddress,
email_.EmailTemplate,
email_.EmailTemplate_Report,
error.Error,
module='ir', type_='model')
Pool.register(
translation.TranslationSet,
translation.TranslationClean,
translation.TranslationUpdate,
translation.TranslationExport,
translation.TranslationReport,
ui.view.ShowView,
model.PrintModelGraph,
module.ModuleConfigWizard,
module.ModuleActivateUpgrade,
module.ModuleConfig,
lang.LangConfig,
module='ir', type_='wizard')
Pool.register(
model.ModelGraph,
model.ModelWorkflowGraph,
module='ir', type_='report')

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More