732 lines
29 KiB
Python
Executable File
732 lines
29 KiB
Python
Executable File
# This file is part of Tryton. The COPYRIGHT file at the top level of
|
|
# this repository contains the full copyright notices and license terms.
|
|
import logging
|
|
import re
|
|
|
|
from psycopg2.sql import SQL, Identifier
|
|
|
|
from trytond.backend.table import (
|
|
IndexTranslatorInterface, TableHandlerInterface)
|
|
from trytond.transaction import Transaction
|
|
|
|
__all__ = ['TableHandler']
|
|
|
|
logger = logging.getLogger(__name__)
|
|
VARCHAR_SIZE_RE = re.compile(r'VARCHAR\(([0-9]+)\)')
|
|
|
|
|
|
class TableHandler(TableHandlerInterface):
|
|
namedatalen = 64
|
|
index_translators = []
|
|
|
|
def _init(self, model, history=False):
|
|
super()._init(model, history=history)
|
|
self.__columns = None
|
|
self.__constraints = None
|
|
self.__fk_deltypes = None
|
|
self.__indexes = None
|
|
|
|
transaction = Transaction()
|
|
cursor = transaction.connection.cursor()
|
|
|
|
# Create new table if necessary
|
|
if not self.table_exist(self.table_name):
|
|
cursor.execute(SQL('CREATE TABLE {} ()').format(
|
|
Identifier(self.table_name)))
|
|
self.table_schema = transaction.database.get_table_schema(
|
|
transaction.connection, self.table_name)
|
|
|
|
cursor.execute('SELECT tableowner = current_user FROM pg_tables '
|
|
'WHERE tablename = %s AND schemaname = %s',
|
|
(self.table_name, self.table_schema))
|
|
self.is_owner, = cursor.fetchone()
|
|
|
|
if model.__doc__ and self.is_owner:
|
|
cursor.execute(SQL('COMMENT ON TABLE {} IS %s').format(
|
|
Identifier(self.table_name)),
|
|
(model.__doc__,))
|
|
|
|
def migrate_to_identity(table, column):
|
|
previous_seq_name = f"{table}_{column}_seq"
|
|
cursor.execute(
|
|
"SELECT nextval(format(%s, %s))", ('%I', previous_seq_name,))
|
|
next_val, = cursor.fetchone()
|
|
cursor.execute(
|
|
"SELECT seqincrement, seqmax, seqmin, seqcache "
|
|
"FROM pg_sequence WHERE seqrelid = %s::regclass",
|
|
(previous_seq_name,))
|
|
increment, s_max, s_min, cache = cursor.fetchone()
|
|
# Previously created sequences were setting bigint values for those
|
|
# identity column mimic the type of the underlying column
|
|
if (s_max > 2 ** 31 - 1
|
|
and self._columns[column]['typname'] != 'int8'):
|
|
s_max = 2 ** 31 - 1
|
|
if (s_min < -(2 ** 31)
|
|
and self._columns[column]['typname'] != 'int8'):
|
|
s_min = -(2 ** 31)
|
|
cursor.execute(
|
|
SQL("ALTER TABLE {} ALTER COLUMN {} DROP DEFAULT").format(
|
|
Identifier(table), Identifier(column)))
|
|
cursor.execute(
|
|
SQL("DROP SEQUENCE {}").format(
|
|
Identifier(previous_seq_name)))
|
|
cursor.execute(
|
|
SQL("ALTER TABLE {} ALTER COLUMN {} "
|
|
"ADD GENERATED BY DEFAULT AS IDENTITY").format(
|
|
Identifier(table), Identifier(column)))
|
|
cursor.execute(
|
|
"SELECT pg_get_serial_sequence(format(%s, %s), %s)",
|
|
('%I', table, column))
|
|
serial_seq_name, = cursor.fetchone()
|
|
cursor.execute(
|
|
(f"ALTER SEQUENCE {serial_seq_name} INCREMENT BY %s "
|
|
"MINVALUE %s MAXVALUE %s RESTART WITH %s CACHE %s"),
|
|
(increment, s_min, s_max, next_val, cache))
|
|
|
|
update_definitions = False
|
|
if 'id' not in self._columns:
|
|
update_definitions = True
|
|
if not self.history:
|
|
cursor.execute(
|
|
SQL(
|
|
"ALTER TABLE {} ADD COLUMN id INTEGER "
|
|
"GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY").format(
|
|
Identifier(self.table_name)))
|
|
else:
|
|
cursor.execute(
|
|
SQL('ALTER TABLE {} ADD COLUMN id INTEGER')
|
|
.format(Identifier(self.table_name)))
|
|
else:
|
|
if not self.history and not self.__columns['id']['identity']:
|
|
update_definitions = True
|
|
migrate_to_identity(self.table_name, 'id')
|
|
if self.history and '__id' not in self._columns:
|
|
update_definitions = True
|
|
cursor.execute(
|
|
SQL(
|
|
"ALTER TABLE {} ADD COLUMN __id INTEGER "
|
|
"GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY").format(
|
|
Identifier(self.table_name)))
|
|
elif self.history:
|
|
if not self.__columns['__id']['identity']:
|
|
update_definitions = True
|
|
cursor.execute(
|
|
SQL("ALTER TABLE {} ALTER COLUMN id DROP DEFAULT").format(
|
|
Identifier(self.table_name)))
|
|
migrate_to_identity(self.table_name, '__id')
|
|
if update_definitions:
|
|
self._update_definitions(columns=True)
|
|
|
|
@classmethod
|
|
def table_exist(cls, table_name):
|
|
transaction = Transaction()
|
|
return bool(transaction.database.get_table_schema(
|
|
transaction.connection, table_name))
|
|
|
|
@classmethod
|
|
def table_rename(cls, old_name, new_name):
|
|
transaction = Transaction()
|
|
cursor = transaction.connection.cursor()
|
|
# Rename table
|
|
if (cls.table_exist(old_name)
|
|
and not cls.table_exist(new_name)):
|
|
cursor.execute(SQL('ALTER TABLE {} RENAME TO {}').format(
|
|
Identifier(old_name), Identifier(new_name)))
|
|
# Migrate from 6.6: rename old sequence
|
|
old_sequence = old_name + '_id_seq'
|
|
new_sequence = new_name + '_id_seq'
|
|
transaction.database.sequence_rename(
|
|
transaction.connection, old_sequence, new_sequence)
|
|
# Rename history table
|
|
old_history = old_name + "__history"
|
|
new_history = new_name + "__history"
|
|
if (cls.table_exist(old_history)
|
|
and not cls.table_exist(new_history)):
|
|
cursor.execute('ALTER TABLE "%s" RENAME TO "%s"'
|
|
% (old_history, new_history))
|
|
|
|
def column_exist(self, column_name):
|
|
return column_name in self._columns
|
|
|
|
def column_rename(self, old_name, new_name):
|
|
cursor = Transaction().connection.cursor()
|
|
if self.column_exist(old_name):
|
|
if not self.column_exist(new_name):
|
|
cursor.execute(SQL(
|
|
'ALTER TABLE {} RENAME COLUMN {} TO {}').format(
|
|
Identifier(self.table_name),
|
|
Identifier(old_name),
|
|
Identifier(new_name)))
|
|
self._update_definitions(columns=True)
|
|
else:
|
|
logger.warning(
|
|
'Unable to rename column %s on table %s to %s.',
|
|
old_name, self.table_name, new_name)
|
|
|
|
@property
|
|
def _columns(self):
|
|
if self.__columns is None:
|
|
cursor = Transaction().connection.cursor()
|
|
self.__columns = {}
|
|
# Fetch columns definitions from the table
|
|
cursor.execute('SELECT '
|
|
'column_name, udt_name, is_nullable, '
|
|
'character_maximum_length, '
|
|
'column_default, is_identity '
|
|
'FROM information_schema.columns '
|
|
'WHERE table_name = %s AND table_schema = %s',
|
|
(self.table_name, self.table_schema))
|
|
for column, typname, nullable, size, default, identity in cursor:
|
|
self.__columns[column] = {
|
|
'typname': typname,
|
|
'notnull': True if nullable == 'NO' else False,
|
|
'size': size,
|
|
'default': default,
|
|
'identity': False if identity == 'NO' else True,
|
|
}
|
|
return self.__columns
|
|
|
|
@property
|
|
def _constraints(self):
|
|
if self.__constraints is None:
|
|
cursor = Transaction().connection.cursor()
|
|
# fetch constraints for the table
|
|
cursor.execute('SELECT constraint_name '
|
|
'FROM information_schema.table_constraints '
|
|
'WHERE table_name = %s AND table_schema = %s',
|
|
(self.table_name, self.table_schema))
|
|
self.__constraints = [c for c, in cursor]
|
|
|
|
# add nonstandard exclude constraint
|
|
cursor.execute('SELECT c.conname '
|
|
'FROM pg_namespace nc, '
|
|
'pg_namespace nr, '
|
|
'pg_constraint c, '
|
|
'pg_class r '
|
|
'WHERE nc.oid = c.connamespace AND nr.oid = r.relnamespace '
|
|
'AND c.conrelid = r.oid '
|
|
"AND c.contype = 'x' " # exclude type
|
|
"AND r.relkind IN ('r', 'p') "
|
|
'AND r.relname = %s AND nr.nspname = %s',
|
|
(self.table_name, self.table_schema))
|
|
self.__constraints.extend((c for c, in cursor))
|
|
return self.__constraints
|
|
|
|
@property
|
|
def _fk_deltypes(self):
|
|
if self.__fk_deltypes is None:
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute('SELECT k.column_name, r.delete_rule '
|
|
'FROM information_schema.key_column_usage AS k '
|
|
'JOIN information_schema.referential_constraints AS r '
|
|
'ON r.constraint_schema = k.constraint_schema '
|
|
'AND r.constraint_name = k.constraint_name '
|
|
'WHERE k.table_name = %s AND k.table_schema = %s',
|
|
(self.table_name, self.table_schema))
|
|
self.__fk_deltypes = dict(cursor)
|
|
return self.__fk_deltypes
|
|
|
|
@property
|
|
def _indexes(self):
|
|
if self.__indexes is None:
|
|
cursor = Transaction().connection.cursor()
|
|
# Fetch indexes defined for the table
|
|
cursor.execute("SELECT cl2.relname "
|
|
"FROM pg_index ind "
|
|
"JOIN pg_class cl on (cl.oid = ind.indrelid) "
|
|
"JOIN pg_namespace n ON (cl.relnamespace = n.oid) "
|
|
"JOIN pg_class cl2 on (cl2.oid = ind.indexrelid) "
|
|
"WHERE cl.relname = %s AND n.nspname = %s "
|
|
"AND NOT ind.indisprimary AND NOT ind.indisunique",
|
|
(self.table_name, self.table_schema))
|
|
self.__indexes = [l[0] for l in cursor]
|
|
return self.__indexes
|
|
|
|
def _update_definitions(self, columns=None, constraints=None):
|
|
if columns is None and constraints is None:
|
|
columns = constraints = True
|
|
if columns:
|
|
self.__columns = None
|
|
if constraints:
|
|
self.__constraints = None
|
|
self.__fk_deltypes = None
|
|
|
|
def alter_size(self, column_name, column_type):
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute(
|
|
SQL("ALTER TABLE {} ALTER COLUMN {} TYPE {}").format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name),
|
|
SQL(column_type)))
|
|
self._update_definitions(columns=True)
|
|
|
|
def alter_type(self, column_name, column_type):
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute(SQL('ALTER TABLE {} ALTER {} TYPE {}').format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name),
|
|
SQL(column_type)))
|
|
self._update_definitions(columns=True)
|
|
|
|
def column_is_type(self, column_name, type_, *, size=-1):
|
|
db_type = self._columns[column_name]['typname'].upper()
|
|
|
|
database = Transaction().database
|
|
base_type = database.sql_type(type_).base.upper()
|
|
if base_type == 'VARCHAR' and (size is None or size >= 0):
|
|
same_size = self._columns[column_name]['size'] == size
|
|
else:
|
|
same_size = True
|
|
|
|
return base_type == db_type and same_size
|
|
|
|
def db_default(self, column_name, value):
|
|
if value in [True, False]:
|
|
test = str(value).lower()
|
|
else:
|
|
test = value
|
|
if self._columns[column_name]['default'] != test:
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute(
|
|
SQL(
|
|
'ALTER TABLE {} ALTER COLUMN {} SET DEFAULT %s').format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)),
|
|
(value,))
|
|
|
|
def add_column(self, column_name, sql_type, default=None, comment=''):
|
|
cursor = Transaction().connection.cursor()
|
|
database = Transaction().database
|
|
|
|
column_type = database.sql_type(sql_type)
|
|
match = VARCHAR_SIZE_RE.match(sql_type)
|
|
field_size = int(match.group(1)) if match else None
|
|
|
|
def add_comment():
|
|
if comment and self.is_owner:
|
|
cursor.execute(
|
|
SQL('COMMENT ON COLUMN {}.{} IS %s').format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)),
|
|
(comment,))
|
|
if self.column_exist(column_name):
|
|
if (column_name in ('create_date', 'write_date')
|
|
and column_type[1].lower() != 'timestamp(6)'):
|
|
# Migrate dates from timestamp(0) to timestamp
|
|
cursor.execute(
|
|
SQL(
|
|
'ALTER TABLE {} ALTER COLUMN {} TYPE timestamp')
|
|
.format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)))
|
|
|
|
add_comment()
|
|
base_type = column_type[0].lower()
|
|
typname = self._columns[column_name]['typname']
|
|
if base_type != typname:
|
|
if (typname, base_type) in [
|
|
('varchar', 'text'),
|
|
('text', 'varchar'),
|
|
('date', 'timestamp'),
|
|
('int2', 'int4'),
|
|
('int2', 'float4'),
|
|
('int2', 'int8'),
|
|
('int2', 'float8'),
|
|
('int2', 'numeric'),
|
|
('int4', 'int8'),
|
|
('int4', 'float8'),
|
|
('int4', 'numeric'),
|
|
('int8', 'float8'),
|
|
('int8', 'numeric'),
|
|
('float4', 'numeric'),
|
|
('float4', 'float8'),
|
|
('float8', 'numeric'),
|
|
]:
|
|
self.alter_type(column_name, base_type)
|
|
elif (typname, base_type) in [
|
|
('int8', 'int4'),
|
|
('int8', 'int2'),
|
|
('int4', 'int2'),
|
|
('float8', 'float4'),
|
|
]:
|
|
pass
|
|
else:
|
|
logger.warning(
|
|
'Unable to migrate column %s on table %s '
|
|
'from %s to %s.',
|
|
column_name, self.table_name, typname, base_type)
|
|
|
|
if base_type == typname == 'varchar':
|
|
# Migrate size
|
|
from_size = self._columns[column_name]['size']
|
|
if field_size is None:
|
|
if from_size:
|
|
self.alter_size(column_name, base_type)
|
|
elif from_size == field_size:
|
|
pass
|
|
elif from_size and from_size < field_size:
|
|
self.alter_size(column_name, column_type[1])
|
|
else:
|
|
logger.warning(
|
|
'Unable to migrate column %s on table %s '
|
|
'from varchar(%s) to varchar(%s).',
|
|
column_name, self.table_name,
|
|
from_size if from_size and from_size > 0 else "",
|
|
field_size)
|
|
return
|
|
|
|
column_type = column_type[1]
|
|
cursor.execute(
|
|
SQL('ALTER TABLE {} ADD COLUMN {} {}').format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name),
|
|
SQL(column_type)))
|
|
add_comment()
|
|
|
|
if default:
|
|
# check if table is non-empty:
|
|
cursor.execute('SELECT 1 FROM "%s" limit 1' % self.table_name)
|
|
if cursor.rowcount:
|
|
# Populate column with default values:
|
|
cursor.execute(
|
|
SQL('UPDATE {} SET {} = %s').format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)),
|
|
(default(),))
|
|
|
|
self._update_definitions(columns=True)
|
|
|
|
def add_fk(self, columns, reference, ref_columns=None, on_delete=None):
|
|
if on_delete is not None:
|
|
on_delete = on_delete.upper()
|
|
else:
|
|
on_delete = 'SET NULL'
|
|
if isinstance(columns, str):
|
|
columns = [columns]
|
|
|
|
cursor = Transaction().connection.cursor()
|
|
if ref_columns:
|
|
ref_columns_name = '_' + '_'.join(ref_columns)
|
|
else:
|
|
ref_columns_name = ''
|
|
name = self.convert_name(
|
|
self.table_name + '_' + '_'.join(columns)
|
|
+ ref_columns_name + '_fkey')
|
|
if name in self._constraints:
|
|
for column_name in columns:
|
|
if self._fk_deltypes.get(column_name) != on_delete:
|
|
self.drop_fk(columns, ref_columns)
|
|
add = True
|
|
break
|
|
else:
|
|
add = False
|
|
else:
|
|
add = True
|
|
if add:
|
|
columns = SQL(', ').join(map(Identifier, columns))
|
|
if not ref_columns:
|
|
ref_columns = ['id']
|
|
ref_columns = SQL(', ').join(map(Identifier, ref_columns))
|
|
cursor.execute(
|
|
SQL(
|
|
"ALTER TABLE {table} "
|
|
"ADD CONSTRAINT {constraint} "
|
|
"FOREIGN KEY ({columns}) "
|
|
"REFERENCES {reference} ({ref_columns}) "
|
|
"ON DELETE {action}"
|
|
)
|
|
.format(
|
|
table=Identifier(self.table_name),
|
|
constraint=Identifier(name),
|
|
columns=columns,
|
|
reference=Identifier(reference),
|
|
ref_columns=ref_columns,
|
|
action=SQL(on_delete)))
|
|
self._update_definitions(constraints=True)
|
|
|
|
def drop_fk(self, columns, ref_columns=None, table=None):
|
|
if isinstance(columns, str):
|
|
columns = [columns]
|
|
if ref_columns:
|
|
ref_columns_name = '_' + '_'.join(ref_columns)
|
|
else:
|
|
ref_columns_name = ''
|
|
self.drop_constraint(
|
|
'_'.join(columns) + ref_columns_name + '_fkey', table=table)
|
|
|
|
def not_null_action(self, column_name, action='add'):
|
|
if not self.column_exist(column_name):
|
|
return
|
|
|
|
with Transaction().connection.cursor() as cursor:
|
|
if action == 'add':
|
|
if self._columns[column_name]['notnull']:
|
|
return
|
|
cursor.execute(SQL(
|
|
'SELECT id FROM {} WHERE {} IS NULL LIMIT 1').format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)))
|
|
if not cursor.rowcount:
|
|
cursor.execute(
|
|
SQL(
|
|
'ALTER TABLE {} ALTER COLUMN {} SET NOT NULL')
|
|
.format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)))
|
|
self._update_definitions(columns=True)
|
|
else:
|
|
logger.warning(
|
|
"Unable to set not null on column %s of table %s.\n"
|
|
"Try restarting one more time.\n"
|
|
"If that doesn't work update the records and restart "
|
|
"again.",
|
|
column_name, self.table_name)
|
|
elif action == 'remove':
|
|
if not self._columns[column_name]['notnull']:
|
|
return
|
|
cursor.execute(
|
|
SQL('ALTER TABLE {} ALTER COLUMN {} DROP NOT NULL')
|
|
.format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)))
|
|
self._update_definitions(columns=True)
|
|
else:
|
|
raise Exception('Not null action not supported!')
|
|
|
|
def add_constraint(self, ident, constraint):
|
|
ident = self.convert_name(self.table_name + "_" + ident)
|
|
if ident in self._constraints:
|
|
# This constrain already exist
|
|
return
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute(
|
|
SQL('ALTER TABLE {} ADD CONSTRAINT {} {}').format(
|
|
Identifier(self.table_name),
|
|
Identifier(ident),
|
|
SQL(str(constraint))),
|
|
constraint.params)
|
|
self._update_definitions(constraints=True)
|
|
|
|
def drop_constraint(self, ident, table=None):
|
|
ident = self.convert_name((table or self.table_name) + "_" + ident)
|
|
if ident not in self._constraints:
|
|
return
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute(
|
|
SQL('ALTER TABLE {} DROP CONSTRAINT {}').format(
|
|
Identifier(self.table_name), Identifier(ident)))
|
|
self._update_definitions(constraints=True)
|
|
|
|
def set_indexes(self, indexes, concurrently=False):
|
|
cursor = Transaction().connection.cursor()
|
|
old = set(self._indexes)
|
|
for index in indexes:
|
|
translator = self.index_translator_for(index)
|
|
if translator:
|
|
name, query, params = translator.definition(index)
|
|
name = '_'.join([self.table_name, name])
|
|
name = 'idx_' + self.convert_name(name, reserved=len('idx_'))
|
|
cursor.execute(
|
|
'SELECT idx.indisvalid '
|
|
'FROM pg_index idx '
|
|
'JOIN pg_class cls ON cls.oid = idx.indexrelid '
|
|
'WHERE cls.relname = %s',
|
|
(name,))
|
|
if (idx_valid := cursor.fetchone()) and not idx_valid[0]:
|
|
cursor.execute(
|
|
SQL("DROP INDEX {}").format(Identifier(name)))
|
|
cursor.execute(
|
|
SQL('CREATE INDEX {} IF NOT EXISTS {} ON {} USING {}')
|
|
.format(
|
|
SQL('CONCURRENTLY' if concurrently else ''),
|
|
Identifier(name),
|
|
Identifier(self.table_name),
|
|
query),
|
|
params)
|
|
old.discard(name)
|
|
for name in old:
|
|
if name.startswith('idx_') or name.endswith('_index'):
|
|
cursor.execute(SQL('DROP INDEX {}').format(Identifier(name)))
|
|
self.__indexes = None
|
|
|
|
def drop_column(self, column_name):
|
|
if not self.column_exist(column_name):
|
|
return
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute(SQL('ALTER TABLE {} DROP COLUMN {}').format(
|
|
Identifier(self.table_name),
|
|
Identifier(column_name)))
|
|
self._update_definitions(columns=True)
|
|
|
|
@classmethod
|
|
def drop_table(cls, model, table, cascade=False):
|
|
cursor = Transaction().connection.cursor()
|
|
cursor.execute('DELETE FROM ir_model_data WHERE model = %s', (model,))
|
|
|
|
query = 'DROP TABLE {}'
|
|
if cascade:
|
|
query = query + ' CASCADE'
|
|
cursor.execute(SQL(query).format(Identifier(table)))
|
|
|
|
|
|
class IndexMixin:
|
|
|
|
_type = None
|
|
|
|
def __init_subclass__(cls):
|
|
TableHandler.index_translators.append(cls)
|
|
|
|
@classmethod
|
|
def definition(cls, index):
|
|
expr_template = SQL('{expression} {collate} {opclass} {order}')
|
|
indexed_expressions = cls._get_indexed_expressions(index)
|
|
expressions = []
|
|
params = []
|
|
for expression, usage in indexed_expressions:
|
|
expressions.append(expr_template.format(
|
|
**cls._get_expression_variables(expression, usage)))
|
|
params.extend(expression.params)
|
|
|
|
include = SQL('')
|
|
if index.options.get('include'):
|
|
include = SQL('INCLUDE ({columns})').format(
|
|
columns=SQL(',').join(map(
|
|
lambda c: SQL(str(c)),
|
|
index.options.get('include'))))
|
|
|
|
where = SQL('')
|
|
if index.options.get('where'):
|
|
where = SQL('WHERE {where}').format(
|
|
where=SQL(str(index.options['where'])))
|
|
params.extend(index.options['where'].params)
|
|
|
|
query = SQL('{type} ({expressions}) {include} {where}').format(
|
|
type=SQL(cls._type),
|
|
expressions=SQL(',').join(expressions),
|
|
include=include,
|
|
where=where)
|
|
name = cls._get_name(query, params)
|
|
return name, query, params
|
|
|
|
@classmethod
|
|
def _get_indexed_expressions(cls, index):
|
|
return index.expressions
|
|
|
|
@classmethod
|
|
def _get_expression_variables(cls, expression, usage):
|
|
variables = {
|
|
'expression': SQL(str(expression)),
|
|
'collate': SQL(''),
|
|
'opclass': SQL(''),
|
|
'order': SQL(''),
|
|
}
|
|
if usage.options.get('collation'):
|
|
variables['collate'] = SQL('COLLATE {}').format(
|
|
usage.options['collation'])
|
|
if usage.options.get('order'):
|
|
order = usage.options['order'].upper()
|
|
variables['order'] = SQL(order)
|
|
return variables
|
|
|
|
|
|
class HashTranslator(IndexMixin, IndexTranslatorInterface):
|
|
_type = 'HASH'
|
|
|
|
@classmethod
|
|
def score(cls, index):
|
|
if (len(index.expressions) > 1
|
|
or index.expressions[0][1].__class__.__name__ != 'Equality'):
|
|
return 0
|
|
if index.options.get('include'):
|
|
return 0
|
|
return 100
|
|
|
|
@classmethod
|
|
def _get_indexed_expressions(cls, index):
|
|
return [
|
|
(e, u) for e, u in index.expressions
|
|
if u.__class__.__name__ == 'Equality'][:1]
|
|
|
|
|
|
class BTreeTranslator(IndexMixin, IndexTranslatorInterface):
|
|
_type = 'BTREE'
|
|
|
|
@classmethod
|
|
def score(cls, index):
|
|
score = 0
|
|
for _, usage in index.expressions:
|
|
if usage.__class__.__name__ == 'Range':
|
|
score += 100
|
|
elif usage.__class__.__name__ == 'Equality':
|
|
score += 50
|
|
elif usage.__class__.__name__ == 'Similarity':
|
|
score += 20
|
|
if usage.options.get('begin'):
|
|
score += 100
|
|
return score
|
|
|
|
@classmethod
|
|
def _get_indexed_expressions(cls, index):
|
|
return [
|
|
(e, u) for e, u in index.expressions
|
|
if u.__class__.__name__ in {'Equality', 'Range', 'Similarity'}]
|
|
|
|
@classmethod
|
|
def _get_expression_variables(cls, expression, usage):
|
|
params = super()._get_expression_variables(expression, usage)
|
|
if (usage.__class__.__name__ == 'Similarity'
|
|
and not usage.options.get('collation')):
|
|
# text_pattern_ops and varchar_pattern_ops are the same
|
|
params['opclass'] = SQL('varchar_pattern_ops')
|
|
return params
|
|
|
|
|
|
class TrigramTranslator(IndexMixin, IndexTranslatorInterface):
|
|
_type = 'GIN'
|
|
|
|
@classmethod
|
|
def score(cls, index):
|
|
database = Transaction().database
|
|
has_btree_gin = database.has_extension('btree_gin')
|
|
has_trigram = database.has_extension('pg_trgm')
|
|
if not has_btree_gin and not has_trigram:
|
|
return 0
|
|
|
|
score = 0
|
|
for _, usage in index.expressions:
|
|
if usage.__class__.__name__ == 'Similarity':
|
|
if has_trigram:
|
|
score += 100
|
|
else:
|
|
score += 50
|
|
elif has_btree_gin:
|
|
if usage.__class__.__name__ == 'Range':
|
|
score += 90
|
|
elif usage.__class__.__name__ == 'Equality':
|
|
score += 40
|
|
else:
|
|
return 0
|
|
return score
|
|
|
|
@classmethod
|
|
def _get_indexed_expressions(cls, index):
|
|
database = Transaction().database
|
|
has_btree_gin = database.has_extension('btree_gin')
|
|
has_trigram = database.has_extension('pg_trgm')
|
|
|
|
def filter(usage):
|
|
if usage.__class__.__name__ == 'Similarity':
|
|
return has_trigram
|
|
elif usage.__class__.__name__ in {'Range', 'Equality'}:
|
|
return has_btree_gin
|
|
else:
|
|
return False
|
|
return [(e, u) for e, u in index.expressions if filter(u)]
|
|
|
|
@classmethod
|
|
def _get_expression_variables(cls, expression, usage):
|
|
params = super()._get_expression_variables(expression, usage)
|
|
if usage.__class__.__name__ == 'Similarity':
|
|
params['opclass'] = SQL('gin_trgm_ops')
|
|
return params
|