Files
tradon/tools/misc.py
2025-12-26 13:11:43 +00:00

325 lines
9.4 KiB
Python
Executable File

# -*- coding: utf-8 -*-
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
"""
Miscelleanous tools used by tryton
"""
import importlib
import io
import os
import re
import types
import unicodedata
import warnings
from array import array
from collections.abc import Iterable, Sized
from functools import wraps
from itertools import chain, islice, tee, zip_longest
from sql import Literal
from sql.conditionals import Case
from sql.operators import Or
from trytond.const import MODULES_GROUP, OPERATORS
try:
from backports.entry_points_selectable import entry_points as _entry_points
except ImportError:
from importlib.metadata import entry_points as _entry_points
_ENTRY_POINTS = None
def entry_points():
global _ENTRY_POINTS
if _ENTRY_POINTS is None:
_ENTRY_POINTS = _entry_points()
return _ENTRY_POINTS
def import_module(name):
try:
ep, = entry_points().select(group=MODULES_GROUP, name=name)
except ValueError:
return importlib.import_module(f'{MODULES_GROUP}.{name}')
return ep.load()
def file_open(name, mode="r", subdir='modules', encoding=None):
"Open a file from the root directory, using subdir folder"
path = find_path(name, subdir, _test=None)
return io.open(path, mode, encoding=encoding)
def find_path(name, subdir='modules', _test=os.path.isfile):
"Return path from the root directory, using subdir folder"
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def secure_join(root, *paths):
"Join paths and ensure it still below root"
path = os.path.join(root, *paths)
path = os.path.normpath(path)
if not path.startswith(os.path.join(root, '')):
raise IOError("Permission denied: %s" % name)
return path
if subdir:
if subdir == 'modules':
try:
module_name, module_path = name.split(os.sep, 1)
except ValueError:
module_name, module_path = name, ''
if module_name in {'ir', 'res', 'tests'}:
path = secure_join(root_path, module_name, module_path)
else:
try:
module = import_module(module_name)
except ModuleNotFoundError:
path = secure_join(root_path, subdir, name)
else:
path = os.path.dirname(module.__file__)
if module_path:
path = secure_join(path, module_path)
else:
path = secure_join(root_path, subdir, name)
else:
path = secure_join(root_path, name)
if not _test or _test(path):
return path
else:
raise FileNotFoundError("No such file or directory: %r" % name)
def find_dir(name, subdir='modules'):
"Return directory from the root directory, using subdir folder"
return find_path(name, subdir=subdir, _test=os.path.isdir)
def get_smtp_server():
"""
Instanciate, configure and return a SMTP or SMTP_SSL instance from
smtplib.
:return: A SMTP instance. The quit() method must be call when all
the calls to sendmail() have been made.
"""
from trytond.sendmail import get_smtp_server
warnings.warn(
'get_smtp_server is deprecated use trytond.sendmail',
DeprecationWarning)
return get_smtp_server()
def reduce_ids(field, ids):
'''
Return a small SQL expression for the list of ids and the sql column
'''
if __debug__:
def strict_int(value):
assert not isinstance(value, float) or value.is_integer(), \
"ids must be integer"
return int(value)
else:
strict_int = int
ids = list(map(strict_int, ids))
if not ids:
return Literal(False)
ids.sort()
prev = ids.pop(0)
continue_list = [prev, prev]
discontinue_list = array('l')
sql = Or()
for i in ids:
if i == prev:
continue
if i != prev + 1:
if continue_list[-1] - continue_list[0] < 5:
discontinue_list.extend([continue_list[0] + x for x in
range(continue_list[-1] - continue_list[0] + 1)])
else:
sql.append((field >= continue_list[0])
& (field <= continue_list[-1]))
continue_list = []
continue_list.append(i)
prev = i
if continue_list[-1] - continue_list[0] < 5:
discontinue_list.extend([continue_list[0] + x for x in
range(continue_list[-1] - continue_list[0] + 1)])
else:
sql.append((field >= continue_list[0]) & (field <= continue_list[-1]))
if discontinue_list:
sql.append(field.in_(discontinue_list))
return sql
def reduce_domain(domain):
'''
Reduce domain
'''
if not domain:
return []
operator = 'AND'
if isinstance(domain[0], str):
operator = domain[0]
domain = domain[1:]
result = [operator]
for arg in domain:
if (isinstance(arg, tuple)
or (isinstance(arg, list)
and len(arg) > 2
and arg[1] in OPERATORS)):
# clause
result.append(arg)
elif isinstance(arg, list) and arg:
# sub-domain
sub_domain = reduce_domain(arg)
sub_operator = sub_domain[0]
if sub_operator == operator:
result.extend(sub_domain[1:])
else:
result.append(sub_domain)
else:
result.append(arg)
return result
def grouped_slice(records, count=None):
'Grouped slice'
from trytond.transaction import Transaction
if count is None:
count = Transaction().database.IN_MAX
count = max(1, count)
if not isinstance(records, Sized):
records = list(records)
for i in range(0, len(records), count):
yield islice(records, i, i + count)
def pairwise_longest(iterable):
a, b = tee(iterable)
next(b, None)
return zip_longest(a, b)
def is_instance_method(cls, method):
for klass in cls.__mro__:
type_ = klass.__dict__.get(method)
if type_ is not None:
return isinstance(type_, types.FunctionType)
def resolve(name):
"Resolve a dotted name to a global object."
name = name.split('.')
used = name.pop(0)
found = importlib.import_module(used)
for n in name:
used = used + '.' + n
try:
found = getattr(found, n)
except AttributeError:
found = importlib.import_module(used)
return found
def strip_wildcard(string, wildcard='%', escape='\\'):
"Strip starting and ending wildcard from string"
string = lstrip_wildcard(string, wildcard)
return rstrip_wildcard(string, wildcard, escape)
def lstrip_wildcard(string, wildcard='%'):
"Strip starting wildcard from string"
if string and string.startswith(wildcard):
string = string[1:]
return string
def rstrip_wildcard(string, wildcard='%', escape='\\'):
"Strip ending wildcard from string"
if (string
and string.endswith(wildcard)
and not string.endswith(escape + wildcard)):
string = string[:-1]
return string
def escape_wildcard(string, wildcards='%_', escape='\\'):
for wildcard in escape + wildcards:
string = string.replace(wildcard, escape + wildcard)
return string
def unescape_wildcard(string, wildcards='%_', escape='\\'):
for wildcard in wildcards + escape:
string = string.replace(escape + wildcard, wildcard)
return string
def is_full_text(value, escape='\\'):
escaped = strip_wildcard(value, escape=escape)
escaped = escaped.replace(escape + '%', '').replace(escape + '_', '')
if '%' in escaped or '_' in escaped:
return False
return value.startswith('%') == value.endswith('%')
def likify(string, escape='\\'):
if not string:
return '%'
escaped = string.replace(escape + '%', '').replace(escape + '_', '')
if '%' in escaped or '_' in escaped:
return string
else:
return '%' + string + '%'
_slugify_strip_re = re.compile(r'[^\w\s-]')
_slugify_hyphenate_re = re.compile(r'[-\s]+')
def slugify(value, hyphenate='-'):
if not isinstance(value, str):
value = str(value)
value = unicodedata.normalize('NFKD', value)
value = str(_slugify_strip_re.sub('', value).strip())
return _slugify_hyphenate_re.sub(hyphenate, value)
def sortable_values(func):
"Decorator that makes list of values sortable"
@wraps(func)
def wrapper(*args, **kwargs):
result = list(func(*args, **kwargs))
for i, value in enumerate(list(result)):
if not isinstance(value, Iterable):
value = [value]
result[i] = tuple(chain((k is None, k) for k in value))
return result
return wrapper
def sql_pairing(x, y):
"""Return SQL expression to pair x and y
Pairing function from http://szudzik.com/ElegantPairing.pdf"""
return Case(
(x < y, (y * y) + x),
else_=(x * x) + x + y)
def firstline(text):
"Returns first non-empty line"
try:
return next((x for x in text.splitlines() if x.strip()))
except StopIteration:
return ''
def remove_forbidden_chars(value):
from trytond.model.fields import Char
if value is None:
return value
for c in Char.forbidden_chars:
if c in value:
value = value.replace(c, ' ')
return value.strip()