Files
tradon/backend/postgresql/database.py
2025-12-26 13:11:43 +00:00

815 lines
26 KiB
Python
Executable File

# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
import json
import logging
import os
import time
import warnings
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from itertools import chain, repeat
from threading import RLock
from psycopg2 import Binary, connect
from psycopg2.extensions import (
ISOLATION_LEVEL_REPEATABLE_READ, UNICODE, AsIs, cursor, register_adapter,
register_type)
from psycopg2.pool import PoolError, ThreadedConnectionPool
from psycopg2.sql import SQL, Identifier
try:
from psycopg2.extensions import PYDATE, PYDATETIME, PYINTERVAL, PYTIME
except ImportError:
PYDATE, PYDATETIME, PYTIME, PYINTERVAL = None, None, None, None
from psycopg2 import DataError as DatabaseDataError
from psycopg2 import IntegrityError as DatabaseIntegrityError
from psycopg2 import InterfaceError
from psycopg2 import OperationalError as DatabaseOperationalError
from psycopg2 import ProgrammingError
from psycopg2.errors import QueryCanceled as DatabaseTimeoutError
from psycopg2.extras import register_default_json, register_default_jsonb
from sql import Cast, Flavor, For, Table
from sql.conditionals import Coalesce
from sql.functions import Function
from sql.operators import BinaryOperator, Concat
from trytond.backend.database import DatabaseInterface, SQLType
from trytond.config import config, parse_uri
from trytond.tools.gevent import is_gevent_monkey_patched
__all__ = [
'Database',
'DatabaseIntegrityError', 'DatabaseDataError', 'DatabaseOperationalError',
'DatabaseTimeoutError']
logger = logging.getLogger(__name__)
os.environ['PGTZ'] = os.environ.get('TZ', '')
_timeout = config.getint('database', 'timeout')
_minconn = config.getint('database', 'minconn', default=1)
_maxconn = config.getint('database', 'maxconn', default=64)
_default_name = config.get('database', 'default_name', default='template1')
def unescape_quote(s):
if s.startswith('"') and s.endswith('"'):
return s.strip('"').replace('""', '"')
return s
def replace_special_values(s, **mapping):
for name, value in mapping.items():
s = s.replace('$' + name, value)
return s
class LoggingCursor(cursor):
def execute(self, sql, args=None):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(self.mogrify(sql, args))
cursor.execute(self, sql, args)
class ForSkipLocked(For):
def __str__(self):
assert not self.nowait, "Can not use both NO WAIT and SKIP LOCKED"
return super().__str__() + (' SKIP LOCKED' if not self.nowait else '')
class Unaccent(Function):
__slots__ = ()
_function = config.get('database', 'unaccent_function', default='unaccent')
class Similarity(Function):
__slots__ = ()
_function = config.get(
'database', 'similarity_function', default='similarity')
class Match(BinaryOperator):
__slots__ = ()
_operator = '@@'
class ToTsvector(Function):
__slots__ = ()
_function = 'to_tsvector'
class Setweight(Function):
__slots__ = ()
_function = 'setweight'
class TsQuery(Function):
__slots__ = ()
class ToTsQuery(TsQuery):
__slots__ = ()
_function = 'to_tsquery'
class PlainToTsQuery(TsQuery):
__slots__ = ()
_function = 'plainto_tsquery'
class PhraseToTsQuery(TsQuery):
__slots__ = ()
_function = 'phraseto_tsquery'
class WebsearchToTsQuery(TsQuery):
__slots__ = ()
_function = 'websearch_to_tsquery'
class TsRank(Function):
__slots__ = ()
_function = 'ts_rank'
class AdvisoryLock(Function):
_function = 'pg_advisory_xact_lock'
class TryAdvisoryLock(Function):
_function = 'pg_try_advisory_xact_lock'
class JSONBExtractPath(Function):
__slots__ = ()
_function = 'jsonb_extract_path'
class JSONKeyExists(BinaryOperator):
__slots__ = ()
_operator = '?'
class _BinaryOperatorArray(BinaryOperator):
"Binary Operator that convert list into Array"
@property
def _operands(self):
if isinstance(self.right, list):
return (self.left, None)
return super()._operands
@property
def params(self):
params = super().params
if isinstance(self.right, list):
params = params[:-1] + (self.right,)
return params
class JSONAnyKeyExist(_BinaryOperatorArray):
__slots__ = ()
_operator = '?|'
class JSONAllKeyExist(_BinaryOperatorArray):
__slots__ = ()
_operator = '?&'
class JSONContains(BinaryOperator):
__slots__ = ()
_operator = '@>'
class Database(DatabaseInterface):
index_translators = []
_lock = RLock()
_databases = defaultdict(dict)
_connpool = None
_list_cache = {}
_list_cache_timestamp = {}
_search_path = None
_current_user = None
_has_returning = None
_has_select_for_skip_locked = None
_has_proc = defaultdict(lambda: defaultdict(dict))
_extensions = defaultdict(dict)
_search_full_text_languages = defaultdict(dict)
flavor = Flavor(ilike=True)
TYPES_MAPPING = {
'SMALLINT': SQLType('INT2', 'INT2'),
'BIGINT': SQLType('INT8', 'INT8'),
'BLOB': SQLType('BYTEA', 'BYTEA'),
'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP(0)'),
'REAL': SQLType('FLOAT4', 'FLOAT4'),
'FLOAT': SQLType('FLOAT8', 'FLOAT8'),
'FULLTEXT': SQLType('TSVECTOR', 'TSVECTOR'),
'INTEGER': SQLType('INT4', 'INT4'),
'JSON': SQLType('JSONB', 'JSONB'),
'TIMESTAMP': SQLType('TIMESTAMP', 'TIMESTAMP(6)'),
}
def __new__(cls, name=_default_name):
with cls._lock:
now = datetime.now()
databases = cls._databases[os.getpid()]
for database in list(databases.values()):
if ((now - database._last_use).total_seconds() > _timeout
and database.name != name
and not database._connpool._used):
database.close()
if name in databases:
inst = databases[name]
else:
inst = DatabaseInterface.__new__(cls, name=name)
try:
inst._connpool = ThreadedConnectionPool(
_minconn, _maxconn, **cls._connection_params(name),
cursor_factory=LoggingCursor)
except Exception:
logger.error(
'connection to "%s" failed', name, exc_info=True)
raise
else:
logger.info('connection to "%s" succeeded', name)
databases[name] = inst
inst._last_use = datetime.now()
return inst
def __init__(self, name=_default_name):
super(Database, self).__init__(name)
@classmethod
def _connection_params(cls, name):
uri = parse_uri(config.get('database', 'uri'))
if uri.path and uri.path != '/':
warnings.warn("The path specified in the URI will be overridden")
params = {
'dsn': uri._replace(path=name).geturl(),
'fallback_application_name': os.environ.get(
'TRYTOND_APPNAME', 'trytond'),
}
return params
def connect(self):
return self
def get_connection(
self, autocommit=False, readonly=False, statement_timeout=None):
retry = max(config.getint('database', 'retry'), _maxconn)
for count in range(retry, -1, -1):
try:
conn = self._connpool.getconn()
except (PoolError, DatabaseOperationalError):
if count and not self._connpool.closed:
logger.info('waiting a connection')
time.sleep(1)
continue
raise
except Exception:
logger.error(
'connection to "%s" failed', self.name, exc_info=True)
raise
try:
conn.set_session(
isolation_level=ISOLATION_LEVEL_REPEATABLE_READ,
readonly=readonly,
autocommit=autocommit)
with conn.cursor() as cur:
if statement_timeout:
cur.execute('SET statement_timeout=%s' %
(statement_timeout * 1000))
else:
# Detect disconnection
cur.execute('SELECT 1')
except DatabaseOperationalError:
self._connpool.putconn(conn, close=True)
continue
break
return conn
def put_connection(self, connection, close=False):
try:
connection.reset()
except InterfaceError:
pass
self._connpool.putconn(connection, close=close)
def close(self):
with self._lock:
logger.info('disconnection from "%s"', self.name)
self._connpool.closeall()
self._databases[os.getpid()].pop(self.name)
@classmethod
def create(cls, connection, database_name, template='template0'):
cursor = connection.cursor()
cursor.execute(
SQL(
"CREATE DATABASE {} TEMPLATE {} ENCODING 'unicode'")
.format(
Identifier(database_name),
Identifier(template)))
connection.commit()
cls._list_cache.clear()
@classmethod
def drop(cls, connection, database_name):
cursor = connection.cursor()
cursor.execute(SQL("DROP DATABASE {}")
.format(Identifier(database_name)))
cls._list_cache.clear()
cls._has_proc.pop(database_name, None)
cls._search_full_text_languages.pop(database_name, None)
def get_version(self, connection):
version = connection.server_version
major, rest = divmod(int(version), 10000)
minor, patch = divmod(rest, 100)
return (major, minor, patch)
def list(self, hostname=None):
now = time.time()
timeout = config.getint('session', 'timeout')
res = self.__class__._list_cache.get(hostname)
timestamp = self.__class__._list_cache_timestamp.get(hostname, now)
if res and abs(timestamp - now) < timeout:
return res
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute('SELECT datname FROM pg_database '
'WHERE datistemplate = false ORDER BY datname')
res = []
for db_name, in cursor:
try:
conn = connect(**self._connection_params(db_name))
try:
with conn:
if self._test(conn, hostname=hostname):
res.append(db_name)
finally:
conn.close()
except Exception:
logger.debug(
'Test failed for "%s"', db_name, exc_info=True)
continue
finally:
self.put_connection(connection, close=True)
self.__class__._list_cache[hostname] = res
self.__class__._list_cache_timestamp[hostname] = now
return res
def init(self):
from trytond.modules import get_module_info
connection = self.get_connection()
try:
cursor = connection.cursor()
sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
with open(sql_file) as fp:
for line in fp.read().split(';'):
if (len(line) > 0) and (not line.isspace()):
cursor.execute(line)
for module in ['ir', 'res']:
info = get_module_info(module)
cursor.execute('INSERT INTO ir_module '
'(create_uid, create_date, name, state) '
'VALUES (%s, now(), %s, %s) '
'RETURNING id',
(0, module, 'to activate'))
module_id = cursor.fetchone()[0]
for dependency in info.get('depends', []):
cursor.execute('INSERT INTO ir_module_dependency '
'(create_uid, create_date, module, name) '
'VALUES (%s, now(), %s, %s)',
(0, module_id, dependency))
connection.commit()
finally:
self.put_connection(connection)
def test(self, hostname=None):
try:
connection = self.get_connection()
except Exception:
logger.debug('Test failed for "%s"', self.name, exc_info=True)
return False
try:
return self._test(connection, hostname=hostname)
finally:
self.put_connection(connection, close=True)
@classmethod
def _test(cls, connection, hostname=None):
cursor = connection.cursor()
tables = ('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
'ir_translation', 'ir_lang', 'ir_configuration')
cursor.execute('SELECT table_name FROM information_schema.tables '
'WHERE table_name IN %s', (tables,))
if len(cursor.fetchall()) != len(tables):
return False
if hostname:
try:
cursor.execute(
'SELECT hostname FROM ir_configuration')
hostnames = {h for h, in cursor if h}
if hostnames and hostname not in hostnames:
return False
except ProgrammingError:
pass
return True
def nextid(self, connection, table, count=1):
cursor = connection.cursor()
cursor.execute(
"SELECT nextval(pg_get_serial_sequence(format(%s, %s), %s)) "
"FROM generate_series(1, %s)",
('%I', table, 'id', count))
if count == 1:
return cursor.fetchone()[0]
else:
return [id for id, in cursor]
def setnextid(self, connection, table, value):
cursor = connection.cursor()
cursor.execute(
"SELECT setval(pg_get_serial_sequence(format(%s, %s), %s), %s)",
('%I', table, 'id', value))
def currid(self, connection, table):
cursor = connection.cursor()
cursor.execute(
"SELECT pg_get_serial_sequence(format(%s, %s), %s)",
('%I', table, 'id'))
sequence_name, = cursor.fetchone()
cursor.execute(f"SELECT last_value FROM {sequence_name}")
return cursor.fetchone()[0]
def lock(self, connection, table):
cursor = connection.cursor()
cursor.execute(SQL('LOCK {} IN EXCLUSIVE MODE NOWAIT').format(
Identifier(table)))
def lock_id(self, id, timeout=None):
if not timeout:
return TryAdvisoryLock(id)
else:
return AdvisoryLock(id)
def has_constraint(self, constraint):
return True
def has_multirow_insert(self):
return True
def get_table_schema(self, connection, table_name):
cursor = connection.cursor()
for schema in self.search_path:
cursor.execute('SELECT 1 '
'FROM information_schema.tables '
'WHERE table_name = %s AND table_schema = %s',
(table_name, schema))
if cursor.rowcount:
return schema
@property
def current_user(self):
if self._current_user is None:
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute('SELECT current_user')
self._current_user = cursor.fetchone()[0]
finally:
self.put_connection(connection)
return self._current_user
@property
def search_path(self):
if self._search_path is None:
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute('SHOW search_path')
path, = cursor.fetchone()
special_values = {
'user': self.current_user,
}
self._search_path = [
unescape_quote(replace_special_values(
p.strip(), **special_values))
for p in path.split(',')]
finally:
self.put_connection(connection)
return self._search_path
def has_returning(self):
if self._has_returning is None:
connection = self.get_connection()
try:
# RETURNING clause is available since PostgreSQL 8.2
self._has_returning = self.get_version(connection) >= (8, 2)
finally:
self.put_connection(connection)
return self._has_returning
def has_select_for(self):
return True
def get_select_for_skip_locked(self):
if self._has_select_for_skip_locked is None:
connection = self.get_connection()
try:
# SKIP LOCKED clause is available since PostgreSQL 9.5
self._has_select_for_skip_locked = (
self.get_version(connection) >= (9, 5))
finally:
self.put_connection(connection)
if self._has_select_for_skip_locked:
return ForSkipLocked
else:
return For
def has_window_functions(self):
return True
@classmethod
def has_sequence(cls):
return True
def has_proc(self, name, property='oid'):
if (name in self._has_proc[self.name]
and property in self._has_proc[self.name][name]):
return self._has_proc[self.name][name][property]
connection = self.get_connection()
result = False
try:
cursor = connection.cursor()
cursor.execute(
SQL('SELECT {} FROM pg_proc WHERE proname=%s').format(
Identifier(property)), (name,))
result = cursor.fetchone()
if result:
result, = result
finally:
self.put_connection(connection)
self._has_proc[self.name][name][property] = result
return result
def has_unaccent(self):
return self.has_proc(Unaccent._function)
def has_unaccent_indexable(self):
return self.has_proc(Unaccent._function, 'provolatile') == 'i'
def has_similarity(self):
return self.has_proc(Similarity._function)
def similarity(self, column, value):
return Similarity(column, value)
def has_search_full_text(self):
return True
def _search_full_text_language(self, language):
languages = self._search_full_text_languages[self.name]
if language not in languages:
lang = Table('ir_lang')
connection = self.get_connection()
try:
cursor = connection.cursor()
cursor.execute(*lang.select(
Coalesce(lang.pg_text_search, 'simple'),
where=lang.code == language,
limit=1))
config_name, = cursor.fetchone()
finally:
self.put_connection(connection)
languages[language] = config_name
else:
config_name = languages[language]
return config_name
def format_full_text(self, *documents, language=None):
size = max(len(documents) // 4, 1)
if len(documents) > 1:
weights = chain(
['A'] * size, ['B'] * size, ['C'] * size, repeat('D'))
else:
weights = [None]
expression = None
if language:
config_name = self._search_full_text_language(language)
else:
config_name = None
for document, weight in zip(documents, weights):
if not document:
continue
if config_name:
ts_vector = ToTsvector(config_name, document)
else:
ts_vector = ToTsvector('simple', document)
if weight:
ts_vector = Setweight(ts_vector, weight)
if expression is None:
expression = ts_vector
else:
expression = Concat(expression, ts_vector)
return expression
def format_full_text_query(self, query, language=None):
connection = self.get_connection()
try:
version = self.get_version(connection)
finally:
self.put_connection(connection)
if not isinstance(query, TsQuery):
if version >= (11, 0):
ToTsQuery = WebsearchToTsQuery
else:
ToTsQuery = PlainToTsQuery
if language:
config_name = self._search_full_text_language(language)
else:
config_name = 'simple'
query = ToTsQuery(config_name, query)
return query
def search_full_text(self, document, query):
return Match(document, query)
def rank_full_text(self, document, query, normalize=None):
# TODO: weights and cover density
norm_int = 0
if normalize:
values = {
'document log': 1,
'document': 2,
'mean': 4,
'word': 8,
'word log': 16,
'rank': 32,
}
for norm in normalize:
norm_int |= values.get(norm, 0)
return TsRank(document, query, norm_int)
def sql_type(self, type_):
if type_ in self.TYPES_MAPPING:
return self.TYPES_MAPPING[type_]
if type_.startswith('VARCHAR'):
return SQLType('VARCHAR', type_)
return SQLType(type_, type_)
def sql_format(self, type_, value):
if type_ == 'BLOB':
if value is not None:
return Binary(value)
return value
def unaccent(self, value):
if self.has_unaccent():
return Unaccent(value)
return value
def sequence_exist(self, connection, name):
cursor = connection.cursor()
for schema in self.search_path:
cursor.execute('SELECT 1 '
'FROM information_schema.sequences '
'WHERE sequence_name = %s AND sequence_schema = %s',
(name, schema))
if cursor.rowcount:
return True
return False
def sequence_create(
self, connection, name, number_increment=1, start_value=1):
cursor = connection.cursor()
cursor.execute(
SQL("CREATE SEQUENCE {} INCREMENT BY %s START WITH %s").format(
Identifier(name)),
(number_increment, start_value))
def sequence_update(
self, connection, name, number_increment=1, start_value=1):
cursor = connection.cursor()
cursor.execute(
SQL("ALTER SEQUENCE {} INCREMENT BY %s RESTART WITH %s").format(
Identifier(name)),
(number_increment, start_value))
def sequence_rename(self, connection, old_name, new_name):
cursor = connection.cursor()
if (self.sequence_exist(connection, old_name)
and not self.sequence_exist(connection, new_name)):
cursor.execute(
SQL("ALTER TABLE {} RENAME TO {}").format(
Identifier(old_name),
Identifier(new_name)))
def sequence_delete(self, connection, name):
cursor = connection.cursor()
cursor.execute(SQL("DROP SEQUENCE {}").format(
Identifier(name)))
def sequence_next_number(self, connection, name):
cursor = connection.cursor()
version = self.get_version(connection)
if version >= (10, 0):
cursor.execute(
'SELECT increment_by '
'FROM pg_sequences '
'WHERE sequencename=%s',
(name,))
increment, = cursor.fetchone()
cursor.execute(
SQL(
'SELECT CASE WHEN NOT is_called THEN last_value '
'ELSE last_value + %s '
'END '
'FROM {}').format(Identifier(name)),
(increment,))
else:
cursor.execute(
SQL(
'SELECT CASE WHEN NOT is_called THEN last_value '
'ELSE last_value + increment_by '
'END '
'FROM {}').format(Identifier(name)))
return cursor.fetchone()[0]
def has_channel(self):
return True
def has_extension(self, extension_name):
if extension_name in self._extensions[self.name]:
return self._extensions[self.name][extension_name]
connection = self.get_connection()
result = False
try:
cursor = connection.cursor()
cursor.execute(
"SELECT 1 FROM pg_extension WHERE extname=%s",
(extension_name,))
result = bool(cursor.rowcount)
finally:
self.put_connection(connection)
self._extensions[self.name][extension_name] = result
return result
def json_get(self, column, key=None):
column = Cast(column, 'jsonb')
if key:
column = JSONBExtractPath(column, key)
return column
def json_key_exists(self, column, key):
return JSONKeyExists(Cast(column, 'jsonb'), key)
def json_any_keys_exist(self, column, keys):
return JSONAnyKeyExist(Cast(column, 'jsonb'), keys)
def json_all_keys_exist(self, column, keys):
return JSONAllKeyExist(Cast(column, 'jsonb'), keys)
def json_contains(self, column, json):
return JSONContains(Cast(column, 'jsonb'), Cast(json, 'jsonb'))
register_type(UNICODE)
if PYDATE:
register_type(PYDATE)
if PYDATETIME:
register_type(PYDATETIME)
if PYTIME:
register_type(PYTIME)
if PYINTERVAL:
register_type(PYINTERVAL)
register_adapter(float, lambda value: AsIs(repr(value)))
register_adapter(Decimal, lambda value: AsIs(str(value)))
def convert_json(value):
from trytond.protocols.jsonrpc import JSONDecoder
return json.loads(value, object_hook=JSONDecoder())
register_default_json(loads=convert_json)
register_default_jsonb(loads=convert_json)
if is_gevent_monkey_patched():
from psycopg2.extensions import set_wait_callback
from psycopg2.extras import wait_select
set_wait_callback(wait_select)