# This file is part of Tryton. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. import logging import re from psycopg2.sql import SQL, Identifier from trytond.backend.table import ( IndexTranslatorInterface, TableHandlerInterface) from trytond.transaction import Transaction __all__ = ['TableHandler'] logger = logging.getLogger(__name__) VARCHAR_SIZE_RE = re.compile(r'VARCHAR\(([0-9]+)\)') class TableHandler(TableHandlerInterface): namedatalen = 64 index_translators = [] def _init(self, model, history=False): super()._init(model, history=history) self.__columns = None self.__constraints = None self.__fk_deltypes = None self.__indexes = None transaction = Transaction() cursor = transaction.connection.cursor() # Create new table if necessary if not self.table_exist(self.table_name): cursor.execute(SQL('CREATE TABLE {} ()').format( Identifier(self.table_name))) self.table_schema = transaction.database.get_table_schema( transaction.connection, self.table_name) cursor.execute('SELECT tableowner = current_user FROM pg_tables ' 'WHERE tablename = %s AND schemaname = %s', (self.table_name, self.table_schema)) self.is_owner, = cursor.fetchone() if model.__doc__ and self.is_owner: cursor.execute(SQL('COMMENT ON TABLE {} IS %s').format( Identifier(self.table_name)), (model.__doc__,)) def migrate_to_identity(table, column): previous_seq_name = f"{table}_{column}_seq" cursor.execute( "SELECT nextval(format(%s, %s))", ('%I', previous_seq_name,)) next_val, = cursor.fetchone() cursor.execute( "SELECT seqincrement, seqmax, seqmin, seqcache " "FROM pg_sequence WHERE seqrelid = %s::regclass", (previous_seq_name,)) increment, s_max, s_min, cache = cursor.fetchone() # Previously created sequences were setting bigint values for those # identity column mimic the type of the underlying column if (s_max > 2 ** 31 - 1 and self._columns[column]['typname'] != 'int8'): s_max = 2 ** 31 - 1 if (s_min < -(2 ** 31) and self._columns[column]['typname'] != 'int8'): s_min = -(2 ** 31) cursor.execute( SQL("ALTER TABLE {} ALTER COLUMN {} DROP DEFAULT").format( Identifier(table), Identifier(column))) cursor.execute( SQL("DROP SEQUENCE {}").format( Identifier(previous_seq_name))) cursor.execute( SQL("ALTER TABLE {} ALTER COLUMN {} " "ADD GENERATED BY DEFAULT AS IDENTITY").format( Identifier(table), Identifier(column))) cursor.execute( "SELECT pg_get_serial_sequence(format(%s, %s), %s)", ('%I', table, column)) serial_seq_name, = cursor.fetchone() cursor.execute( (f"ALTER SEQUENCE {serial_seq_name} INCREMENT BY %s " "MINVALUE %s MAXVALUE %s RESTART WITH %s CACHE %s"), (increment, s_min, s_max, next_val, cache)) update_definitions = False if 'id' not in self._columns: update_definitions = True if not self.history: cursor.execute( SQL( "ALTER TABLE {} ADD COLUMN id INTEGER " "GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY").format( Identifier(self.table_name))) else: cursor.execute( SQL('ALTER TABLE {} ADD COLUMN id INTEGER') .format(Identifier(self.table_name))) else: if not self.history and not self.__columns['id']['identity']: update_definitions = True migrate_to_identity(self.table_name, 'id') if self.history and '__id' not in self._columns: update_definitions = True cursor.execute( SQL( "ALTER TABLE {} ADD COLUMN __id INTEGER " "GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY").format( Identifier(self.table_name))) elif self.history: if not self.__columns['__id']['identity']: update_definitions = True cursor.execute( SQL("ALTER TABLE {} ALTER COLUMN id DROP DEFAULT").format( Identifier(self.table_name))) migrate_to_identity(self.table_name, '__id') if update_definitions: self._update_definitions(columns=True) @classmethod def table_exist(cls, table_name): transaction = Transaction() return bool(transaction.database.get_table_schema( transaction.connection, table_name)) @classmethod def table_rename(cls, old_name, new_name): transaction = Transaction() cursor = transaction.connection.cursor() # Rename table if (cls.table_exist(old_name) and not cls.table_exist(new_name)): cursor.execute(SQL('ALTER TABLE {} RENAME TO {}').format( Identifier(old_name), Identifier(new_name))) # Migrate from 6.6: rename old sequence old_sequence = old_name + '_id_seq' new_sequence = new_name + '_id_seq' transaction.database.sequence_rename( transaction.connection, old_sequence, new_sequence) # Rename history table old_history = old_name + "__history" new_history = new_name + "__history" if (cls.table_exist(old_history) and not cls.table_exist(new_history)): cursor.execute('ALTER TABLE "%s" RENAME TO "%s"' % (old_history, new_history)) def column_exist(self, column_name): return column_name in self._columns def column_rename(self, old_name, new_name): cursor = Transaction().connection.cursor() if self.column_exist(old_name): if not self.column_exist(new_name): cursor.execute(SQL( 'ALTER TABLE {} RENAME COLUMN {} TO {}').format( Identifier(self.table_name), Identifier(old_name), Identifier(new_name))) self._update_definitions(columns=True) else: logger.warning( 'Unable to rename column %s on table %s to %s.', old_name, self.table_name, new_name) @property def _columns(self): if self.__columns is None: cursor = Transaction().connection.cursor() self.__columns = {} # Fetch columns definitions from the table cursor.execute('SELECT ' 'column_name, udt_name, is_nullable, ' 'character_maximum_length, ' 'column_default, is_identity ' 'FROM information_schema.columns ' 'WHERE table_name = %s AND table_schema = %s', (self.table_name, self.table_schema)) for column, typname, nullable, size, default, identity in cursor: self.__columns[column] = { 'typname': typname, 'notnull': True if nullable == 'NO' else False, 'size': size, 'default': default, 'identity': False if identity == 'NO' else True, } return self.__columns @property def _constraints(self): if self.__constraints is None: cursor = Transaction().connection.cursor() # fetch constraints for the table cursor.execute('SELECT constraint_name ' 'FROM information_schema.table_constraints ' 'WHERE table_name = %s AND table_schema = %s', (self.table_name, self.table_schema)) self.__constraints = [c for c, in cursor] # add nonstandard exclude constraint cursor.execute('SELECT c.conname ' 'FROM pg_namespace nc, ' 'pg_namespace nr, ' 'pg_constraint c, ' 'pg_class r ' 'WHERE nc.oid = c.connamespace AND nr.oid = r.relnamespace ' 'AND c.conrelid = r.oid ' "AND c.contype = 'x' " # exclude type "AND r.relkind IN ('r', 'p') " 'AND r.relname = %s AND nr.nspname = %s', (self.table_name, self.table_schema)) self.__constraints.extend((c for c, in cursor)) return self.__constraints @property def _fk_deltypes(self): if self.__fk_deltypes is None: cursor = Transaction().connection.cursor() cursor.execute('SELECT k.column_name, r.delete_rule ' 'FROM information_schema.key_column_usage AS k ' 'JOIN information_schema.referential_constraints AS r ' 'ON r.constraint_schema = k.constraint_schema ' 'AND r.constraint_name = k.constraint_name ' 'WHERE k.table_name = %s AND k.table_schema = %s', (self.table_name, self.table_schema)) self.__fk_deltypes = dict(cursor) return self.__fk_deltypes @property def _indexes(self): if self.__indexes is None: cursor = Transaction().connection.cursor() # Fetch indexes defined for the table cursor.execute("SELECT cl2.relname " "FROM pg_index ind " "JOIN pg_class cl on (cl.oid = ind.indrelid) " "JOIN pg_namespace n ON (cl.relnamespace = n.oid) " "JOIN pg_class cl2 on (cl2.oid = ind.indexrelid) " "WHERE cl.relname = %s AND n.nspname = %s " "AND NOT ind.indisprimary AND NOT ind.indisunique", (self.table_name, self.table_schema)) self.__indexes = [l[0] for l in cursor] return self.__indexes def _update_definitions(self, columns=None, constraints=None): if columns is None and constraints is None: columns = constraints = True if columns: self.__columns = None if constraints: self.__constraints = None self.__fk_deltypes = None def alter_size(self, column_name, column_type): cursor = Transaction().connection.cursor() cursor.execute( SQL("ALTER TABLE {} ALTER COLUMN {} TYPE {}").format( Identifier(self.table_name), Identifier(column_name), SQL(column_type))) self._update_definitions(columns=True) def alter_type(self, column_name, column_type): cursor = Transaction().connection.cursor() cursor.execute(SQL('ALTER TABLE {} ALTER {} TYPE {}').format( Identifier(self.table_name), Identifier(column_name), SQL(column_type))) self._update_definitions(columns=True) def column_is_type(self, column_name, type_, *, size=-1): db_type = self._columns[column_name]['typname'].upper() database = Transaction().database base_type = database.sql_type(type_).base.upper() if base_type == 'VARCHAR' and (size is None or size >= 0): same_size = self._columns[column_name]['size'] == size else: same_size = True return base_type == db_type and same_size def db_default(self, column_name, value): if value in [True, False]: test = str(value).lower() else: test = value if self._columns[column_name]['default'] != test: cursor = Transaction().connection.cursor() cursor.execute( SQL( 'ALTER TABLE {} ALTER COLUMN {} SET DEFAULT %s').format( Identifier(self.table_name), Identifier(column_name)), (value,)) def add_column(self, column_name, sql_type, default=None, comment=''): cursor = Transaction().connection.cursor() database = Transaction().database column_type = database.sql_type(sql_type) match = VARCHAR_SIZE_RE.match(sql_type) field_size = int(match.group(1)) if match else None def add_comment(): if comment and self.is_owner: cursor.execute( SQL('COMMENT ON COLUMN {}.{} IS %s').format( Identifier(self.table_name), Identifier(column_name)), (comment,)) if self.column_exist(column_name): if (column_name in ('create_date', 'write_date') and column_type[1].lower() != 'timestamp(6)'): # Migrate dates from timestamp(0) to timestamp cursor.execute( SQL( 'ALTER TABLE {} ALTER COLUMN {} TYPE timestamp') .format( Identifier(self.table_name), Identifier(column_name))) add_comment() base_type = column_type[0].lower() typname = self._columns[column_name]['typname'] if base_type != typname: if (typname, base_type) in [ ('varchar', 'text'), ('text', 'varchar'), ('date', 'timestamp'), ('int2', 'int4'), ('int2', 'float4'), ('int2', 'int8'), ('int2', 'float8'), ('int2', 'numeric'), ('int4', 'int8'), ('int4', 'float8'), ('int4', 'numeric'), ('int8', 'float8'), ('int8', 'numeric'), ('float4', 'numeric'), ('float4', 'float8'), ('float8', 'numeric'), ]: self.alter_type(column_name, base_type) elif (typname, base_type) in [ ('int8', 'int4'), ('int8', 'int2'), ('int4', 'int2'), ('float8', 'float4'), ]: pass else: logger.warning( 'Unable to migrate column %s on table %s ' 'from %s to %s.', column_name, self.table_name, typname, base_type) if base_type == typname == 'varchar': # Migrate size from_size = self._columns[column_name]['size'] if field_size is None: if from_size: self.alter_size(column_name, base_type) elif from_size == field_size: pass elif from_size and from_size < field_size: self.alter_size(column_name, column_type[1]) else: logger.warning( 'Unable to migrate column %s on table %s ' 'from varchar(%s) to varchar(%s).', column_name, self.table_name, from_size if from_size and from_size > 0 else "", field_size) return column_type = column_type[1] cursor.execute( SQL('ALTER TABLE {} ADD COLUMN {} {}').format( Identifier(self.table_name), Identifier(column_name), SQL(column_type))) add_comment() if default: # check if table is non-empty: cursor.execute('SELECT 1 FROM "%s" limit 1' % self.table_name) if cursor.rowcount: # Populate column with default values: cursor.execute( SQL('UPDATE {} SET {} = %s').format( Identifier(self.table_name), Identifier(column_name)), (default(),)) self._update_definitions(columns=True) def add_fk(self, columns, reference, ref_columns=None, on_delete=None): if on_delete is not None: on_delete = on_delete.upper() else: on_delete = 'SET NULL' if isinstance(columns, str): columns = [columns] cursor = Transaction().connection.cursor() if ref_columns: ref_columns_name = '_' + '_'.join(ref_columns) else: ref_columns_name = '' name = self.convert_name( self.table_name + '_' + '_'.join(columns) + ref_columns_name + '_fkey') if name in self._constraints: for column_name in columns: if self._fk_deltypes.get(column_name) != on_delete: self.drop_fk(columns, ref_columns) add = True break else: add = False else: add = True if add: columns = SQL(', ').join(map(Identifier, columns)) if not ref_columns: ref_columns = ['id'] ref_columns = SQL(', ').join(map(Identifier, ref_columns)) cursor.execute( SQL( "ALTER TABLE {table} " "ADD CONSTRAINT {constraint} " "FOREIGN KEY ({columns}) " "REFERENCES {reference} ({ref_columns}) " "ON DELETE {action}" ) .format( table=Identifier(self.table_name), constraint=Identifier(name), columns=columns, reference=Identifier(reference), ref_columns=ref_columns, action=SQL(on_delete))) self._update_definitions(constraints=True) def drop_fk(self, columns, ref_columns=None, table=None): if isinstance(columns, str): columns = [columns] if ref_columns: ref_columns_name = '_' + '_'.join(ref_columns) else: ref_columns_name = '' self.drop_constraint( '_'.join(columns) + ref_columns_name + '_fkey', table=table) def not_null_action(self, column_name, action='add'): if not self.column_exist(column_name): return with Transaction().connection.cursor() as cursor: if action == 'add': if self._columns[column_name]['notnull']: return cursor.execute(SQL( 'SELECT id FROM {} WHERE {} IS NULL LIMIT 1').format( Identifier(self.table_name), Identifier(column_name))) if not cursor.rowcount: cursor.execute( SQL( 'ALTER TABLE {} ALTER COLUMN {} SET NOT NULL') .format( Identifier(self.table_name), Identifier(column_name))) self._update_definitions(columns=True) else: logger.warning( "Unable to set not null on column %s of table %s.\n" "Try restarting one more time.\n" "If that doesn't work update the records and restart " "again.", column_name, self.table_name) elif action == 'remove': if not self._columns[column_name]['notnull']: return cursor.execute( SQL('ALTER TABLE {} ALTER COLUMN {} DROP NOT NULL') .format( Identifier(self.table_name), Identifier(column_name))) self._update_definitions(columns=True) else: raise Exception('Not null action not supported!') def add_constraint(self, ident, constraint): ident = self.convert_name(self.table_name + "_" + ident) if ident in self._constraints: # This constrain already exist return cursor = Transaction().connection.cursor() cursor.execute( SQL('ALTER TABLE {} ADD CONSTRAINT {} {}').format( Identifier(self.table_name), Identifier(ident), SQL(str(constraint))), constraint.params) self._update_definitions(constraints=True) def drop_constraint(self, ident, table=None): ident = self.convert_name((table or self.table_name) + "_" + ident) if ident not in self._constraints: return cursor = Transaction().connection.cursor() cursor.execute( SQL('ALTER TABLE {} DROP CONSTRAINT {}').format( Identifier(self.table_name), Identifier(ident))) self._update_definitions(constraints=True) def set_indexes(self, indexes, concurrently=False): cursor = Transaction().connection.cursor() old = set(self._indexes) for index in indexes: translator = self.index_translator_for(index) if translator: name, query, params = translator.definition(index) name = '_'.join([self.table_name, name]) name = 'idx_' + self.convert_name(name, reserved=len('idx_')) cursor.execute( 'SELECT idx.indisvalid ' 'FROM pg_index idx ' 'JOIN pg_class cls ON cls.oid = idx.indexrelid ' 'WHERE cls.relname = %s', (name,)) if (idx_valid := cursor.fetchone()) and not idx_valid[0]: cursor.execute( SQL("DROP INDEX {}").format(Identifier(name))) cursor.execute( SQL('CREATE INDEX {} IF NOT EXISTS {} ON {} USING {}') .format( SQL('CONCURRENTLY' if concurrently else ''), Identifier(name), Identifier(self.table_name), query), params) old.discard(name) for name in old: if name.startswith('idx_') or name.endswith('_index'): cursor.execute(SQL('DROP INDEX {}').format(Identifier(name))) self.__indexes = None def drop_column(self, column_name): if not self.column_exist(column_name): return cursor = Transaction().connection.cursor() cursor.execute(SQL('ALTER TABLE {} DROP COLUMN {}').format( Identifier(self.table_name), Identifier(column_name))) self._update_definitions(columns=True) @classmethod def drop_table(cls, model, table, cascade=False): cursor = Transaction().connection.cursor() cursor.execute('DELETE FROM ir_model_data WHERE model = %s', (model,)) query = 'DROP TABLE {}' if cascade: query = query + ' CASCADE' cursor.execute(SQL(query).format(Identifier(table))) class IndexMixin: _type = None def __init_subclass__(cls): TableHandler.index_translators.append(cls) @classmethod def definition(cls, index): expr_template = SQL('{expression} {collate} {opclass} {order}') indexed_expressions = cls._get_indexed_expressions(index) expressions = [] params = [] for expression, usage in indexed_expressions: expressions.append(expr_template.format( **cls._get_expression_variables(expression, usage))) params.extend(expression.params) include = SQL('') if index.options.get('include'): include = SQL('INCLUDE ({columns})').format( columns=SQL(',').join(map( lambda c: SQL(str(c)), index.options.get('include')))) where = SQL('') if index.options.get('where'): where = SQL('WHERE {where}').format( where=SQL(str(index.options['where']))) params.extend(index.options['where'].params) query = SQL('{type} ({expressions}) {include} {where}').format( type=SQL(cls._type), expressions=SQL(',').join(expressions), include=include, where=where) name = cls._get_name(query, params) return name, query, params @classmethod def _get_indexed_expressions(cls, index): return index.expressions @classmethod def _get_expression_variables(cls, expression, usage): variables = { 'expression': SQL(str(expression)), 'collate': SQL(''), 'opclass': SQL(''), 'order': SQL(''), } if usage.options.get('collation'): variables['collate'] = SQL('COLLATE {}').format( usage.options['collation']) if usage.options.get('order'): order = usage.options['order'].upper() variables['order'] = SQL(order) return variables class HashTranslator(IndexMixin, IndexTranslatorInterface): _type = 'HASH' @classmethod def score(cls, index): if (len(index.expressions) > 1 or index.expressions[0][1].__class__.__name__ != 'Equality'): return 0 if index.options.get('include'): return 0 return 100 @classmethod def _get_indexed_expressions(cls, index): return [ (e, u) for e, u in index.expressions if u.__class__.__name__ == 'Equality'][:1] class BTreeTranslator(IndexMixin, IndexTranslatorInterface): _type = 'BTREE' @classmethod def score(cls, index): score = 0 for _, usage in index.expressions: if usage.__class__.__name__ == 'Range': score += 100 elif usage.__class__.__name__ == 'Equality': score += 50 elif usage.__class__.__name__ == 'Similarity': score += 20 if usage.options.get('begin'): score += 100 return score @classmethod def _get_indexed_expressions(cls, index): return [ (e, u) for e, u in index.expressions if u.__class__.__name__ in {'Equality', 'Range', 'Similarity'}] @classmethod def _get_expression_variables(cls, expression, usage): params = super()._get_expression_variables(expression, usage) if (usage.__class__.__name__ == 'Similarity' and not usage.options.get('collation')): # text_pattern_ops and varchar_pattern_ops are the same params['opclass'] = SQL('varchar_pattern_ops') return params class TrigramTranslator(IndexMixin, IndexTranslatorInterface): _type = 'GIN' @classmethod def score(cls, index): database = Transaction().database has_btree_gin = database.has_extension('btree_gin') has_trigram = database.has_extension('pg_trgm') if not has_btree_gin and not has_trigram: return 0 score = 0 for _, usage in index.expressions: if usage.__class__.__name__ == 'Similarity': if has_trigram: score += 100 else: score += 50 elif has_btree_gin: if usage.__class__.__name__ == 'Range': score += 90 elif usage.__class__.__name__ == 'Equality': score += 40 else: return 0 return score @classmethod def _get_indexed_expressions(cls, index): database = Transaction().database has_btree_gin = database.has_extension('btree_gin') has_trigram = database.has_extension('pg_trgm') def filter(usage): if usage.__class__.__name__ == 'Similarity': return has_trigram elif usage.__class__.__name__ in {'Range', 'Equality'}: return has_btree_gin else: return False return [(e, u) for e, u in index.expressions if filter(u)] @classmethod def _get_expression_variables(cls, expression, usage): params = super()._get_expression_variables(expression, usage) if usage.__class__.__name__ == 'Similarity': params['opclass'] = SQL('gin_trgm_ops') return params