# -*- coding: ascii -*-
u"""
==========================================
Constraint inspection and representation
==========================================
Constraint inspection and representation.
:Copyright:
Copyright 2010 - 2023
Andr\xe9 Malo or his licensors, as applicable
:License:
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__author__ = u"Andr\xe9 Malo"
import keyword as _keyword
import re as _re
import tokenize as _tokenize
from . import _util
[docs]class Constraint(object):
"""
Reflected Constraint
Attributes:
constraint (SA Constraint):
Constraint
"""
_SYMBOL, _IMPORT = None, None
[docs] def __new__(cls, constraint, table, symbols, options=None):
"""Constraint factory"""
if cls == Constraint:
name = constraint.__class__.__name__
if name == 'CheckConstraint':
return None
return globals()[name](
constraint, table, symbols, options=options
)
return object.__new__(cls)
[docs] def __init__(self, constraint, table, symbols, options=None):
"""
Initialization
Parameters:
constraint (SA Constraint):
Constraint
table (str):
Table varname
symbols (Symbols):
Symbol table
options (str):
Options
"""
self.constraint = constraint
self.table = table
self._symbols = symbols
self._symbols.imports[self._SYMBOL] = self._IMPORT
self.options = options
[docs] def copy(self):
"""Create shallow copy"""
return self.__class__(
self.constraint, self.table, self._symbols, self.options
)
[docs] def __cmp__(self, other):
"""Compare"""
names = [
'PrimaryKeyConstraint',
'UniqueConstraint',
'ForeignKeyConstraint',
'CheckConstraint',
]
def bytype(const):
"""Sort by type"""
try:
return names.index(const.__class__.__name__)
except IndexError:
return -1
return _util.cmp(
(
bytype(self.constraint),
self.options is not None,
self.constraint.name,
repr(self),
),
(
bytype(other.constraint),
other.options is not None,
other.constraint.name,
repr(other),
),
)
[docs] def __lt__(self, other, _cmp=__cmp__):
"""Check for '<'"""
return _cmp(self, other) < 0
[docs] def repr(self, symbol, args, keywords=(), short=False):
"""
Base repr for all constraints
Parameters:
symbol (str):
Symbol name
args (iterable):
Positional arguments
keywords (iterable):
Keywords arguments to specify
short (bool):
Short representation (i.e. one-line)? Only applied if there are
not too many parameters.
Returns:
str: The constraint repr
"""
# pylint: disable = too-many-branches
params = []
if self.constraint.name is not None:
params.append('name=%r' % (self.constraint.name,))
if self.constraint.deferrable is not None:
params.append('deferrable=%r' % (self.constraint.deferrable,))
if self.constraint.initially is not None:
params.append('initially=%r' % (self.constraint.initially,))
for keyword in keywords:
if getattr(self.constraint, keyword) is not None:
params.append(
"%s=%r" % (keyword, getattr(self.constraint, keyword))
)
if short and len(params) > 1:
short = False
if args:
if short:
args = ', '.join(args)
else:
args = '\n ' + ',\n '.join(args) + ','
else:
args = ''
if short:
params = ', '.join(params)
if args and params:
params = ', ' + params
else:
params = ',\n '.join(params)
if params:
params = '\n ' + params + ','
if args or params:
params += '\n'
return "%s(%s%s)" % (self._symbols[symbol], args, params)
[docs]def access_col(col):
"""
Generate column access string (either as attribute or via dict access)
Parameters:
col (SA Column):
Column
Returns:
str: Access string
"""
try:
name = col.name
except AttributeError:
name = col
try:
if _util.py2 and isinstance(name, _util.bytes):
name.decode('ascii')
else:
name.encode('ascii')
except UnicodeError:
is_ascii = False
else:
is_ascii = True
if (
is_ascii
and not _keyword.iskeyword(name)
and _re.match(_tokenize.Name + '$', name)
):
return ".c.%s" % name
return ".c[%r]" % name
[docs]class UniqueConstraint(Constraint):
"""Unique constraint"""
_SYMBOL = 'uk'
_IMPORT = 'from %(constraints)s import Unique as %(uk)s'
[docs] def __repr__(self):
"""
Make string representation
Returns:
str: The string representation
"""
empty = len(self.constraint.columns) == 0
short = len(self.constraint.columns) <= 1
result = self.repr(
self._SYMBOL,
[
"%s%s" % (self.table, access_col(col))
for col in self.constraint.columns
],
short=short,
)
if empty:
result = "# %s" % result
return result
[docs]class PrimaryKeyConstraint(UniqueConstraint):
"""Primary Key constraint"""
_SYMBOL = 'pk'
_IMPORT = 'from %(constraints)s import PrimaryKey as %(pk)s'
[docs]class ForeignKeyConstraint(Constraint):
"""ForeignKey constraint"""
_SYMBOL = 'fk'
_IMPORT = 'from %(constraints)s import ForeignKey as %(fk)s'
[docs] def __repr__(self):
"""
Make string representation
Returns:
str: The string representation
"""
columns = "[%s]" % ',\n '.join(
[
"%s%s" % (self.table, access_col(col))
for col in self.constraint.columns
]
)
refcolumns = "[%s]" % ',\n '.join(
[
"%s%s"
% (
self._symbols[u'table_%s' % key.column.table.name],
access_col(key.column),
)
for key in self.constraint.elements
]
)
keywords = ['onupdate', 'ondelete']
if self.constraint.use_alter:
keywords.append('use_alter')
result = self.repr('fk', [columns, refcolumns], keywords)
if self.options:
cyclic = self.constraint.use_alter
if self.options.startswith('seen:'):
table = self.options.split(None, 1)[1]
if cyclic:
result = '\n# Cyclic foreign key:\n' + result
else:
result = '\n# Foreign key belongs to %r:\n%s' % (
table,
result,
)
elif self.options.startswith('unseen:'):
table = self.options.split(None, 1)[1]
result = result.splitlines(True)
if cyclic:
result.insert(
0,
'Cyclic foreign key, defined at table %r:\n' % table,
)
else:
result.insert(0, 'Defined at table %r:\n' % table)
result = '\n' + ''.join(['# %s' % item for item in result])
return result