# This file is part of Tryton. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. import json import logging import os import time import warnings from collections import defaultdict from datetime import datetime from decimal import Decimal from itertools import chain, repeat from threading import RLock from psycopg2 import Binary, connect from psycopg2.extensions import ( ISOLATION_LEVEL_REPEATABLE_READ, UNICODE, AsIs, cursor, register_adapter, register_type) from psycopg2.pool import PoolError, ThreadedConnectionPool from psycopg2.sql import SQL, Identifier try: from psycopg2.extensions import PYDATE, PYDATETIME, PYINTERVAL, PYTIME except ImportError: PYDATE, PYDATETIME, PYTIME, PYINTERVAL = None, None, None, None from psycopg2 import DataError as DatabaseDataError from psycopg2 import IntegrityError as DatabaseIntegrityError from psycopg2 import InterfaceError from psycopg2 import OperationalError as DatabaseOperationalError from psycopg2 import ProgrammingError from psycopg2.errors import QueryCanceled as DatabaseTimeoutError from psycopg2.extras import register_default_json, register_default_jsonb from sql import Cast, Flavor, For, Table from sql.conditionals import Coalesce from sql.functions import Function from sql.operators import BinaryOperator, Concat from trytond.backend.database import DatabaseInterface, SQLType from trytond.config import config, parse_uri from trytond.tools.gevent import is_gevent_monkey_patched __all__ = [ 'Database', 'DatabaseIntegrityError', 'DatabaseDataError', 'DatabaseOperationalError', 'DatabaseTimeoutError'] logger = logging.getLogger(__name__) os.environ['PGTZ'] = os.environ.get('TZ', '') _timeout = config.getint('database', 'timeout') _minconn = config.getint('database', 'minconn', default=1) _maxconn = config.getint('database', 'maxconn', default=64) _default_name = config.get('database', 'default_name', default='template1') def unescape_quote(s): if s.startswith('"') and s.endswith('"'): return s.strip('"').replace('""', '"') return s def replace_special_values(s, **mapping): for name, value in mapping.items(): s = s.replace('$' + name, value) return s class LoggingCursor(cursor): def execute(self, sql, args=None): if logger.isEnabledFor(logging.DEBUG): logger.debug(self.mogrify(sql, args)) cursor.execute(self, sql, args) class ForSkipLocked(For): def __str__(self): assert not self.nowait, "Can not use both NO WAIT and SKIP LOCKED" return super().__str__() + (' SKIP LOCKED' if not self.nowait else '') class Unaccent(Function): __slots__ = () _function = config.get('database', 'unaccent_function', default='unaccent') class Similarity(Function): __slots__ = () _function = config.get( 'database', 'similarity_function', default='similarity') class Match(BinaryOperator): __slots__ = () _operator = '@@' class ToTsvector(Function): __slots__ = () _function = 'to_tsvector' class Setweight(Function): __slots__ = () _function = 'setweight' class TsQuery(Function): __slots__ = () class ToTsQuery(TsQuery): __slots__ = () _function = 'to_tsquery' class PlainToTsQuery(TsQuery): __slots__ = () _function = 'plainto_tsquery' class PhraseToTsQuery(TsQuery): __slots__ = () _function = 'phraseto_tsquery' class WebsearchToTsQuery(TsQuery): __slots__ = () _function = 'websearch_to_tsquery' class TsRank(Function): __slots__ = () _function = 'ts_rank' class AdvisoryLock(Function): _function = 'pg_advisory_xact_lock' class TryAdvisoryLock(Function): _function = 'pg_try_advisory_xact_lock' class JSONBExtractPath(Function): __slots__ = () _function = 'jsonb_extract_path' class JSONKeyExists(BinaryOperator): __slots__ = () _operator = '?' class _BinaryOperatorArray(BinaryOperator): "Binary Operator that convert list into Array" @property def _operands(self): if isinstance(self.right, list): return (self.left, None) return super()._operands @property def params(self): params = super().params if isinstance(self.right, list): params = params[:-1] + (self.right,) return params class JSONAnyKeyExist(_BinaryOperatorArray): __slots__ = () _operator = '?|' class JSONAllKeyExist(_BinaryOperatorArray): __slots__ = () _operator = '?&' class JSONContains(BinaryOperator): __slots__ = () _operator = '@>' class Database(DatabaseInterface): index_translators = [] _lock = RLock() _databases = defaultdict(dict) _connpool = None _list_cache = {} _list_cache_timestamp = {} _search_path = None _current_user = None _has_returning = None _has_select_for_skip_locked = None _has_proc = defaultdict(lambda: defaultdict(dict)) _extensions = defaultdict(dict) _search_full_text_languages = defaultdict(dict) flavor = Flavor(ilike=True) TYPES_MAPPING = { 'SMALLINT': SQLType('INT2', 'INT2'), 'BIGINT': SQLType('INT8', 'INT8'), 'BLOB': SQLType('BYTEA', 'BYTEA'), 'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP(0)'), 'REAL': SQLType('FLOAT4', 'FLOAT4'), 'FLOAT': SQLType('FLOAT8', 'FLOAT8'), 'FULLTEXT': SQLType('TSVECTOR', 'TSVECTOR'), 'INTEGER': SQLType('INT4', 'INT4'), 'JSON': SQLType('JSONB', 'JSONB'), 'TIMESTAMP': SQLType('TIMESTAMP', 'TIMESTAMP(6)'), } def __new__(cls, name=_default_name): with cls._lock: now = datetime.now() databases = cls._databases[os.getpid()] for database in list(databases.values()): if ((now - database._last_use).total_seconds() > _timeout and database.name != name and not database._connpool._used): database.close() if name in databases: inst = databases[name] else: inst = DatabaseInterface.__new__(cls, name=name) try: inst._connpool = ThreadedConnectionPool( _minconn, _maxconn, **cls._connection_params(name), cursor_factory=LoggingCursor) except Exception: logger.error( 'connection to "%s" failed', name, exc_info=True) raise else: logger.info('connection to "%s" succeeded', name) databases[name] = inst inst._last_use = datetime.now() return inst def __init__(self, name=_default_name): super(Database, self).__init__(name) @classmethod def _connection_params(cls, name): uri = parse_uri(config.get('database', 'uri')) if uri.path and uri.path != '/': warnings.warn("The path specified in the URI will be overridden") params = { 'dsn': uri._replace(path=name).geturl(), 'fallback_application_name': os.environ.get( 'TRYTOND_APPNAME', 'trytond'), } return params def connect(self): return self def get_connection( self, autocommit=False, readonly=False, statement_timeout=None): retry = max(config.getint('database', 'retry'), _maxconn) for count in range(retry, -1, -1): try: conn = self._connpool.getconn() except (PoolError, DatabaseOperationalError): if count and not self._connpool.closed: logger.info('waiting a connection') time.sleep(1) continue raise except Exception: logger.error( 'connection to "%s" failed', self.name, exc_info=True) raise try: conn.set_session( isolation_level=ISOLATION_LEVEL_REPEATABLE_READ, readonly=readonly, autocommit=autocommit) with conn.cursor() as cur: if statement_timeout: cur.execute('SET statement_timeout=%s' % (statement_timeout * 1000)) else: # Detect disconnection cur.execute('SELECT 1') except DatabaseOperationalError: self._connpool.putconn(conn, close=True) continue break return conn def put_connection(self, connection, close=False): try: connection.reset() except InterfaceError: pass self._connpool.putconn(connection, close=close) def close(self): with self._lock: logger.info('disconnection from "%s"', self.name) self._connpool.closeall() self._databases[os.getpid()].pop(self.name) @classmethod def create(cls, connection, database_name, template='template0'): cursor = connection.cursor() cursor.execute( SQL( "CREATE DATABASE {} TEMPLATE {} ENCODING 'unicode'") .format( Identifier(database_name), Identifier(template))) connection.commit() cls._list_cache.clear() @classmethod def drop(cls, connection, database_name): cursor = connection.cursor() cursor.execute(SQL("DROP DATABASE {}") .format(Identifier(database_name))) cls._list_cache.clear() cls._has_proc.pop(database_name, None) cls._search_full_text_languages.pop(database_name, None) def get_version(self, connection): version = connection.server_version major, rest = divmod(int(version), 10000) minor, patch = divmod(rest, 100) return (major, minor, patch) def list(self, hostname=None): now = time.time() timeout = config.getint('session', 'timeout') res = self.__class__._list_cache.get(hostname) timestamp = self.__class__._list_cache_timestamp.get(hostname, now) if res and abs(timestamp - now) < timeout: return res connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SELECT datname FROM pg_database ' 'WHERE datistemplate = false ORDER BY datname') res = [] for db_name, in cursor: try: conn = connect(**self._connection_params(db_name)) try: with conn: if self._test(conn, hostname=hostname): res.append(db_name) finally: conn.close() except Exception: logger.debug( 'Test failed for "%s"', db_name, exc_info=True) continue finally: self.put_connection(connection, close=True) self.__class__._list_cache[hostname] = res self.__class__._list_cache_timestamp[hostname] = now return res def init(self): from trytond.modules import get_module_info connection = self.get_connection() try: cursor = connection.cursor() sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ['ir', 'res']: info = get_module_info(module) cursor.execute('INSERT INTO ir_module ' '(create_uid, create_date, name, state) ' 'VALUES (%s, now(), %s, %s) ' 'RETURNING id', (0, module, 'to activate')) module_id = cursor.fetchone()[0] for dependency in info.get('depends', []): cursor.execute('INSERT INTO ir_module_dependency ' '(create_uid, create_date, module, name) ' 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency)) connection.commit() finally: self.put_connection(connection) def test(self, hostname=None): try: connection = self.get_connection() except Exception: logger.debug('Test failed for "%s"', self.name, exc_info=True) return False try: return self._test(connection, hostname=hostname) finally: self.put_connection(connection, close=True) @classmethod def _test(cls, connection, hostname=None): cursor = connection.cursor() tables = ('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu', 'res_user', 'res_group', 'ir_module', 'ir_module_dependency', 'ir_translation', 'ir_lang', 'ir_configuration') cursor.execute('SELECT table_name FROM information_schema.tables ' 'WHERE table_name IN %s', (tables,)) if len(cursor.fetchall()) != len(tables): return False if hostname: try: cursor.execute( 'SELECT hostname FROM ir_configuration') hostnames = {h for h, in cursor if h} if hostnames and hostname not in hostnames: return False except ProgrammingError: pass return True def nextid(self, connection, table, count=1): cursor = connection.cursor() cursor.execute( "SELECT nextval(pg_get_serial_sequence(format(%s, %s), %s)) " "FROM generate_series(1, %s)", ('%I', table, 'id', count)) if count == 1: return cursor.fetchone()[0] else: return [id for id, in cursor] def setnextid(self, connection, table, value): cursor = connection.cursor() cursor.execute( "SELECT setval(pg_get_serial_sequence(format(%s, %s), %s), %s)", ('%I', table, 'id', value)) def currid(self, connection, table): cursor = connection.cursor() cursor.execute( "SELECT pg_get_serial_sequence(format(%s, %s), %s)", ('%I', table, 'id')) sequence_name, = cursor.fetchone() cursor.execute(f"SELECT last_value FROM {sequence_name}") return cursor.fetchone()[0] def lock(self, connection, table): cursor = connection.cursor() cursor.execute(SQL('LOCK {} IN EXCLUSIVE MODE NOWAIT').format( Identifier(table))) def lock_id(self, id, timeout=None): if not timeout: return TryAdvisoryLock(id) else: return AdvisoryLock(id) def has_constraint(self, constraint): return True def has_multirow_insert(self): return True def get_table_schema(self, connection, table_name): cursor = connection.cursor() for schema in self.search_path: cursor.execute('SELECT 1 ' 'FROM information_schema.tables ' 'WHERE table_name = %s AND table_schema = %s', (table_name, schema)) if cursor.rowcount: return schema @property def current_user(self): if self._current_user is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SELECT current_user') self._current_user = cursor.fetchone()[0] finally: self.put_connection(connection) return self._current_user @property def search_path(self): if self._search_path is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SHOW search_path') path, = cursor.fetchone() special_values = { 'user': self.current_user, } self._search_path = [ unescape_quote(replace_special_values( p.strip(), **special_values)) for p in path.split(',')] finally: self.put_connection(connection) return self._search_path def has_returning(self): if self._has_returning is None: connection = self.get_connection() try: # RETURNING clause is available since PostgreSQL 8.2 self._has_returning = self.get_version(connection) >= (8, 2) finally: self.put_connection(connection) return self._has_returning def has_select_for(self): return True def get_select_for_skip_locked(self): if self._has_select_for_skip_locked is None: connection = self.get_connection() try: # SKIP LOCKED clause is available since PostgreSQL 9.5 self._has_select_for_skip_locked = ( self.get_version(connection) >= (9, 5)) finally: self.put_connection(connection) if self._has_select_for_skip_locked: return ForSkipLocked else: return For def has_window_functions(self): return True @classmethod def has_sequence(cls): return True def has_proc(self, name, property='oid'): if (name in self._has_proc[self.name] and property in self._has_proc[self.name][name]): return self._has_proc[self.name][name][property] connection = self.get_connection() result = False try: cursor = connection.cursor() cursor.execute( SQL('SELECT {} FROM pg_proc WHERE proname=%s').format( Identifier(property)), (name,)) result = cursor.fetchone() if result: result, = result finally: self.put_connection(connection) self._has_proc[self.name][name][property] = result return result def has_unaccent(self): return self.has_proc(Unaccent._function) def has_unaccent_indexable(self): return self.has_proc(Unaccent._function, 'provolatile') == 'i' def has_similarity(self): return self.has_proc(Similarity._function) def similarity(self, column, value): return Similarity(column, value) def has_search_full_text(self): return True def _search_full_text_language(self, language): languages = self._search_full_text_languages[self.name] if language not in languages: lang = Table('ir_lang') connection = self.get_connection() try: cursor = connection.cursor() cursor.execute(*lang.select( Coalesce(lang.pg_text_search, 'simple'), where=lang.code == language, limit=1)) config_name, = cursor.fetchone() finally: self.put_connection(connection) languages[language] = config_name else: config_name = languages[language] return config_name def format_full_text(self, *documents, language=None): size = max(len(documents) // 4, 1) if len(documents) > 1: weights = chain( ['A'] * size, ['B'] * size, ['C'] * size, repeat('D')) else: weights = [None] expression = None if language: config_name = self._search_full_text_language(language) else: config_name = None for document, weight in zip(documents, weights): if not document: continue if config_name: ts_vector = ToTsvector(config_name, document) else: ts_vector = ToTsvector('simple', document) if weight: ts_vector = Setweight(ts_vector, weight) if expression is None: expression = ts_vector else: expression = Concat(expression, ts_vector) return expression def format_full_text_query(self, query, language=None): connection = self.get_connection() try: version = self.get_version(connection) finally: self.put_connection(connection) if not isinstance(query, TsQuery): if version >= (11, 0): ToTsQuery = WebsearchToTsQuery else: ToTsQuery = PlainToTsQuery if language: config_name = self._search_full_text_language(language) else: config_name = 'simple' query = ToTsQuery(config_name, query) return query def search_full_text(self, document, query): return Match(document, query) def rank_full_text(self, document, query, normalize=None): # TODO: weights and cover density norm_int = 0 if normalize: values = { 'document log': 1, 'document': 2, 'mean': 4, 'word': 8, 'word log': 16, 'rank': 32, } for norm in normalize: norm_int |= values.get(norm, 0) return TsRank(document, query, norm_int) def sql_type(self, type_): if type_ in self.TYPES_MAPPING: return self.TYPES_MAPPING[type_] if type_.startswith('VARCHAR'): return SQLType('VARCHAR', type_) return SQLType(type_, type_) def sql_format(self, type_, value): if type_ == 'BLOB': if value is not None: return Binary(value) return value def unaccent(self, value): if self.has_unaccent(): return Unaccent(value) return value def sequence_exist(self, connection, name): cursor = connection.cursor() for schema in self.search_path: cursor.execute('SELECT 1 ' 'FROM information_schema.sequences ' 'WHERE sequence_name = %s AND sequence_schema = %s', (name, schema)) if cursor.rowcount: return True return False def sequence_create( self, connection, name, number_increment=1, start_value=1): cursor = connection.cursor() cursor.execute( SQL("CREATE SEQUENCE {} INCREMENT BY %s START WITH %s").format( Identifier(name)), (number_increment, start_value)) def sequence_update( self, connection, name, number_increment=1, start_value=1): cursor = connection.cursor() cursor.execute( SQL("ALTER SEQUENCE {} INCREMENT BY %s RESTART WITH %s").format( Identifier(name)), (number_increment, start_value)) def sequence_rename(self, connection, old_name, new_name): cursor = connection.cursor() if (self.sequence_exist(connection, old_name) and not self.sequence_exist(connection, new_name)): cursor.execute( SQL("ALTER TABLE {} RENAME TO {}").format( Identifier(old_name), Identifier(new_name))) def sequence_delete(self, connection, name): cursor = connection.cursor() cursor.execute(SQL("DROP SEQUENCE {}").format( Identifier(name))) def sequence_next_number(self, connection, name): cursor = connection.cursor() version = self.get_version(connection) if version >= (10, 0): cursor.execute( 'SELECT increment_by ' 'FROM pg_sequences ' 'WHERE sequencename=%s', (name,)) increment, = cursor.fetchone() cursor.execute( SQL( 'SELECT CASE WHEN NOT is_called THEN last_value ' 'ELSE last_value + %s ' 'END ' 'FROM {}').format(Identifier(name)), (increment,)) else: cursor.execute( SQL( 'SELECT CASE WHEN NOT is_called THEN last_value ' 'ELSE last_value + increment_by ' 'END ' 'FROM {}').format(Identifier(name))) return cursor.fetchone()[0] def has_channel(self): return True def has_extension(self, extension_name): if extension_name in self._extensions[self.name]: return self._extensions[self.name][extension_name] connection = self.get_connection() result = False try: cursor = connection.cursor() cursor.execute( "SELECT 1 FROM pg_extension WHERE extname=%s", (extension_name,)) result = bool(cursor.rowcount) finally: self.put_connection(connection) self._extensions[self.name][extension_name] = result return result def json_get(self, column, key=None): column = Cast(column, 'jsonb') if key: column = JSONBExtractPath(column, key) return column def json_key_exists(self, column, key): return JSONKeyExists(Cast(column, 'jsonb'), key) def json_any_keys_exist(self, column, keys): return JSONAnyKeyExist(Cast(column, 'jsonb'), keys) def json_all_keys_exist(self, column, keys): return JSONAllKeyExist(Cast(column, 'jsonb'), keys) def json_contains(self, column, json): return JSONContains(Cast(column, 'jsonb'), Cast(json, 'jsonb')) register_type(UNICODE) if PYDATE: register_type(PYDATE) if PYDATETIME: register_type(PYDATETIME) if PYTIME: register_type(PYTIME) if PYINTERVAL: register_type(PYINTERVAL) register_adapter(float, lambda value: AsIs(repr(value))) register_adapter(Decimal, lambda value: AsIs(str(value))) def convert_json(value): from trytond.protocols.jsonrpc import JSONDecoder return json.loads(value, object_hook=JSONDecoder()) register_default_json(loads=convert_json) register_default_jsonb(loads=convert_json) if is_gevent_monkey_patched(): from psycopg2.extensions import set_wait_callback from psycopg2.extras import wait_select set_wait_callback(wait_select)