Merge branch 'master' into updates
This commit is contained in:
63
CHANGES.rst
63
CHANGES.rst
@@ -4,6 +4,69 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.27.12 (2014-12-xx)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Fixed PhoneNumber string coercion (#93)
|
||||
|
||||
|
||||
0.27.11 (2014-12-06)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added loose typed column checking support for get_column_key
|
||||
- Made get_column_key throw UnmappedColumnError to be consistent with SQLAlchemy
|
||||
|
||||
|
||||
0.27.10 (2014-12-03)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Fixed column alias handling in dependent_objects
|
||||
|
||||
|
||||
0.27.9 (2014-12-01)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Fixed aggregated decorator many-to-many relationship handling
|
||||
- Fixed aggregated column alias handling
|
||||
|
||||
|
||||
0.27.8 (2014-11-13)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added is_loaded utility function
|
||||
- Removed deprecated has_any_changes
|
||||
|
||||
|
||||
0.27.7 (2014-11-03)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added support for Column and ColumnEntity objects in get_mapper
|
||||
- Made make_order_by_deterministic add deterministic column more aggressively
|
||||
|
||||
|
||||
0.27.6 (2014-10-29)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Fixed assert_max_length not working with non nullable columns
|
||||
- Add PostgreSQL < 9.2 support for drop_database
|
||||
|
||||
|
||||
0.27.5 (2014-10-24)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Made assert_* functions automatically rollback session
|
||||
- Changed make_order_by_deterministic attach order by primary key for queries without order by
|
||||
- Fixed alias handling in has_unique_index
|
||||
- Fixed alias handling in has_index
|
||||
- Fixed alias handling in make_order_by_deterministic
|
||||
|
||||
|
||||
0.27.4 (2014-10-23)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added assert_non_nullable, assert_nullable and assert_max_length testing functions
|
||||
|
||||
|
||||
0.27.3 (2014-10-22)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -20,4 +20,5 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ
|
||||
orm_helpers
|
||||
utility_classes
|
||||
models
|
||||
testing
|
||||
license
|
||||
|
@@ -76,6 +76,12 @@ identity
|
||||
.. autofunction:: identity
|
||||
|
||||
|
||||
is_loaded
|
||||
^^^^^^^^^
|
||||
|
||||
.. autofunction:: is_loaded
|
||||
|
||||
|
||||
make_order_by_deterministic
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
20
docs/testing.rst
Normal file
20
docs/testing.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
Testing
|
||||
=======
|
||||
|
||||
.. automodule:: sqlalchemy_utils.asserts
|
||||
|
||||
|
||||
assert_nullable
|
||||
---------------
|
||||
|
||||
.. autofunction:: assert_nullable
|
||||
|
||||
assert_non_nullable
|
||||
-------------------
|
||||
|
||||
.. autofunction:: assert_non_nullable
|
||||
|
||||
assert_max_length
|
||||
-----------------
|
||||
|
||||
.. autofunction:: assert_max_length
|
2
setup.py
2
setup.py
@@ -47,7 +47,7 @@ extras_require = {
|
||||
'ipaddress': ['ipaddr'] if not PY3 else [],
|
||||
'timezone': ['python-dateutil'],
|
||||
'url': ['furl >= 0.3.5'] if not PY3 else [],
|
||||
'encrypted': ['cryptography==0.6']
|
||||
'encrypted': ['cryptography>=0.6']
|
||||
}
|
||||
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from .aggregates import aggregated
|
||||
from .asserts import assert_nullable, assert_non_nullable, assert_max_length
|
||||
from .batch import batch_fetch, with_backrefs
|
||||
from .decorators import generates
|
||||
from .exceptions import ImproperlyConfigured
|
||||
@@ -23,11 +24,11 @@ from .functions import (
|
||||
get_referencing_foreign_keys,
|
||||
get_tables,
|
||||
group_foreign_keys,
|
||||
has_any_changes,
|
||||
has_changes,
|
||||
has_index,
|
||||
has_unique_index,
|
||||
identity,
|
||||
is_loaded,
|
||||
merge_references,
|
||||
mock_engine,
|
||||
naturally_equivalent,
|
||||
@@ -36,6 +37,7 @@ from .functions import (
|
||||
sort_query,
|
||||
table_name,
|
||||
)
|
||||
from .i18n import TranslationHybrid
|
||||
from .listeners import (
|
||||
auto_delete_orphans,
|
||||
coercion_listener,
|
||||
@@ -78,12 +80,15 @@ from .types import (
|
||||
from .models import Timestamp
|
||||
|
||||
|
||||
__version__ = '0.27.3'
|
||||
__version__ = '0.27.11'
|
||||
|
||||
|
||||
__all__ = (
|
||||
aggregated,
|
||||
analyze,
|
||||
assert_max_length,
|
||||
assert_non_nullable,
|
||||
assert_nullable,
|
||||
auto_delete_orphans,
|
||||
batch_fetch,
|
||||
coercion_listener,
|
||||
@@ -109,11 +114,11 @@ __all__ = (
|
||||
get_referencing_foreign_keys,
|
||||
get_tables,
|
||||
group_foreign_keys,
|
||||
has_any_changes,
|
||||
has_changes,
|
||||
has_index,
|
||||
identity,
|
||||
instrumented_list,
|
||||
is_loaded,
|
||||
merge_references,
|
||||
mock_engine,
|
||||
naturally_equivalent,
|
||||
|
@@ -394,13 +394,64 @@ class AggregatedAttribute(declared_attr):
|
||||
self.relationship = relationship
|
||||
|
||||
def __get__(desc, self, cls):
|
||||
value = (desc.fget, desc.relationship, desc.column)
|
||||
if cls not in aggregated_attrs:
|
||||
aggregated_attrs[cls] = [(desc.fget, desc.relationship)]
|
||||
aggregated_attrs[cls] = [value]
|
||||
else:
|
||||
aggregated_attrs[cls].append((desc.fget, desc.relationship))
|
||||
aggregated_attrs[cls].append(value)
|
||||
return desc.column
|
||||
|
||||
|
||||
def get_aggregate_query(agg_expr, relationships):
|
||||
"""
|
||||
Return a subquery for fetching an aggregate value of given aggregate
|
||||
expression and given sequence of relationships.
|
||||
|
||||
The returned aggregate query can be used when updating denormalized column
|
||||
value with query such as:
|
||||
|
||||
UPDATE table SET column = {aggregate_query}
|
||||
WHERE {condition}
|
||||
|
||||
:param agg_expr:
|
||||
an expression to be selected, for example sa.func.count('1')
|
||||
:param relationships:
|
||||
Sequence of relationships to be used for building the aggregate
|
||||
query.
|
||||
"""
|
||||
from_ = relationships[0].mapper.class_.__table__
|
||||
for relationship in relationships[0:-1]:
|
||||
property_ = relationship.property
|
||||
if property_.secondary is not None:
|
||||
from_ = from_.join(
|
||||
property_.secondary,
|
||||
property_.secondaryjoin
|
||||
)
|
||||
|
||||
from_ = (
|
||||
from_
|
||||
.join(
|
||||
property_.parent.class_,
|
||||
property_.primaryjoin
|
||||
)
|
||||
)
|
||||
|
||||
prop = relationships[-1].property
|
||||
condition = prop.primaryjoin
|
||||
if prop.secondary is not None:
|
||||
from_ = from_.join(
|
||||
prop.secondary,
|
||||
prop.secondaryjoin
|
||||
)
|
||||
|
||||
query = sa.select(
|
||||
[agg_expr],
|
||||
from_obj=[from_]
|
||||
)
|
||||
|
||||
return query.where(condition)
|
||||
|
||||
|
||||
class AggregatedValue(object):
|
||||
def __init__(self, class_, attr, relationships, expr):
|
||||
self.class_ = class_
|
||||
@@ -418,23 +469,7 @@ class AggregatedValue(object):
|
||||
|
||||
@property
|
||||
def aggregate_query(self):
|
||||
from_ = self.relationships[0].mapper.class_.__table__
|
||||
for relationship in self.relationships[0:-1]:
|
||||
property_ = relationship.property
|
||||
from_ = (
|
||||
from_
|
||||
.join(
|
||||
property_.parent.class_,
|
||||
property_.primaryjoin
|
||||
)
|
||||
)
|
||||
|
||||
query = sa.select(
|
||||
[self.expr],
|
||||
from_obj=[from_]
|
||||
)
|
||||
|
||||
query = query.where(self.relationships[-1])
|
||||
query = get_aggregate_query(self.expr, self.relationships)
|
||||
|
||||
return query.correlate(self.class_).as_scalar()
|
||||
|
||||
@@ -484,11 +519,22 @@ class AggregatedValue(object):
|
||||
property_ = self.relationships[-1].property
|
||||
|
||||
from_ = property_.mapper.class_.__table__
|
||||
for relationship in reversed(self.relationships[1:-1]):
|
||||
for relationship in reversed(self.relationships[0:-1]):
|
||||
property_ = relationship.property
|
||||
from_ = (
|
||||
from_.join(property_.mapper.class_, property_.primaryjoin)
|
||||
)
|
||||
if property_.secondary is not None:
|
||||
from_ = from_.join(
|
||||
property_.secondary,
|
||||
property_.primaryjoin
|
||||
)
|
||||
from_ = from_.join(
|
||||
property_.mapper.class_,
|
||||
property_.secondaryjoin
|
||||
)
|
||||
else:
|
||||
from_ = from_.join(
|
||||
property_.mapper.class_,
|
||||
property_.primaryjoin
|
||||
)
|
||||
return from_
|
||||
|
||||
def local_condition(self, prop, objects):
|
||||
@@ -532,7 +578,7 @@ class AggregationManager(object):
|
||||
|
||||
def update_generator_registry(self):
|
||||
for class_, attrs in six.iteritems(aggregated_attrs):
|
||||
for expr, relationship in attrs:
|
||||
for expr, relationship, column in attrs:
|
||||
relationships = []
|
||||
rel_class = class_
|
||||
|
||||
@@ -544,7 +590,7 @@ class AggregationManager(object):
|
||||
self.generator_registry[rel_class].append(
|
||||
AggregatedValue(
|
||||
class_=class_,
|
||||
attr=expr.__name__,
|
||||
attr=column,
|
||||
relationships=list(reversed(relationships)),
|
||||
expr=expr(class_)
|
||||
)
|
||||
|
106
sqlalchemy_utils/asserts.py
Normal file
106
sqlalchemy_utils/asserts.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
The functions in this module can be used for testing that the constraints of
|
||||
your models. Each assert function runs SQL UPDATEs that check for the existence
|
||||
of given constraint. Consider the following model::
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(200), nullable=True)
|
||||
email = sa.Column(sa.String(255), nullable=False)
|
||||
|
||||
|
||||
user = User(name='John Doe', email='john@example.com')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
|
||||
We can easily test the constraints by assert_* functions::
|
||||
|
||||
|
||||
from sqlalchemy_utils import (
|
||||
assert_nullable,
|
||||
assert_non_nullable,
|
||||
assert_max_length
|
||||
)
|
||||
|
||||
assert_nullable(user, 'name')
|
||||
assert_non_nullable(user, 'email')
|
||||
assert_max_length(user, 'name', 200)
|
||||
|
||||
# raises AssertionError because the max length of email is 255
|
||||
assert_max_length(user, 'email', 300)
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.exc import DataError, IntegrityError
|
||||
|
||||
|
||||
class raises(object):
|
||||
def __init__(self, expected_exc):
|
||||
self.expected_exc = expected_exc
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type != self.expected_exc:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _update_field(obj, field, value):
|
||||
session = sa.orm.object_session(obj)
|
||||
table = sa.inspect(obj.__class__).columns[field].table
|
||||
query = table.update().values(**{field: value})
|
||||
session.execute(query)
|
||||
session.flush()
|
||||
|
||||
|
||||
def _expect_successful_update(obj, field, value, reraise_exc):
|
||||
try:
|
||||
_update_field(obj, field, value)
|
||||
except (reraise_exc) as e:
|
||||
session = sa.orm.object_session(obj)
|
||||
session.rollback()
|
||||
assert False, str(e)
|
||||
|
||||
|
||||
def _expect_failing_update(obj, field, value, expected_exc):
|
||||
with raises(expected_exc):
|
||||
_update_field(obj, field, value)
|
||||
session = sa.orm.object_session(obj)
|
||||
session.rollback()
|
||||
|
||||
|
||||
def assert_nullable(obj, column):
|
||||
"""
|
||||
Assert that given column is nullable. This is checked by running an SQL
|
||||
update that assigns given column as None.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
"""
|
||||
_expect_successful_update(obj, column, None, IntegrityError)
|
||||
|
||||
|
||||
def assert_non_nullable(obj, column):
|
||||
"""
|
||||
Assert that given column is not nullable. This is checked by running an SQL
|
||||
update that assigns given column as None.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
"""
|
||||
_expect_failing_update(obj, column, None, IntegrityError)
|
||||
|
||||
|
||||
def assert_max_length(obj, column, max_length):
|
||||
"""
|
||||
Assert that the given column is of given max length.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
"""
|
||||
_expect_successful_update(obj, column, u'a' * max_length, DataError)
|
||||
_expect_failing_update(obj, column, u'a' * (max_length + 1), DataError)
|
@@ -35,9 +35,9 @@ from .orm import (
|
||||
get_query_entities,
|
||||
get_tables,
|
||||
getdotattr,
|
||||
has_any_changes,
|
||||
has_changes,
|
||||
identity,
|
||||
is_loaded,
|
||||
naturally_equivalent,
|
||||
quote,
|
||||
table_name,
|
||||
@@ -62,9 +62,9 @@ __all__ = (
|
||||
'get_tables',
|
||||
'getdotattr',
|
||||
'group_foreign_keys',
|
||||
'has_any_changes',
|
||||
'has_changes',
|
||||
'identity',
|
||||
'is_loaded',
|
||||
'is_auto_assigned_date_column',
|
||||
'is_indexed_foreign_key',
|
||||
'make_order_by_deterministic',
|
||||
|
@@ -159,12 +159,18 @@ def has_index(column):
|
||||
has_index(table.c.locale) # False
|
||||
has_index(table.c.id) # True
|
||||
"""
|
||||
table = column.table
|
||||
if not isinstance(table, sa.Table):
|
||||
raise TypeError(
|
||||
'Only columns belonging to Table objects are supported. Given '
|
||||
'column belongs to %r.' % table
|
||||
)
|
||||
return (
|
||||
column is column.table.primary_key.columns.values()[0]
|
||||
column is table.primary_key.columns.values()[0]
|
||||
or
|
||||
any(
|
||||
index.columns.values()[0] is column
|
||||
for index in column.table.indexes
|
||||
for index in table.indexes
|
||||
)
|
||||
)
|
||||
|
||||
@@ -198,8 +204,17 @@ def has_unique_index(column):
|
||||
has_unique_index(table.c.is_published) # True
|
||||
has_unique_index(table.c.is_deleted) # False
|
||||
has_unique_index(table.c.id) # True
|
||||
|
||||
|
||||
:raises TypeError: if given column does not belong to a Table object
|
||||
"""
|
||||
pks = column.table.primary_key.columns
|
||||
table = column.table
|
||||
if not isinstance(table, sa.Table):
|
||||
raise TypeError(
|
||||
'Only columns belonging to Table objects are supported. Given '
|
||||
'column belongs to %r.' % table
|
||||
)
|
||||
pks = table.primary_key.columns
|
||||
return (
|
||||
(column is pks.values()[0] and len(pks) == 1)
|
||||
or
|
||||
@@ -384,12 +399,21 @@ def drop_database(url):
|
||||
engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||
|
||||
# Disconnect all users from the database we are dropping.
|
||||
version = list(
|
||||
map(
|
||||
int,
|
||||
engine.execute('SHOW server_version;').first()[0].split('.')
|
||||
)
|
||||
)
|
||||
pid_column = (
|
||||
'pid' if (version[0] >= 9 and version[1] >= 2) else 'procpid'
|
||||
)
|
||||
text = '''
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
SELECT pg_terminate_backend(pg_stat_activity.%(pid_column)s)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '%s'
|
||||
AND pid <> pg_backend_pid()
|
||||
''' % database
|
||||
WHERE pg_stat_activity.datname = '%(database)s'
|
||||
AND %(pid_column)s <> pg_backend_pid();
|
||||
''' % {'pid_column': pid_column, 'database': database}
|
||||
engine.execute(text)
|
||||
|
||||
# Drop the database.
|
||||
|
@@ -7,7 +7,7 @@ from sqlalchemy.exc import NoInspectionAvailable
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||
|
||||
from .orm import get_mapper, get_tables
|
||||
from .orm import get_column_key, get_mapper, get_tables
|
||||
from ..query_chain import QueryChain
|
||||
|
||||
|
||||
@@ -279,31 +279,40 @@ def dependent_objects(obj, foreign_keys=None):
|
||||
table in mapper.tables and
|
||||
not (parent_mapper and table in parent_mapper.tables)
|
||||
):
|
||||
criteria = []
|
||||
visited_constraints = []
|
||||
for key in keys:
|
||||
if key.constraint not in visited_constraints:
|
||||
visited_constraints.append(key.constraint)
|
||||
subcriteria = [
|
||||
getattr(class_, column.key) ==
|
||||
getattr(
|
||||
obj,
|
||||
key.constraint.elements[index].column.key
|
||||
)
|
||||
for index, column
|
||||
in enumerate(key.constraint.columns)
|
||||
]
|
||||
criteria.append(sa.and_(*subcriteria))
|
||||
|
||||
query = session.query(class_).filter(
|
||||
sa.or_(
|
||||
*criteria
|
||||
)
|
||||
sa.or_(*_get_criteria(keys, class_, obj))
|
||||
)
|
||||
chain.queries.append(query)
|
||||
return chain
|
||||
|
||||
|
||||
def _get_criteria(keys, class_, obj):
|
||||
criteria = []
|
||||
visited_constraints = []
|
||||
for key in keys:
|
||||
if key.constraint in visited_constraints:
|
||||
continue
|
||||
visited_constraints.append(key.constraint)
|
||||
|
||||
subcriteria = []
|
||||
for index, column in enumerate(key.constraint.columns):
|
||||
foreign_column = (
|
||||
key.constraint.elements[index].column
|
||||
)
|
||||
subcriteria.append(
|
||||
getattr(class_, get_column_key(class_, column)) ==
|
||||
getattr(
|
||||
obj,
|
||||
sa.inspect(type(obj))
|
||||
.get_property_by_column(
|
||||
foreign_column
|
||||
).key
|
||||
)
|
||||
)
|
||||
criteria.append(sa.and_(*subcriteria))
|
||||
return criteria
|
||||
|
||||
|
||||
def non_indexed_foreign_keys(metadata, engine=None):
|
||||
"""
|
||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||
@@ -347,10 +356,10 @@ def is_indexed_foreign_key(constraint):
|
||||
|
||||
:param constraint: ForeignKeyConstraint object to check the indexes
|
||||
"""
|
||||
for index in constraint.table.indexes:
|
||||
index_column_names = set(
|
||||
column.name for column in index.columns
|
||||
)
|
||||
if index_column_names == set(constraint.columns):
|
||||
return True
|
||||
return False
|
||||
return any(
|
||||
set(column.name for column in index.columns)
|
||||
==
|
||||
set(constraint.columns)
|
||||
for index
|
||||
in constraint.table.indexes
|
||||
)
|
||||
|
@@ -35,14 +35,21 @@ def get_column_key(model, column):
|
||||
get_column_key(User, User.__table__.c.name) # 'name'
|
||||
|
||||
.. versionadded: 0.26.5
|
||||
|
||||
.. versionchanged: 0.27.11
|
||||
Throws UnmappedColumnError instead of ValueError when no property was
|
||||
found for given column. This is consistent with how SQLAlchemy works.
|
||||
"""
|
||||
for key, c in sa.inspect(model).columns.items():
|
||||
if c is column:
|
||||
return key
|
||||
raise ValueError(
|
||||
"Class %s doesn't have a column '%s'",
|
||||
model.__name__,
|
||||
column
|
||||
mapper = sa.inspect(model)
|
||||
try:
|
||||
return mapper.get_property_by_column(column).key
|
||||
except sa.orm.exc.UnmappedColumnError:
|
||||
for key, c in mapper.columns.items():
|
||||
if c.name == column.name and c.table is column.table:
|
||||
return key
|
||||
raise sa.orm.exc.UnmappedColumnError(
|
||||
'No column %s is configured on mapper %s...' %
|
||||
(column, mapper)
|
||||
)
|
||||
|
||||
|
||||
@@ -77,6 +84,10 @@ def get_mapper(mixed):
|
||||
"""
|
||||
if isinstance(mixed, sa.orm.query._MapperEntity):
|
||||
mixed = mixed.expr
|
||||
elif isinstance(mixed, sa.Column):
|
||||
mixed = mixed.table
|
||||
elif isinstance(mixed, sa.orm.query._ColumnEntity):
|
||||
mixed = mixed.expr
|
||||
|
||||
if isinstance(mixed, sa.orm.Mapper):
|
||||
return mixed
|
||||
@@ -227,8 +238,6 @@ def get_tables(mixed):
|
||||
tables = sum((m.tables for m in polymorphic_mappers), [])
|
||||
else:
|
||||
tables = mapper.tables
|
||||
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
@@ -635,13 +644,7 @@ def getdotattr(obj_or_class, dot_path):
|
||||
for path in dot_path.split('.'):
|
||||
getter = attrgetter(path)
|
||||
if isinstance(last, list):
|
||||
tmp = []
|
||||
for el in last:
|
||||
if isinstance(el, list):
|
||||
tmp.extend(map(getter, el))
|
||||
else:
|
||||
tmp.append(getter(el))
|
||||
last = tmp
|
||||
last = sum((getter(el) for el in last), [])
|
||||
elif isinstance(last, InstrumentedAttribute):
|
||||
last = getter(last.property.mapper.class_)
|
||||
elif last is None:
|
||||
@@ -722,35 +725,37 @@ def has_changes(obj, attrs=None, exclude=None):
|
||||
)
|
||||
|
||||
|
||||
def has_any_changes(obj, columns):
|
||||
def is_loaded(obj, prop):
|
||||
"""
|
||||
Simple shortcut function for checking if any of the given attributes of
|
||||
given declarative model object have changes.
|
||||
|
||||
Return whether or not given property of given object has been loaded.
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import has_any_changes
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
content = sa.orm.deferred(sa.Column(sa.String))
|
||||
|
||||
|
||||
user = User()
|
||||
article = session.query(Article).get(5)
|
||||
|
||||
has_any_changes(user, ('name', )) # False
|
||||
# name gets loaded since its not a deferred property
|
||||
assert is_loaded(article, 'name')
|
||||
|
||||
user.name = u'someone'
|
||||
|
||||
has_any_changes(user, ('name', 'age')) # True
|
||||
# content has not yet been loaded since its a deferred property
|
||||
assert not is_loaded(article, 'content')
|
||||
|
||||
|
||||
.. versionadded: 0.26.3
|
||||
.. deprecated:: 0.26.6
|
||||
User :func:`has_changes` instead.
|
||||
.. versionadded: 0.27.8
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param attrs: Names of the attributes
|
||||
:param prop: Name of the property or InstrumentedAttribute
|
||||
"""
|
||||
return any(has_changes(obj, column) for column in columns)
|
||||
return not isinstance(
|
||||
getattr(sa.inspect(obj).attrs, prop).loaded_value,
|
||||
sa.util.langhelpers._symbol
|
||||
)
|
||||
|
||||
|
||||
def identity(obj_or_class):
|
||||
|
@@ -169,23 +169,28 @@ def make_order_by_deterministic(query):
|
||||
|
||||
.. versionadded: 0.27.1
|
||||
"""
|
||||
if not query._order_by:
|
||||
return query
|
||||
order_by_func = sa.asc
|
||||
|
||||
order_by = query._order_by[0]
|
||||
if isinstance(order_by, sa.sql.expression.UnaryExpression):
|
||||
if order_by.modifier == sa.sql.operators.desc_op:
|
||||
order_by_func = sa.desc
|
||||
else:
|
||||
order_by_func = sa.asc
|
||||
column = order_by.get_children()[0]
|
||||
if not query._order_by:
|
||||
column = None
|
||||
else:
|
||||
column = order_by
|
||||
order_by_func = sa.asc
|
||||
order_by = query._order_by[0]
|
||||
if isinstance(order_by, sa.sql.expression.UnaryExpression):
|
||||
if order_by.modifier == sa.sql.operators.desc_op:
|
||||
order_by_func = sa.desc
|
||||
else:
|
||||
order_by_func = sa.asc
|
||||
column = order_by.get_children()[0]
|
||||
else:
|
||||
column = order_by
|
||||
|
||||
# Queries that are ordered by an already
|
||||
if isinstance(column, sa.Column) and has_unique_index(column):
|
||||
return query
|
||||
if isinstance(column, sa.Column):
|
||||
try:
|
||||
if has_unique_index(column):
|
||||
return query
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
base_table = get_tables(query._entities[0])[0]
|
||||
query = query.order_by(
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
from .exceptions import ImproperlyConfigured
|
||||
|
||||
|
||||
@@ -21,3 +23,60 @@ except ImportError:
|
||||
'install babel or make a similar function and override it '
|
||||
'in this module.'
|
||||
)
|
||||
|
||||
|
||||
class TranslationHybrid(object):
|
||||
def __init__(self, current_locale, default_locale):
|
||||
self.current_locale = current_locale
|
||||
self.default_locale = default_locale
|
||||
|
||||
def cast_locale(self, obj, locale):
|
||||
"""
|
||||
Cast given locale to string. Supports also callbacks that return
|
||||
locales.
|
||||
"""
|
||||
if callable(locale):
|
||||
try:
|
||||
return str(locale())
|
||||
except TypeError:
|
||||
return str(locale(obj))
|
||||
return str(locale)
|
||||
|
||||
def getter_factory(self, attr):
|
||||
"""
|
||||
Return a hybrid_property getter function for given attribute. The
|
||||
returned getter first checks if object has translation for current
|
||||
locale. If not it tries to get translation for default locale. If there
|
||||
is no translation found for default locale it returns None.
|
||||
"""
|
||||
def getter(obj):
|
||||
current_locale = self.cast_locale(obj, self.current_locale)
|
||||
try:
|
||||
return getattr(obj, attr.key)[current_locale]
|
||||
except (TypeError, KeyError):
|
||||
default_locale = self.cast_locale(
|
||||
obj, self.default_locale
|
||||
)
|
||||
try:
|
||||
return getattr(obj, attr.key)[default_locale]
|
||||
except (TypeError, KeyError):
|
||||
return None
|
||||
return getter
|
||||
|
||||
def setter_factory(self, attr):
|
||||
def setter(obj, value):
|
||||
if getattr(obj, attr.key) is None:
|
||||
setattr(obj, attr.key, {})
|
||||
locale = self.cast_locale(obj, self.current_locale)
|
||||
getattr(obj, attr.key)[locale] = value
|
||||
return setter
|
||||
|
||||
def expr_factory(self, attr):
|
||||
return lambda cls: attr
|
||||
|
||||
def __call__(self, attr):
|
||||
return hybrid_property(
|
||||
fget=self.getter_factory(attr),
|
||||
fset=self.setter_factory(attr),
|
||||
expr=self.expr_factory(attr)
|
||||
)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import six
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
from sqlalchemy_utils.utils import str_coercible
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ except ImportError:
|
||||
BasePhoneNumber = object
|
||||
|
||||
|
||||
@str_coercible
|
||||
class PhoneNumber(BasePhoneNumber):
|
||||
'''
|
||||
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
||||
@@ -66,9 +68,6 @@ class PhoneNumber(BasePhoneNumber):
|
||||
def __unicode__(self):
|
||||
return self.national
|
||||
|
||||
def __str__(self):
|
||||
return six.text_type(self.national).encode('utf-8')
|
||||
|
||||
|
||||
class PhoneNumberType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||
class TestAggregateValueGenerationWithBackrefs(TestCase):
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
|
80
tests/aggregate/test_m2m_m2m.py
Normal file
80
tests/aggregate/test_m2m_m2m.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateManyToManyAndManyToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
catalog_products = sa.Table(
|
||||
'catalog_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
|
||||
catalogs = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products',
|
||||
secondary=catalog_products
|
||||
)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_insert(self):
|
||||
category = self.Category()
|
||||
products = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
]
|
||||
catalog = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
catalog2 = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
assert catalog.category_count == 1
|
||||
assert catalog2.category_count == 1
|
76
tests/aggregate/test_o2m_m2m.py
Normal file
76
tests/aggregate/test_o2m_m2m.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateOneToManyAndManyToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
|
||||
catalog = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products'
|
||||
)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_insert(self):
|
||||
category = self.Category()
|
||||
products = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
]
|
||||
catalog = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
products2 = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
]
|
||||
catalog2 = self.Catalog(products=products2)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
assert catalog.category_count == 1
|
||||
assert catalog2.category_count == 1
|
62
tests/aggregate/test_o2m_o2m.py
Normal file
62
tests/aggregate/test_o2m_o2m.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from decimal import Decimal
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateOneToManyAndOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_assigns_aggregates(self):
|
||||
category = self.Category(name=u'Some category')
|
||||
catalog = self.Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Some catalog'
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
category=category
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
@@ -1,76 +1,15 @@
|
||||
from decimal import Decimal
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestDeepModelPathsForAggregates(TestCase):
|
||||
class Test3LevelDeepOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_assigns_aggregates(self):
|
||||
category = self.Category(name=u'Some category')
|
||||
catalog = self.Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Some catalog'
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
category=category
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
|
||||
|
||||
class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
n = 1
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'categories.sub_categories.products',
|
||||
@@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
@@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='sub_category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
sub_category_id = sa.Column(
|
||||
@@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def catalog_factory(self):
|
||||
product = self.Product(
|
||||
name=u'Product %d' % self.n
|
||||
)
|
||||
product = self.Product()
|
||||
sub_category = self.SubCategory(
|
||||
name=u'SubCategory %d' % self.n,
|
||||
products=[product]
|
||||
)
|
||||
category = self.Category(
|
||||
name=u'Category %d' % self.n,
|
||||
sub_categories=[sub_category]
|
||||
)
|
||||
catalog = self.Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Catalog %d' % self.n
|
||||
category = self.Category(sub_categories=[sub_category])
|
||||
catalog = self.Catalog(categories=[category])
|
||||
self.session.add(catalog)
|
||||
self.n += 1
|
||||
return catalog
|
||||
|
||||
def test_only_updates_affected_aggregates(self):
|
||||
@@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
)
|
||||
|
||||
catalog.categories[0].sub_categories[0].products.append(
|
||||
self.Product(name=u'Product 3')
|
||||
self.Product()
|
||||
)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
58
tests/aggregate/test_with_column_alias.py
Normal file
58
tests/aggregate/test_with_column_alias.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregatedWithColumnAlias(TestCase):
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
@aggregated(
|
||||
'comments',
|
||||
sa.Column('_comment_count', sa.Integer, default=0)
|
||||
)
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
comments = sa.orm.relationship('Comment', backref='thread')
|
||||
|
||||
class Comment(self.Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||
|
||||
self.Thread = Thread
|
||||
self.Comment = Comment
|
||||
|
||||
def test_assigns_aggregates_on_insert(self):
|
||||
thread = self.Thread()
|
||||
self.session.add(thread)
|
||||
comment = self.Comment(thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_separate_insert(self):
|
||||
thread = self.Thread()
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
assert thread.comment_count == 1
|
||||
|
||||
def test_assigns_aggregates_on_delete(self):
|
||||
thread = self.Thread()
|
||||
self.session.add(thread)
|
||||
self.session.commit()
|
||||
comment = self.Comment(thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.delete(comment)
|
||||
self.session.commit()
|
||||
self.session.refresh(thread)
|
||||
assert thread.comment_count == 0
|
@@ -78,6 +78,85 @@ class TestDependentObjects(TestCase):
|
||||
assert objects[3] in deps
|
||||
|
||||
|
||||
class TestDependentObjectsWithColumnAliases(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
first_name = sa.Column(sa.Unicode(255))
|
||||
last_name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(
|
||||
'_author_id', sa.Integer, sa.ForeignKey('user.id')
|
||||
)
|
||||
owner_id = sa.Column(
|
||||
'_owner_id',
|
||||
sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL')
|
||||
)
|
||||
|
||||
author = sa.orm.relationship(User, foreign_keys=[author_id])
|
||||
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
owner_id = sa.Column(
|
||||
'_owner_id',
|
||||
sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE')
|
||||
)
|
||||
|
||||
owner = sa.orm.relationship(User)
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def test_returns_all_dependent_objects(self):
|
||||
user = self.User(first_name=u'John')
|
||||
articles = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(owner=user),
|
||||
self.Article(author=user, owner=user)
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
|
||||
deps = list(dependent_objects(user))
|
||||
assert len(deps) == 3
|
||||
assert articles[0] in deps
|
||||
assert articles[2] in deps
|
||||
assert articles[3] in deps
|
||||
|
||||
def test_with_foreign_keys_parameter(self):
|
||||
user = self.User(first_name=u'John')
|
||||
objects = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(owner=user),
|
||||
self.Article(author=user, owner=user),
|
||||
self.BlogPost(owner=user)
|
||||
]
|
||||
self.session.add_all(objects)
|
||||
self.session.commit()
|
||||
|
||||
deps = list(
|
||||
dependent_objects(
|
||||
user,
|
||||
(
|
||||
fk for fk in get_referencing_foreign_keys(self.User)
|
||||
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||
)
|
||||
).limit(5)
|
||||
)
|
||||
assert len(deps) == 2
|
||||
assert objects[0] in deps
|
||||
assert objects[3] in deps
|
||||
|
||||
|
||||
class TestDependentObjectsWithManyReferences(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
@@ -192,7 +271,6 @@ class TestDependentObjectsWithSingleTableInheritance(TestCase):
|
||||
'polymorphic_identity': u'blog_post'
|
||||
}
|
||||
|
||||
|
||||
self.Category = Category
|
||||
self.TextItem = TextItem
|
||||
self.Article = Article
|
||||
|
@@ -1,3 +1,4 @@
|
||||
from copy import copy
|
||||
from pytest import raises
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -15,7 +16,12 @@ class TestGetColumnKey(object):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.Unicode(255))
|
||||
|
||||
class Movie(Base):
|
||||
__tablename__ = 'movie'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
self.Building = Building
|
||||
self.Movie = Movie
|
||||
|
||||
def test_supports_aliases(self):
|
||||
assert (
|
||||
@@ -29,6 +35,10 @@ class TestGetColumnKey(object):
|
||||
'name'
|
||||
)
|
||||
|
||||
def test_supports_vague_matching_of_column_objects(self):
|
||||
column = copy(self.Building.__table__.c._name)
|
||||
assert get_column_key(self.Building, column) == 'name'
|
||||
|
||||
def test_throws_value_error_for_unknown_column(self):
|
||||
with raises(ValueError):
|
||||
get_column_key(self.Building, 'unknown')
|
||||
with raises(sa.orm.exc.UnmappedColumnError):
|
||||
get_column_key(self.Building, self.Movie.__table__.c.id)
|
||||
|
@@ -56,6 +56,18 @@ class TestGetMapper(object):
|
||||
sa.inspect(self.Building)
|
||||
)
|
||||
|
||||
def test_column(self):
|
||||
assert (
|
||||
get_mapper(self.Building.__table__.c.id) ==
|
||||
sa.inspect(self.Building)
|
||||
)
|
||||
|
||||
def test_column_of_an_alias(self):
|
||||
assert (
|
||||
get_mapper(sa.orm.aliased(self.Building.__table__).c.id) ==
|
||||
sa.inspect(self.Building)
|
||||
)
|
||||
|
||||
|
||||
class TestGetMapperWithQueryEntities(TestCase):
|
||||
def create_models(self):
|
||||
@@ -79,6 +91,10 @@ class TestGetMapperWithQueryEntities(TestCase):
|
||||
sa.inspect(self.Building)
|
||||
)
|
||||
|
||||
def test_column_entity(self):
|
||||
query = self.session.query(self.Building.id)
|
||||
assert get_mapper(query._entities[0]) == sa.inspect(self.Building)
|
||||
|
||||
|
||||
class TestGetMapperWithMultipleMappersFound(object):
|
||||
def setup_method(self, method):
|
||||
|
@@ -67,11 +67,12 @@ class TestGetDotAttr(TestCase):
|
||||
subsection = self.SubSection(section=section)
|
||||
subsubsection = self.SubSubSection(subsection=subsection)
|
||||
|
||||
assert getdotattr(document, 'sections') == [section]
|
||||
assert getdotattr(document, 'sections.subsections') == [
|
||||
[subsection]
|
||||
subsection
|
||||
]
|
||||
assert getdotattr(document, 'sections.subsections.subsubsections') == [
|
||||
[subsubsection]
|
||||
subsubsection
|
||||
]
|
||||
|
||||
def test_class_paths(self):
|
||||
|
@@ -1,24 +0,0 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import has_any_changes
|
||||
|
||||
|
||||
class TestHasAnyChanges(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
|
||||
self.Article = Article
|
||||
|
||||
def test_without_changed_attr(self):
|
||||
article = self.Article()
|
||||
assert not has_any_changes(article, ['title'])
|
||||
|
||||
def test_with_changed_attr(self):
|
||||
article = self.Article(title='Some title')
|
||||
assert has_any_changes(article, ['title', 'id'])
|
@@ -1,4 +1,5 @@
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import has_index
|
||||
@@ -23,6 +24,11 @@ class TestHasIndex(object):
|
||||
|
||||
self.table = ArticleTranslation.__table__
|
||||
|
||||
def test_column_that_belongs_to_an_alias(self):
|
||||
alias = sa.orm.aliased(self.table)
|
||||
with raises(TypeError):
|
||||
assert has_index(alias.c.id)
|
||||
|
||||
def test_compound_primary_key(self):
|
||||
assert has_index(self.table.c.id)
|
||||
assert not has_index(self.table.c.locale)
|
||||
|
@@ -1,10 +1,11 @@
|
||||
from pytest import raises
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import has_unique_index
|
||||
|
||||
|
||||
class TestHasIndex(object):
|
||||
class TestHasUniqueIndex(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -31,6 +32,11 @@ class TestHasIndex(object):
|
||||
def test_primary_key(self):
|
||||
assert has_unique_index(self.articles.c.id)
|
||||
|
||||
def test_column_of_aliased_table(self):
|
||||
alias = sa.orm.aliased(self.articles)
|
||||
with raises(TypeError):
|
||||
assert has_unique_index(alias.c.id)
|
||||
|
||||
def test_unique_index(self):
|
||||
assert has_unique_index(self.article_translations.c.is_deleted)
|
||||
|
||||
|
25
tests/functions/test_is_loaded.py
Normal file
25
tests/functions/test_is_loaded.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import is_loaded
|
||||
|
||||
|
||||
class TestIsLoaded(object):
|
||||
def setup_method(self, method):
|
||||
Base = declarative_base()
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.orm.deferred(sa.Column(sa.String(100)))
|
||||
|
||||
self.Article = Article
|
||||
|
||||
def test_loaded_property(self):
|
||||
article = self.Article(id=1)
|
||||
assert is_loaded(article, 'id')
|
||||
|
||||
def test_unloaded_property(self):
|
||||
article = self.Article(id=4)
|
||||
assert not is_loaded(article, 'title')
|
||||
|
@@ -17,7 +17,6 @@ class TestMakeOrderByDeterministic(TestCase):
|
||||
sa.func.lower(name)
|
||||
)
|
||||
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
@@ -31,6 +30,7 @@ class TestMakeOrderByDeterministic(TestCase):
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
|
||||
def test_column_property(self):
|
||||
query = self.session.query(self.User).order_by(self.User.email_lower)
|
||||
@@ -82,4 +82,10 @@ class TestMakeOrderByDeterministic(TestCase):
|
||||
def test_query_without_order_by(self):
|
||||
query = self.session.query(self.User)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert 'ORDER BY' not in str(query)
|
||||
assert 'ORDER BY "user".id' in str(query)
|
||||
|
||||
def test_alias(self):
|
||||
alias = sa.orm.aliased(self.User.__table__)
|
||||
query = self.session.query(alias).order_by(alias.c.name)
|
||||
query = make_order_by_deterministic(query)
|
||||
assert str(query).endswith('ORDER BY user_1.name, "user".id ASC')
|
||||
|
90
tests/test_asserts.py
Normal file
90
tests/test_asserts.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import sqlalchemy as sa
|
||||
import pytest
|
||||
from sqlalchemy_utils import (
|
||||
assert_nullable,
|
||||
assert_non_nullable,
|
||||
assert_max_length
|
||||
)
|
||||
from sqlalchemy_utils.asserts import raises
|
||||
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestRaises(object):
|
||||
def test_matching_exception(self):
|
||||
with raises(Exception):
|
||||
raise Exception()
|
||||
assert True
|
||||
|
||||
def test_non_matchin_exception(self):
|
||||
with pytest.raises(Exception):
|
||||
with raises(ValueError):
|
||||
raise Exception()
|
||||
|
||||
|
||||
class AssertionTestCase(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(20))
|
||||
age = sa.Column(sa.Integer, nullable=False)
|
||||
email = sa.Column(sa.String(200), nullable=False, unique=True)
|
||||
|
||||
self.User = User
|
||||
|
||||
def setup_method(self, method):
|
||||
TestCase.setup_method(self, method)
|
||||
user = self.User(name='Someone', email='someone@example.com', age=15)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
self.user = user
|
||||
|
||||
|
||||
class TestAssertNonNullable(AssertionTestCase):
|
||||
def test_non_nullable_column(self):
|
||||
# Test everything twice so that session gets rolled back properly
|
||||
assert_non_nullable(self.user, 'age')
|
||||
assert_non_nullable(self.user, 'age')
|
||||
|
||||
def test_nullable_column(self):
|
||||
with raises(AssertionError):
|
||||
assert_non_nullable(self.user, 'name')
|
||||
with raises(AssertionError):
|
||||
assert_non_nullable(self.user, 'name')
|
||||
|
||||
|
||||
class TestAssertNullable(AssertionTestCase):
|
||||
def test_nullable_column(self):
|
||||
assert_nullable(self.user, 'name')
|
||||
assert_nullable(self.user, 'name')
|
||||
|
||||
def test_non_nullable_column(self):
|
||||
with raises(AssertionError):
|
||||
assert_nullable(self.user, 'age')
|
||||
with raises(AssertionError):
|
||||
assert_nullable(self.user, 'age')
|
||||
|
||||
|
||||
class TestAssertMaxLength(AssertionTestCase):
|
||||
def test_with_max_length(self):
|
||||
assert_max_length(self.user, 'name', 20)
|
||||
assert_max_length(self.user, 'name', 20)
|
||||
|
||||
def test_with_non_nullable_column(self):
|
||||
assert_max_length(self.user, 'email', 200)
|
||||
assert_max_length(self.user, 'email', 200)
|
||||
|
||||
def test_smaller_than_max_length(self):
|
||||
with raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 19)
|
||||
with raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 19)
|
||||
|
||||
def test_bigger_than_max_length(self):
|
||||
with raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 21)
|
||||
with raises(AssertionError):
|
||||
assert_max_length(self.user, 'name', 21)
|
68
tests/test_translation_hybrid.py
Normal file
68
tests/test_translation_hybrid.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
from sqlalchemy_utils import TranslationHybrid
|
||||
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestTranslationHybrid(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class City(self.Base):
|
||||
__tablename__ = 'city'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name_translations = sa.Column(JSON())
|
||||
name = self.translation_hybrid(name_translations)
|
||||
locale = 'en'
|
||||
|
||||
self.City = City
|
||||
|
||||
def setup_method(self, method):
|
||||
self.translation_hybrid = TranslationHybrid('fi', 'en')
|
||||
TestCase.setup_method(self, method)
|
||||
|
||||
def test_using_hybrid_as_constructor(self):
|
||||
city = self.City(name='Helsinki')
|
||||
assert city.name_translations['fi'] == 'Helsinki'
|
||||
|
||||
def test_hybrid_as_expression(self):
|
||||
assert self.City.name == self.City.name_translations
|
||||
|
||||
def test_if_no_translation_exists_returns_none(self):
|
||||
city = self.City()
|
||||
assert city.name is None
|
||||
|
||||
def test_fall_back_to_default_translation(self):
|
||||
city = self.City(name_translations={'en': 'Helsinki'})
|
||||
self.translation_hybrid.current_locale = 'sv'
|
||||
assert city.name == 'Helsinki'
|
||||
|
||||
|
||||
class TestTranslationHybridWithDynamicDefaultLocale(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class City(self.Base):
|
||||
__tablename__ = 'city'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name_translations = sa.Column(JSON)
|
||||
name = self.translation_hybrid(name_translations)
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
self.City = City
|
||||
|
||||
def setup_method(self, method):
|
||||
self.translation_hybrid = TranslationHybrid(
|
||||
'fi',
|
||||
lambda self: self.locale
|
||||
)
|
||||
TestCase.setup_method(self, method)
|
||||
|
||||
def test_fallback_to_dynamic_locale(self):
|
||||
self.translation_hybrid.current_locale = 'en'
|
||||
city = self.City(name_translations={})
|
||||
city.locale = 'fi'
|
||||
city.name_translations['fi'] = 'Helsinki'
|
||||
|
||||
assert city.name == 'Helsinki'
|
@@ -1,6 +1,8 @@
|
||||
from pytest import mark
|
||||
from tests import TestCase
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from pytest import mark
|
||||
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils import PhoneNumberType, PhoneNumber
|
||||
from sqlalchemy_utils.types import phone_number
|
||||
|
||||
@@ -45,8 +47,11 @@ class TestPhoneNumber(object):
|
||||
|
||||
def test_phone_number_str_repr(self):
|
||||
number = PhoneNumber('+358401234567')
|
||||
assert number.__unicode__() == number.national
|
||||
assert number.__str__() == number.national.encode('utf-8')
|
||||
if six.PY2:
|
||||
assert unicode(number) == number.national
|
||||
assert str(number) == number.national.encode('utf-8')
|
||||
else:
|
||||
assert str(number) == number.national
|
||||
|
||||
|
||||
@mark.skipif('phone_number.phonenumbers is None')
|
||||
|
Reference in New Issue
Block a user