Merge branch 'master' into updates
This commit is contained in:
63
CHANGES.rst
63
CHANGES.rst
@@ -4,6 +4,69 @@ Changelog
|
|||||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||||
|
|
||||||
|
|
||||||
|
0.27.12 (2014-12-xx)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Fixed PhoneNumber string coercion (#93)
|
||||||
|
|
||||||
|
|
||||||
|
0.27.11 (2014-12-06)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Added loose typed column checking support for get_column_key
|
||||||
|
- Made get_column_key throw UnmappedColumnError to be consistent with SQLAlchemy
|
||||||
|
|
||||||
|
|
||||||
|
0.27.10 (2014-12-03)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Fixed column alias handling in dependent_objects
|
||||||
|
|
||||||
|
|
||||||
|
0.27.9 (2014-12-01)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Fixed aggregated decorator many-to-many relationship handling
|
||||||
|
- Fixed aggregated column alias handling
|
||||||
|
|
||||||
|
|
||||||
|
0.27.8 (2014-11-13)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Added is_loaded utility function
|
||||||
|
- Removed deprecated has_any_changes
|
||||||
|
|
||||||
|
|
||||||
|
0.27.7 (2014-11-03)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Added support for Column and ColumnEntity objects in get_mapper
|
||||||
|
- Made make_order_by_deterministic add deterministic column more aggressively
|
||||||
|
|
||||||
|
|
||||||
|
0.27.6 (2014-10-29)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Fixed assert_max_length not working with non nullable columns
|
||||||
|
- Add PostgreSQL < 9.2 support for drop_database
|
||||||
|
|
||||||
|
|
||||||
|
0.27.5 (2014-10-24)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Made assert_* functions automatically rollback session
|
||||||
|
- Changed make_order_by_deterministic attach order by primary key for queries without order by
|
||||||
|
- Fixed alias handling in has_unique_index
|
||||||
|
- Fixed alias handling in has_index
|
||||||
|
- Fixed alias handling in make_order_by_deterministic
|
||||||
|
|
||||||
|
|
||||||
|
0.27.4 (2014-10-23)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Added assert_non_nullable, assert_nullable and assert_max_length testing functions
|
||||||
|
|
||||||
|
|
||||||
0.27.3 (2014-10-22)
|
0.27.3 (2014-10-22)
|
||||||
^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@@ -20,4 +20,5 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ
|
|||||||
orm_helpers
|
orm_helpers
|
||||||
utility_classes
|
utility_classes
|
||||||
models
|
models
|
||||||
|
testing
|
||||||
license
|
license
|
||||||
|
@@ -76,6 +76,12 @@ identity
|
|||||||
.. autofunction:: identity
|
.. autofunction:: identity
|
||||||
|
|
||||||
|
|
||||||
|
is_loaded
|
||||||
|
^^^^^^^^^
|
||||||
|
|
||||||
|
.. autofunction:: is_loaded
|
||||||
|
|
||||||
|
|
||||||
make_order_by_deterministic
|
make_order_by_deterministic
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
20
docs/testing.rst
Normal file
20
docs/testing.rst
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
Testing
|
||||||
|
=======
|
||||||
|
|
||||||
|
.. automodule:: sqlalchemy_utils.asserts
|
||||||
|
|
||||||
|
|
||||||
|
assert_nullable
|
||||||
|
---------------
|
||||||
|
|
||||||
|
.. autofunction:: assert_nullable
|
||||||
|
|
||||||
|
assert_non_nullable
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
.. autofunction:: assert_non_nullable
|
||||||
|
|
||||||
|
assert_max_length
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
.. autofunction:: assert_max_length
|
2
setup.py
2
setup.py
@@ -47,7 +47,7 @@ extras_require = {
|
|||||||
'ipaddress': ['ipaddr'] if not PY3 else [],
|
'ipaddress': ['ipaddr'] if not PY3 else [],
|
||||||
'timezone': ['python-dateutil'],
|
'timezone': ['python-dateutil'],
|
||||||
'url': ['furl >= 0.3.5'] if not PY3 else [],
|
'url': ['furl >= 0.3.5'] if not PY3 else [],
|
||||||
'encrypted': ['cryptography==0.6']
|
'encrypted': ['cryptography>=0.6']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
from .aggregates import aggregated
|
from .aggregates import aggregated
|
||||||
|
from .asserts import assert_nullable, assert_non_nullable, assert_max_length
|
||||||
from .batch import batch_fetch, with_backrefs
|
from .batch import batch_fetch, with_backrefs
|
||||||
from .decorators import generates
|
from .decorators import generates
|
||||||
from .exceptions import ImproperlyConfigured
|
from .exceptions import ImproperlyConfigured
|
||||||
@@ -23,11 +24,11 @@ from .functions import (
|
|||||||
get_referencing_foreign_keys,
|
get_referencing_foreign_keys,
|
||||||
get_tables,
|
get_tables,
|
||||||
group_foreign_keys,
|
group_foreign_keys,
|
||||||
has_any_changes,
|
|
||||||
has_changes,
|
has_changes,
|
||||||
has_index,
|
has_index,
|
||||||
has_unique_index,
|
has_unique_index,
|
||||||
identity,
|
identity,
|
||||||
|
is_loaded,
|
||||||
merge_references,
|
merge_references,
|
||||||
mock_engine,
|
mock_engine,
|
||||||
naturally_equivalent,
|
naturally_equivalent,
|
||||||
@@ -36,6 +37,7 @@ from .functions import (
|
|||||||
sort_query,
|
sort_query,
|
||||||
table_name,
|
table_name,
|
||||||
)
|
)
|
||||||
|
from .i18n import TranslationHybrid
|
||||||
from .listeners import (
|
from .listeners import (
|
||||||
auto_delete_orphans,
|
auto_delete_orphans,
|
||||||
coercion_listener,
|
coercion_listener,
|
||||||
@@ -78,12 +80,15 @@ from .types import (
|
|||||||
from .models import Timestamp
|
from .models import Timestamp
|
||||||
|
|
||||||
|
|
||||||
__version__ = '0.27.3'
|
__version__ = '0.27.11'
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
aggregated,
|
aggregated,
|
||||||
analyze,
|
analyze,
|
||||||
|
assert_max_length,
|
||||||
|
assert_non_nullable,
|
||||||
|
assert_nullable,
|
||||||
auto_delete_orphans,
|
auto_delete_orphans,
|
||||||
batch_fetch,
|
batch_fetch,
|
||||||
coercion_listener,
|
coercion_listener,
|
||||||
@@ -109,11 +114,11 @@ __all__ = (
|
|||||||
get_referencing_foreign_keys,
|
get_referencing_foreign_keys,
|
||||||
get_tables,
|
get_tables,
|
||||||
group_foreign_keys,
|
group_foreign_keys,
|
||||||
has_any_changes,
|
|
||||||
has_changes,
|
has_changes,
|
||||||
has_index,
|
has_index,
|
||||||
identity,
|
identity,
|
||||||
instrumented_list,
|
instrumented_list,
|
||||||
|
is_loaded,
|
||||||
merge_references,
|
merge_references,
|
||||||
mock_engine,
|
mock_engine,
|
||||||
naturally_equivalent,
|
naturally_equivalent,
|
||||||
|
@@ -394,13 +394,64 @@ class AggregatedAttribute(declared_attr):
|
|||||||
self.relationship = relationship
|
self.relationship = relationship
|
||||||
|
|
||||||
def __get__(desc, self, cls):
|
def __get__(desc, self, cls):
|
||||||
|
value = (desc.fget, desc.relationship, desc.column)
|
||||||
if cls not in aggregated_attrs:
|
if cls not in aggregated_attrs:
|
||||||
aggregated_attrs[cls] = [(desc.fget, desc.relationship)]
|
aggregated_attrs[cls] = [value]
|
||||||
else:
|
else:
|
||||||
aggregated_attrs[cls].append((desc.fget, desc.relationship))
|
aggregated_attrs[cls].append(value)
|
||||||
return desc.column
|
return desc.column
|
||||||
|
|
||||||
|
|
||||||
|
def get_aggregate_query(agg_expr, relationships):
|
||||||
|
"""
|
||||||
|
Return a subquery for fetching an aggregate value of given aggregate
|
||||||
|
expression and given sequence of relationships.
|
||||||
|
|
||||||
|
The returned aggregate query can be used when updating denormalized column
|
||||||
|
value with query such as:
|
||||||
|
|
||||||
|
UPDATE table SET column = {aggregate_query}
|
||||||
|
WHERE {condition}
|
||||||
|
|
||||||
|
:param agg_expr:
|
||||||
|
an expression to be selected, for example sa.func.count('1')
|
||||||
|
:param relationships:
|
||||||
|
Sequence of relationships to be used for building the aggregate
|
||||||
|
query.
|
||||||
|
"""
|
||||||
|
from_ = relationships[0].mapper.class_.__table__
|
||||||
|
for relationship in relationships[0:-1]:
|
||||||
|
property_ = relationship.property
|
||||||
|
if property_.secondary is not None:
|
||||||
|
from_ = from_.join(
|
||||||
|
property_.secondary,
|
||||||
|
property_.secondaryjoin
|
||||||
|
)
|
||||||
|
|
||||||
|
from_ = (
|
||||||
|
from_
|
||||||
|
.join(
|
||||||
|
property_.parent.class_,
|
||||||
|
property_.primaryjoin
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prop = relationships[-1].property
|
||||||
|
condition = prop.primaryjoin
|
||||||
|
if prop.secondary is not None:
|
||||||
|
from_ = from_.join(
|
||||||
|
prop.secondary,
|
||||||
|
prop.secondaryjoin
|
||||||
|
)
|
||||||
|
|
||||||
|
query = sa.select(
|
||||||
|
[agg_expr],
|
||||||
|
from_obj=[from_]
|
||||||
|
)
|
||||||
|
|
||||||
|
return query.where(condition)
|
||||||
|
|
||||||
|
|
||||||
class AggregatedValue(object):
|
class AggregatedValue(object):
|
||||||
def __init__(self, class_, attr, relationships, expr):
|
def __init__(self, class_, attr, relationships, expr):
|
||||||
self.class_ = class_
|
self.class_ = class_
|
||||||
@@ -418,23 +469,7 @@ class AggregatedValue(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def aggregate_query(self):
|
def aggregate_query(self):
|
||||||
from_ = self.relationships[0].mapper.class_.__table__
|
query = get_aggregate_query(self.expr, self.relationships)
|
||||||
for relationship in self.relationships[0:-1]:
|
|
||||||
property_ = relationship.property
|
|
||||||
from_ = (
|
|
||||||
from_
|
|
||||||
.join(
|
|
||||||
property_.parent.class_,
|
|
||||||
property_.primaryjoin
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
query = sa.select(
|
|
||||||
[self.expr],
|
|
||||||
from_obj=[from_]
|
|
||||||
)
|
|
||||||
|
|
||||||
query = query.where(self.relationships[-1])
|
|
||||||
|
|
||||||
return query.correlate(self.class_).as_scalar()
|
return query.correlate(self.class_).as_scalar()
|
||||||
|
|
||||||
@@ -484,11 +519,22 @@ class AggregatedValue(object):
|
|||||||
property_ = self.relationships[-1].property
|
property_ = self.relationships[-1].property
|
||||||
|
|
||||||
from_ = property_.mapper.class_.__table__
|
from_ = property_.mapper.class_.__table__
|
||||||
for relationship in reversed(self.relationships[1:-1]):
|
for relationship in reversed(self.relationships[0:-1]):
|
||||||
property_ = relationship.property
|
property_ = relationship.property
|
||||||
from_ = (
|
if property_.secondary is not None:
|
||||||
from_.join(property_.mapper.class_, property_.primaryjoin)
|
from_ = from_.join(
|
||||||
)
|
property_.secondary,
|
||||||
|
property_.primaryjoin
|
||||||
|
)
|
||||||
|
from_ = from_.join(
|
||||||
|
property_.mapper.class_,
|
||||||
|
property_.secondaryjoin
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from_ = from_.join(
|
||||||
|
property_.mapper.class_,
|
||||||
|
property_.primaryjoin
|
||||||
|
)
|
||||||
return from_
|
return from_
|
||||||
|
|
||||||
def local_condition(self, prop, objects):
|
def local_condition(self, prop, objects):
|
||||||
@@ -532,7 +578,7 @@ class AggregationManager(object):
|
|||||||
|
|
||||||
def update_generator_registry(self):
|
def update_generator_registry(self):
|
||||||
for class_, attrs in six.iteritems(aggregated_attrs):
|
for class_, attrs in six.iteritems(aggregated_attrs):
|
||||||
for expr, relationship in attrs:
|
for expr, relationship, column in attrs:
|
||||||
relationships = []
|
relationships = []
|
||||||
rel_class = class_
|
rel_class = class_
|
||||||
|
|
||||||
@@ -544,7 +590,7 @@ class AggregationManager(object):
|
|||||||
self.generator_registry[rel_class].append(
|
self.generator_registry[rel_class].append(
|
||||||
AggregatedValue(
|
AggregatedValue(
|
||||||
class_=class_,
|
class_=class_,
|
||||||
attr=expr.__name__,
|
attr=column,
|
||||||
relationships=list(reversed(relationships)),
|
relationships=list(reversed(relationships)),
|
||||||
expr=expr(class_)
|
expr=expr(class_)
|
||||||
)
|
)
|
||||||
|
106
sqlalchemy_utils/asserts.py
Normal file
106
sqlalchemy_utils/asserts.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
The functions in this module can be used for testing that the constraints of
|
||||||
|
your models. Each assert function runs SQL UPDATEs that check for the existence
|
||||||
|
of given constraint. Consider the following model::
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String(200), nullable=True)
|
||||||
|
email = sa.Column(sa.String(255), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
user = User(name='John Doe', email='john@example.com')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
We can easily test the constraints by assert_* functions::
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy_utils import (
|
||||||
|
assert_nullable,
|
||||||
|
assert_non_nullable,
|
||||||
|
assert_max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_nullable(user, 'name')
|
||||||
|
assert_non_nullable(user, 'email')
|
||||||
|
assert_max_length(user, 'name', 200)
|
||||||
|
|
||||||
|
# raises AssertionError because the max length of email is 255
|
||||||
|
assert_max_length(user, 'email', 300)
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.exc import DataError, IntegrityError
|
||||||
|
|
||||||
|
|
||||||
|
class raises(object):
|
||||||
|
def __init__(self, expected_exc):
|
||||||
|
self.expected_exc = expected_exc
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if exc_type != self.expected_exc:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _update_field(obj, field, value):
|
||||||
|
session = sa.orm.object_session(obj)
|
||||||
|
table = sa.inspect(obj.__class__).columns[field].table
|
||||||
|
query = table.update().values(**{field: value})
|
||||||
|
session.execute(query)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def _expect_successful_update(obj, field, value, reraise_exc):
|
||||||
|
try:
|
||||||
|
_update_field(obj, field, value)
|
||||||
|
except (reraise_exc) as e:
|
||||||
|
session = sa.orm.object_session(obj)
|
||||||
|
session.rollback()
|
||||||
|
assert False, str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def _expect_failing_update(obj, field, value, expected_exc):
|
||||||
|
with raises(expected_exc):
|
||||||
|
_update_field(obj, field, value)
|
||||||
|
session = sa.orm.object_session(obj)
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_nullable(obj, column):
|
||||||
|
"""
|
||||||
|
Assert that given column is nullable. This is checked by running an SQL
|
||||||
|
update that assigns given column as None.
|
||||||
|
|
||||||
|
:param obj: SQLAlchemy declarative model object
|
||||||
|
:param column: Name of the column
|
||||||
|
"""
|
||||||
|
_expect_successful_update(obj, column, None, IntegrityError)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_non_nullable(obj, column):
|
||||||
|
"""
|
||||||
|
Assert that given column is not nullable. This is checked by running an SQL
|
||||||
|
update that assigns given column as None.
|
||||||
|
|
||||||
|
:param obj: SQLAlchemy declarative model object
|
||||||
|
:param column: Name of the column
|
||||||
|
"""
|
||||||
|
_expect_failing_update(obj, column, None, IntegrityError)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_max_length(obj, column, max_length):
|
||||||
|
"""
|
||||||
|
Assert that the given column is of given max length.
|
||||||
|
|
||||||
|
:param obj: SQLAlchemy declarative model object
|
||||||
|
:param column: Name of the column
|
||||||
|
"""
|
||||||
|
_expect_successful_update(obj, column, u'a' * max_length, DataError)
|
||||||
|
_expect_failing_update(obj, column, u'a' * (max_length + 1), DataError)
|
@@ -35,9 +35,9 @@ from .orm import (
|
|||||||
get_query_entities,
|
get_query_entities,
|
||||||
get_tables,
|
get_tables,
|
||||||
getdotattr,
|
getdotattr,
|
||||||
has_any_changes,
|
|
||||||
has_changes,
|
has_changes,
|
||||||
identity,
|
identity,
|
||||||
|
is_loaded,
|
||||||
naturally_equivalent,
|
naturally_equivalent,
|
||||||
quote,
|
quote,
|
||||||
table_name,
|
table_name,
|
||||||
@@ -62,9 +62,9 @@ __all__ = (
|
|||||||
'get_tables',
|
'get_tables',
|
||||||
'getdotattr',
|
'getdotattr',
|
||||||
'group_foreign_keys',
|
'group_foreign_keys',
|
||||||
'has_any_changes',
|
|
||||||
'has_changes',
|
'has_changes',
|
||||||
'identity',
|
'identity',
|
||||||
|
'is_loaded',
|
||||||
'is_auto_assigned_date_column',
|
'is_auto_assigned_date_column',
|
||||||
'is_indexed_foreign_key',
|
'is_indexed_foreign_key',
|
||||||
'make_order_by_deterministic',
|
'make_order_by_deterministic',
|
||||||
|
@@ -159,12 +159,18 @@ def has_index(column):
|
|||||||
has_index(table.c.locale) # False
|
has_index(table.c.locale) # False
|
||||||
has_index(table.c.id) # True
|
has_index(table.c.id) # True
|
||||||
"""
|
"""
|
||||||
|
table = column.table
|
||||||
|
if not isinstance(table, sa.Table):
|
||||||
|
raise TypeError(
|
||||||
|
'Only columns belonging to Table objects are supported. Given '
|
||||||
|
'column belongs to %r.' % table
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
column is column.table.primary_key.columns.values()[0]
|
column is table.primary_key.columns.values()[0]
|
||||||
or
|
or
|
||||||
any(
|
any(
|
||||||
index.columns.values()[0] is column
|
index.columns.values()[0] is column
|
||||||
for index in column.table.indexes
|
for index in table.indexes
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -198,8 +204,17 @@ def has_unique_index(column):
|
|||||||
has_unique_index(table.c.is_published) # True
|
has_unique_index(table.c.is_published) # True
|
||||||
has_unique_index(table.c.is_deleted) # False
|
has_unique_index(table.c.is_deleted) # False
|
||||||
has_unique_index(table.c.id) # True
|
has_unique_index(table.c.id) # True
|
||||||
|
|
||||||
|
|
||||||
|
:raises TypeError: if given column does not belong to a Table object
|
||||||
"""
|
"""
|
||||||
pks = column.table.primary_key.columns
|
table = column.table
|
||||||
|
if not isinstance(table, sa.Table):
|
||||||
|
raise TypeError(
|
||||||
|
'Only columns belonging to Table objects are supported. Given '
|
||||||
|
'column belongs to %r.' % table
|
||||||
|
)
|
||||||
|
pks = table.primary_key.columns
|
||||||
return (
|
return (
|
||||||
(column is pks.values()[0] and len(pks) == 1)
|
(column is pks.values()[0] and len(pks) == 1)
|
||||||
or
|
or
|
||||||
@@ -384,12 +399,21 @@ def drop_database(url):
|
|||||||
engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||||
|
|
||||||
# Disconnect all users from the database we are dropping.
|
# Disconnect all users from the database we are dropping.
|
||||||
|
version = list(
|
||||||
|
map(
|
||||||
|
int,
|
||||||
|
engine.execute('SHOW server_version;').first()[0].split('.')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pid_column = (
|
||||||
|
'pid' if (version[0] >= 9 and version[1] >= 2) else 'procpid'
|
||||||
|
)
|
||||||
text = '''
|
text = '''
|
||||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
SELECT pg_terminate_backend(pg_stat_activity.%(pid_column)s)
|
||||||
FROM pg_stat_activity
|
FROM pg_stat_activity
|
||||||
WHERE pg_stat_activity.datname = '%s'
|
WHERE pg_stat_activity.datname = '%(database)s'
|
||||||
AND pid <> pg_backend_pid()
|
AND %(pid_column)s <> pg_backend_pid();
|
||||||
''' % database
|
''' % {'pid_column': pid_column, 'database': database}
|
||||||
engine.execute(text)
|
engine.execute(text)
|
||||||
|
|
||||||
# Drop the database.
|
# Drop the database.
|
||||||
|
@@ -7,7 +7,7 @@ from sqlalchemy.exc import NoInspectionAvailable
|
|||||||
from sqlalchemy.orm import object_session
|
from sqlalchemy.orm import object_session
|
||||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||||
|
|
||||||
from .orm import get_mapper, get_tables
|
from .orm import get_column_key, get_mapper, get_tables
|
||||||
from ..query_chain import QueryChain
|
from ..query_chain import QueryChain
|
||||||
|
|
||||||
|
|
||||||
@@ -279,31 +279,40 @@ def dependent_objects(obj, foreign_keys=None):
|
|||||||
table in mapper.tables and
|
table in mapper.tables and
|
||||||
not (parent_mapper and table in parent_mapper.tables)
|
not (parent_mapper and table in parent_mapper.tables)
|
||||||
):
|
):
|
||||||
criteria = []
|
|
||||||
visited_constraints = []
|
|
||||||
for key in keys:
|
|
||||||
if key.constraint not in visited_constraints:
|
|
||||||
visited_constraints.append(key.constraint)
|
|
||||||
subcriteria = [
|
|
||||||
getattr(class_, column.key) ==
|
|
||||||
getattr(
|
|
||||||
obj,
|
|
||||||
key.constraint.elements[index].column.key
|
|
||||||
)
|
|
||||||
for index, column
|
|
||||||
in enumerate(key.constraint.columns)
|
|
||||||
]
|
|
||||||
criteria.append(sa.and_(*subcriteria))
|
|
||||||
|
|
||||||
query = session.query(class_).filter(
|
query = session.query(class_).filter(
|
||||||
sa.or_(
|
sa.or_(*_get_criteria(keys, class_, obj))
|
||||||
*criteria
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
chain.queries.append(query)
|
chain.queries.append(query)
|
||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
def _get_criteria(keys, class_, obj):
|
||||||
|
criteria = []
|
||||||
|
visited_constraints = []
|
||||||
|
for key in keys:
|
||||||
|
if key.constraint in visited_constraints:
|
||||||
|
continue
|
||||||
|
visited_constraints.append(key.constraint)
|
||||||
|
|
||||||
|
subcriteria = []
|
||||||
|
for index, column in enumerate(key.constraint.columns):
|
||||||
|
foreign_column = (
|
||||||
|
key.constraint.elements[index].column
|
||||||
|
)
|
||||||
|
subcriteria.append(
|
||||||
|
getattr(class_, get_column_key(class_, column)) ==
|
||||||
|
getattr(
|
||||||
|
obj,
|
||||||
|
sa.inspect(type(obj))
|
||||||
|
.get_property_by_column(
|
||||||
|
foreign_column
|
||||||
|
).key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
criteria.append(sa.and_(*subcriteria))
|
||||||
|
return criteria
|
||||||
|
|
||||||
|
|
||||||
def non_indexed_foreign_keys(metadata, engine=None):
|
def non_indexed_foreign_keys(metadata, engine=None):
|
||||||
"""
|
"""
|
||||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||||
@@ -347,10 +356,10 @@ def is_indexed_foreign_key(constraint):
|
|||||||
|
|
||||||
:param constraint: ForeignKeyConstraint object to check the indexes
|
:param constraint: ForeignKeyConstraint object to check the indexes
|
||||||
"""
|
"""
|
||||||
for index in constraint.table.indexes:
|
return any(
|
||||||
index_column_names = set(
|
set(column.name for column in index.columns)
|
||||||
column.name for column in index.columns
|
==
|
||||||
)
|
set(constraint.columns)
|
||||||
if index_column_names == set(constraint.columns):
|
for index
|
||||||
return True
|
in constraint.table.indexes
|
||||||
return False
|
)
|
||||||
|
@@ -35,14 +35,21 @@ def get_column_key(model, column):
|
|||||||
get_column_key(User, User.__table__.c.name) # 'name'
|
get_column_key(User, User.__table__.c.name) # 'name'
|
||||||
|
|
||||||
.. versionadded: 0.26.5
|
.. versionadded: 0.26.5
|
||||||
|
|
||||||
|
.. versionchanged: 0.27.11
|
||||||
|
Throws UnmappedColumnError instead of ValueError when no property was
|
||||||
|
found for given column. This is consistent with how SQLAlchemy works.
|
||||||
"""
|
"""
|
||||||
for key, c in sa.inspect(model).columns.items():
|
mapper = sa.inspect(model)
|
||||||
if c is column:
|
try:
|
||||||
return key
|
return mapper.get_property_by_column(column).key
|
||||||
raise ValueError(
|
except sa.orm.exc.UnmappedColumnError:
|
||||||
"Class %s doesn't have a column '%s'",
|
for key, c in mapper.columns.items():
|
||||||
model.__name__,
|
if c.name == column.name and c.table is column.table:
|
||||||
column
|
return key
|
||||||
|
raise sa.orm.exc.UnmappedColumnError(
|
||||||
|
'No column %s is configured on mapper %s...' %
|
||||||
|
(column, mapper)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -77,6 +84,10 @@ def get_mapper(mixed):
|
|||||||
"""
|
"""
|
||||||
if isinstance(mixed, sa.orm.query._MapperEntity):
|
if isinstance(mixed, sa.orm.query._MapperEntity):
|
||||||
mixed = mixed.expr
|
mixed = mixed.expr
|
||||||
|
elif isinstance(mixed, sa.Column):
|
||||||
|
mixed = mixed.table
|
||||||
|
elif isinstance(mixed, sa.orm.query._ColumnEntity):
|
||||||
|
mixed = mixed.expr
|
||||||
|
|
||||||
if isinstance(mixed, sa.orm.Mapper):
|
if isinstance(mixed, sa.orm.Mapper):
|
||||||
return mixed
|
return mixed
|
||||||
@@ -227,8 +238,6 @@ def get_tables(mixed):
|
|||||||
tables = sum((m.tables for m in polymorphic_mappers), [])
|
tables = sum((m.tables for m in polymorphic_mappers), [])
|
||||||
else:
|
else:
|
||||||
tables = mapper.tables
|
tables = mapper.tables
|
||||||
|
|
||||||
|
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
|
|
||||||
@@ -635,13 +644,7 @@ def getdotattr(obj_or_class, dot_path):
|
|||||||
for path in dot_path.split('.'):
|
for path in dot_path.split('.'):
|
||||||
getter = attrgetter(path)
|
getter = attrgetter(path)
|
||||||
if isinstance(last, list):
|
if isinstance(last, list):
|
||||||
tmp = []
|
last = sum((getter(el) for el in last), [])
|
||||||
for el in last:
|
|
||||||
if isinstance(el, list):
|
|
||||||
tmp.extend(map(getter, el))
|
|
||||||
else:
|
|
||||||
tmp.append(getter(el))
|
|
||||||
last = tmp
|
|
||||||
elif isinstance(last, InstrumentedAttribute):
|
elif isinstance(last, InstrumentedAttribute):
|
||||||
last = getter(last.property.mapper.class_)
|
last = getter(last.property.mapper.class_)
|
||||||
elif last is None:
|
elif last is None:
|
||||||
@@ -722,35 +725,37 @@ def has_changes(obj, attrs=None, exclude=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def has_any_changes(obj, columns):
|
def is_loaded(obj, prop):
|
||||||
"""
|
"""
|
||||||
Simple shortcut function for checking if any of the given attributes of
|
Return whether or not given property of given object has been loaded.
|
||||||
given declarative model object have changes.
|
|
||||||
|
|
||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
|
class Article(Base):
|
||||||
from sqlalchemy_utils import has_any_changes
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String)
|
||||||
|
content = sa.orm.deferred(sa.Column(sa.String))
|
||||||
|
|
||||||
|
|
||||||
user = User()
|
article = session.query(Article).get(5)
|
||||||
|
|
||||||
has_any_changes(user, ('name', )) # False
|
# name gets loaded since its not a deferred property
|
||||||
|
assert is_loaded(article, 'name')
|
||||||
|
|
||||||
user.name = u'someone'
|
# content has not yet been loaded since its a deferred property
|
||||||
|
assert not is_loaded(article, 'content')
|
||||||
has_any_changes(user, ('name', 'age')) # True
|
|
||||||
|
|
||||||
|
|
||||||
.. versionadded: 0.26.3
|
.. versionadded: 0.27.8
|
||||||
.. deprecated:: 0.26.6
|
|
||||||
User :func:`has_changes` instead.
|
|
||||||
|
|
||||||
:param obj: SQLAlchemy declarative model object
|
:param obj: SQLAlchemy declarative model object
|
||||||
:param attrs: Names of the attributes
|
:param prop: Name of the property or InstrumentedAttribute
|
||||||
"""
|
"""
|
||||||
return any(has_changes(obj, column) for column in columns)
|
return not isinstance(
|
||||||
|
getattr(sa.inspect(obj).attrs, prop).loaded_value,
|
||||||
|
sa.util.langhelpers._symbol
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def identity(obj_or_class):
|
def identity(obj_or_class):
|
||||||
|
@@ -169,23 +169,28 @@ def make_order_by_deterministic(query):
|
|||||||
|
|
||||||
.. versionadded: 0.27.1
|
.. versionadded: 0.27.1
|
||||||
"""
|
"""
|
||||||
if not query._order_by:
|
order_by_func = sa.asc
|
||||||
return query
|
|
||||||
|
|
||||||
order_by = query._order_by[0]
|
if not query._order_by:
|
||||||
if isinstance(order_by, sa.sql.expression.UnaryExpression):
|
column = None
|
||||||
if order_by.modifier == sa.sql.operators.desc_op:
|
|
||||||
order_by_func = sa.desc
|
|
||||||
else:
|
|
||||||
order_by_func = sa.asc
|
|
||||||
column = order_by.get_children()[0]
|
|
||||||
else:
|
else:
|
||||||
column = order_by
|
order_by = query._order_by[0]
|
||||||
order_by_func = sa.asc
|
if isinstance(order_by, sa.sql.expression.UnaryExpression):
|
||||||
|
if order_by.modifier == sa.sql.operators.desc_op:
|
||||||
|
order_by_func = sa.desc
|
||||||
|
else:
|
||||||
|
order_by_func = sa.asc
|
||||||
|
column = order_by.get_children()[0]
|
||||||
|
else:
|
||||||
|
column = order_by
|
||||||
|
|
||||||
# Queries that are ordered by an already
|
# Queries that are ordered by an already
|
||||||
if isinstance(column, sa.Column) and has_unique_index(column):
|
if isinstance(column, sa.Column):
|
||||||
return query
|
try:
|
||||||
|
if has_unique_index(column):
|
||||||
|
return query
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
base_table = get_tables(query._entities[0])[0]
|
base_table = get_tables(query._entities[0])[0]
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
|
|
||||||
from .exceptions import ImproperlyConfigured
|
from .exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
|
|
||||||
@@ -21,3 +23,60 @@ except ImportError:
|
|||||||
'install babel or make a similar function and override it '
|
'install babel or make a similar function and override it '
|
||||||
'in this module.'
|
'in this module.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationHybrid(object):
|
||||||
|
def __init__(self, current_locale, default_locale):
|
||||||
|
self.current_locale = current_locale
|
||||||
|
self.default_locale = default_locale
|
||||||
|
|
||||||
|
def cast_locale(self, obj, locale):
|
||||||
|
"""
|
||||||
|
Cast given locale to string. Supports also callbacks that return
|
||||||
|
locales.
|
||||||
|
"""
|
||||||
|
if callable(locale):
|
||||||
|
try:
|
||||||
|
return str(locale())
|
||||||
|
except TypeError:
|
||||||
|
return str(locale(obj))
|
||||||
|
return str(locale)
|
||||||
|
|
||||||
|
def getter_factory(self, attr):
|
||||||
|
"""
|
||||||
|
Return a hybrid_property getter function for given attribute. The
|
||||||
|
returned getter first checks if object has translation for current
|
||||||
|
locale. If not it tries to get translation for default locale. If there
|
||||||
|
is no translation found for default locale it returns None.
|
||||||
|
"""
|
||||||
|
def getter(obj):
|
||||||
|
current_locale = self.cast_locale(obj, self.current_locale)
|
||||||
|
try:
|
||||||
|
return getattr(obj, attr.key)[current_locale]
|
||||||
|
except (TypeError, KeyError):
|
||||||
|
default_locale = self.cast_locale(
|
||||||
|
obj, self.default_locale
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return getattr(obj, attr.key)[default_locale]
|
||||||
|
except (TypeError, KeyError):
|
||||||
|
return None
|
||||||
|
return getter
|
||||||
|
|
||||||
|
def setter_factory(self, attr):
|
||||||
|
def setter(obj, value):
|
||||||
|
if getattr(obj, attr.key) is None:
|
||||||
|
setattr(obj, attr.key, {})
|
||||||
|
locale = self.cast_locale(obj, self.current_locale)
|
||||||
|
getattr(obj, attr.key)[locale] = value
|
||||||
|
return setter
|
||||||
|
|
||||||
|
def expr_factory(self, attr):
|
||||||
|
return lambda cls: attr
|
||||||
|
|
||||||
|
def __call__(self, attr):
|
||||||
|
return hybrid_property(
|
||||||
|
fget=self.getter_factory(attr),
|
||||||
|
fset=self.setter_factory(attr),
|
||||||
|
expr=self.expr_factory(attr)
|
||||||
|
)
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import six
|
import six
|
||||||
from sqlalchemy import types
|
from sqlalchemy import types
|
||||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||||
|
from sqlalchemy_utils.utils import str_coercible
|
||||||
from .scalar_coercible import ScalarCoercible
|
from .scalar_coercible import ScalarCoercible
|
||||||
|
|
||||||
|
|
||||||
@@ -13,6 +14,7 @@ except ImportError:
|
|||||||
BasePhoneNumber = object
|
BasePhoneNumber = object
|
||||||
|
|
||||||
|
|
||||||
|
@str_coercible
|
||||||
class PhoneNumber(BasePhoneNumber):
|
class PhoneNumber(BasePhoneNumber):
|
||||||
'''
|
'''
|
||||||
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
||||||
@@ -66,9 +68,6 @@ class PhoneNumber(BasePhoneNumber):
|
|||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
return self.national
|
return self.national
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return six.text_type(self.national).encode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
class PhoneNumberType(types.TypeDecorator, ScalarCoercible):
|
class PhoneNumberType(types.TypeDecorator, ScalarCoercible):
|
||||||
"""
|
"""
|
||||||
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated
|
|||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
class TestAggregateValueGenerationWithBackrefs(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Thread(self.Base):
|
class Thread(self.Base):
|
||||||
__tablename__ = 'thread'
|
__tablename__ = 'thread'
|
||||||
|
80
tests/aggregate/test_m2m_m2m.py
Normal file
80
tests/aggregate/test_m2m_m2m.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from sqlalchemy_utils import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateManyToManyAndManyToMany(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
catalog_products = sa.Table(
|
||||||
|
'catalog_product',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||||
|
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||||
|
)
|
||||||
|
|
||||||
|
product_categories = sa.Table(
|
||||||
|
'category_product',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||||
|
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||||
|
)
|
||||||
|
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated(
|
||||||
|
'products.categories',
|
||||||
|
sa.Column(sa.Integer, default=0)
|
||||||
|
)
|
||||||
|
def category_count(self):
|
||||||
|
return sa.func.count(sa.distinct(Category.id))
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
catalog_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey('catalog.id')
|
||||||
|
)
|
||||||
|
|
||||||
|
catalogs = sa.orm.relationship(
|
||||||
|
Catalog,
|
||||||
|
backref='products',
|
||||||
|
secondary=catalog_products
|
||||||
|
)
|
||||||
|
|
||||||
|
categories = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
backref='products',
|
||||||
|
secondary=product_categories
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_insert(self):
|
||||||
|
category = self.Category()
|
||||||
|
products = [
|
||||||
|
self.Product(categories=[category]),
|
||||||
|
self.Product(categories=[category])
|
||||||
|
]
|
||||||
|
catalog = self.Catalog(products=products)
|
||||||
|
self.session.add(catalog)
|
||||||
|
catalog2 = self.Catalog(products=products)
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
assert catalog.category_count == 1
|
||||||
|
assert catalog2.category_count == 1
|
76
tests/aggregate/test_o2m_m2m.py
Normal file
76
tests/aggregate/test_o2m_m2m.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from sqlalchemy_utils import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateOneToManyAndManyToMany(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
product_categories = sa.Table(
|
||||||
|
'category_product',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||||
|
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||||
|
)
|
||||||
|
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated(
|
||||||
|
'products.categories',
|
||||||
|
sa.Column(sa.Integer, default=0)
|
||||||
|
)
|
||||||
|
def category_count(self):
|
||||||
|
return sa.func.count(sa.distinct(Category.id))
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
catalog_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey('catalog.id')
|
||||||
|
)
|
||||||
|
|
||||||
|
catalog = sa.orm.relationship(
|
||||||
|
Catalog,
|
||||||
|
backref='products'
|
||||||
|
)
|
||||||
|
|
||||||
|
categories = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
backref='products',
|
||||||
|
secondary=product_categories
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_insert(self):
|
||||||
|
category = self.Category()
|
||||||
|
products = [
|
||||||
|
self.Product(categories=[category]),
|
||||||
|
self.Product(categories=[category])
|
||||||
|
]
|
||||||
|
catalog = self.Catalog(products=products)
|
||||||
|
self.session.add(catalog)
|
||||||
|
products2 = [
|
||||||
|
self.Product(categories=[category]),
|
||||||
|
self.Product(categories=[category])
|
||||||
|
]
|
||||||
|
catalog2 = self.Catalog(products=products2)
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
assert catalog.category_count == 1
|
||||||
|
assert catalog2.category_count == 1
|
62
tests/aggregate/test_o2m_o2m.py
Normal file
62
tests/aggregate/test_o2m_o2m.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.aggregates import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateOneToManyAndOneToMany(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated(
|
||||||
|
'categories.products',
|
||||||
|
sa.Column(sa.Integer, default=0)
|
||||||
|
)
|
||||||
|
def product_count(self):
|
||||||
|
return sa.func.count('1')
|
||||||
|
|
||||||
|
categories = sa.orm.relationship('Category', backref='catalog')
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||||
|
|
||||||
|
products = sa.orm.relationship('Product', backref='category')
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_assigns_aggregates(self):
|
||||||
|
category = self.Category(name=u'Some category')
|
||||||
|
catalog = self.Catalog(
|
||||||
|
categories=[category]
|
||||||
|
)
|
||||||
|
catalog.name = u'Some catalog'
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
product = self.Product(
|
||||||
|
name=u'Some product',
|
||||||
|
price=Decimal('1000'),
|
||||||
|
category=category
|
||||||
|
)
|
||||||
|
self.session.add(product)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(catalog)
|
||||||
|
assert catalog.product_count == 1
|
@@ -1,76 +1,15 @@
|
|||||||
from decimal import Decimal
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy_utils.aggregates import aggregated
|
|
||||||
|
from sqlalchemy_utils import aggregated
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
class Test3LevelDeepOneToMany(TestCase):
|
||||||
class TestDeepModelPathsForAggregates(TestCase):
|
|
||||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Catalog(self.Base):
|
class Catalog(self.Base):
|
||||||
__tablename__ = 'catalog'
|
__tablename__ = 'catalog'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
@aggregated(
|
|
||||||
'categories.products',
|
|
||||||
sa.Column(sa.Integer, default=0)
|
|
||||||
)
|
|
||||||
def product_count(self):
|
|
||||||
return sa.func.count('1')
|
|
||||||
|
|
||||||
categories = sa.orm.relationship('Category', backref='catalog')
|
|
||||||
|
|
||||||
class Category(self.Base):
|
|
||||||
__tablename__ = 'category'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
|
||||||
|
|
||||||
products = sa.orm.relationship('Product', backref='category')
|
|
||||||
|
|
||||||
class Product(self.Base):
|
|
||||||
__tablename__ = 'product'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
price = sa.Column(sa.Numeric)
|
|
||||||
|
|
||||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
|
||||||
|
|
||||||
self.Catalog = Catalog
|
|
||||||
self.Category = Category
|
|
||||||
self.Product = Product
|
|
||||||
|
|
||||||
def test_assigns_aggregates(self):
|
|
||||||
category = self.Category(name=u'Some category')
|
|
||||||
catalog = self.Catalog(
|
|
||||||
categories=[category]
|
|
||||||
)
|
|
||||||
catalog.name = u'Some catalog'
|
|
||||||
self.session.add(catalog)
|
|
||||||
self.session.commit()
|
|
||||||
product = self.Product(
|
|
||||||
name=u'Some product',
|
|
||||||
price=Decimal('1000'),
|
|
||||||
category=category
|
|
||||||
)
|
|
||||||
self.session.add(product)
|
|
||||||
self.session.commit()
|
|
||||||
self.session.refresh(catalog)
|
|
||||||
assert catalog.product_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|
||||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
|
||||||
n = 1
|
|
||||||
|
|
||||||
def create_models(self):
|
|
||||||
class Catalog(self.Base):
|
|
||||||
__tablename__ = 'catalog'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
@aggregated(
|
@aggregated(
|
||||||
'categories.sub_categories.products',
|
'categories.sub_categories.products',
|
||||||
@@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
class Category(self.Base):
|
class Category(self.Base):
|
||||||
__tablename__ = 'category'
|
__tablename__ = 'category'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||||
|
|
||||||
sub_categories = sa.orm.relationship(
|
sub_categories = sa.orm.relationship(
|
||||||
@@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
class SubCategory(self.Base):
|
class SubCategory(self.Base):
|
||||||
__tablename__ = 'sub_category'
|
__tablename__ = 'sub_category'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||||
|
|
||||||
products = sa.orm.relationship('Product', backref='sub_category')
|
products = sa.orm.relationship('Product', backref='sub_category')
|
||||||
|
|
||||||
class Product(self.Base):
|
class Product(self.Base):
|
||||||
__tablename__ = 'product'
|
__tablename__ = 'product'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
price = sa.Column(sa.Numeric)
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
sub_category_id = sa.Column(
|
sub_category_id = sa.Column(
|
||||||
@@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
assert catalog.product_count == 1
|
assert catalog.product_count == 1
|
||||||
|
|
||||||
def catalog_factory(self):
|
def catalog_factory(self):
|
||||||
product = self.Product(
|
product = self.Product()
|
||||||
name=u'Product %d' % self.n
|
|
||||||
)
|
|
||||||
sub_category = self.SubCategory(
|
sub_category = self.SubCategory(
|
||||||
name=u'SubCategory %d' % self.n,
|
|
||||||
products=[product]
|
products=[product]
|
||||||
)
|
)
|
||||||
category = self.Category(
|
category = self.Category(sub_categories=[sub_category])
|
||||||
name=u'Category %d' % self.n,
|
catalog = self.Catalog(categories=[category])
|
||||||
sub_categories=[sub_category]
|
|
||||||
)
|
|
||||||
catalog = self.Catalog(
|
|
||||||
categories=[category]
|
|
||||||
)
|
|
||||||
catalog.name = u'Catalog %d' % self.n
|
|
||||||
self.session.add(catalog)
|
self.session.add(catalog)
|
||||||
self.n += 1
|
|
||||||
return catalog
|
return catalog
|
||||||
|
|
||||||
def test_only_updates_affected_aggregates(self):
|
def test_only_updates_affected_aggregates(self):
|
||||||
@@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
catalog.categories[0].sub_categories[0].products.append(
|
catalog.categories[0].sub_categories[0].products.append(
|
||||||
self.Product(name=u'Product 3')
|
self.Product()
|
||||||
)
|
)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
self.session.refresh(catalog)
|
self.session.refresh(catalog)
|
58
tests/aggregate/test_with_column_alias.py
Normal file
58
tests/aggregate/test_with_column_alias.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.aggregates import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregatedWithColumnAlias(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Thread(self.Base):
|
||||||
|
__tablename__ = 'thread'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
@aggregated(
|
||||||
|
'comments',
|
||||||
|
sa.Column('_comment_count', sa.Integer, default=0)
|
||||||
|
)
|
||||||
|
def comment_count(self):
|
||||||
|
return sa.func.count('1')
|
||||||
|
|
||||||
|
comments = sa.orm.relationship('Comment', backref='thread')
|
||||||
|
|
||||||
|
class Comment(self.Base):
|
||||||
|
__tablename__ = 'comment'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||||
|
|
||||||
|
self.Thread = Thread
|
||||||
|
self.Comment = Comment
|
||||||
|
|
||||||
|
def test_assigns_aggregates_on_insert(self):
|
||||||
|
thread = self.Thread()
|
||||||
|
self.session.add(thread)
|
||||||
|
comment = self.Comment(thread=thread)
|
||||||
|
self.session.add(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(thread)
|
||||||
|
assert thread.comment_count == 1
|
||||||
|
|
||||||
|
def test_assigns_aggregates_on_separate_insert(self):
|
||||||
|
thread = self.Thread()
|
||||||
|
self.session.add(thread)
|
||||||
|
self.session.commit()
|
||||||
|
comment = self.Comment(thread=thread)
|
||||||
|
self.session.add(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(thread)
|
||||||
|
assert thread.comment_count == 1
|
||||||
|
|
||||||
|
def test_assigns_aggregates_on_delete(self):
|
||||||
|
thread = self.Thread()
|
||||||
|
self.session.add(thread)
|
||||||
|
self.session.commit()
|
||||||
|
comment = self.Comment(thread=thread)
|
||||||
|
self.session.add(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.delete(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(thread)
|
||||||
|
assert thread.comment_count == 0
|
@@ -78,6 +78,85 @@ class TestDependentObjects(TestCase):
|
|||||||
assert objects[3] in deps
|
assert objects[3] in deps
|
||||||
|
|
||||||
|
|
||||||
|
class TestDependentObjectsWithColumnAliases(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
first_name = sa.Column(sa.Unicode(255))
|
||||||
|
last_name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Article(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
author_id = sa.Column(
|
||||||
|
'_author_id', sa.Integer, sa.ForeignKey('user.id')
|
||||||
|
)
|
||||||
|
owner_id = sa.Column(
|
||||||
|
'_owner_id',
|
||||||
|
sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL')
|
||||||
|
)
|
||||||
|
|
||||||
|
author = sa.orm.relationship(User, foreign_keys=[author_id])
|
||||||
|
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
|
||||||
|
|
||||||
|
class BlogPost(self.Base):
|
||||||
|
__tablename__ = 'blog_post'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
owner_id = sa.Column(
|
||||||
|
'_owner_id',
|
||||||
|
sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE')
|
||||||
|
)
|
||||||
|
|
||||||
|
owner = sa.orm.relationship(User)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Article = Article
|
||||||
|
self.BlogPost = BlogPost
|
||||||
|
|
||||||
|
def test_returns_all_dependent_objects(self):
|
||||||
|
user = self.User(first_name=u'John')
|
||||||
|
articles = [
|
||||||
|
self.Article(author=user),
|
||||||
|
self.Article(),
|
||||||
|
self.Article(owner=user),
|
||||||
|
self.Article(author=user, owner=user)
|
||||||
|
]
|
||||||
|
self.session.add_all(articles)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
deps = list(dependent_objects(user))
|
||||||
|
assert len(deps) == 3
|
||||||
|
assert articles[0] in deps
|
||||||
|
assert articles[2] in deps
|
||||||
|
assert articles[3] in deps
|
||||||
|
|
||||||
|
def test_with_foreign_keys_parameter(self):
|
||||||
|
user = self.User(first_name=u'John')
|
||||||
|
objects = [
|
||||||
|
self.Article(author=user),
|
||||||
|
self.Article(),
|
||||||
|
self.Article(owner=user),
|
||||||
|
self.Article(author=user, owner=user),
|
||||||
|
self.BlogPost(owner=user)
|
||||||
|
]
|
||||||
|
self.session.add_all(objects)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
deps = list(
|
||||||
|
dependent_objects(
|
||||||
|
user,
|
||||||
|
(
|
||||||
|
fk for fk in get_referencing_foreign_keys(self.User)
|
||||||
|
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||||
|
)
|
||||||
|
).limit(5)
|
||||||
|
)
|
||||||
|
assert len(deps) == 2
|
||||||
|
assert objects[0] in deps
|
||||||
|
assert objects[3] in deps
|
||||||
|
|
||||||
|
|
||||||
class TestDependentObjectsWithManyReferences(TestCase):
|
class TestDependentObjectsWithManyReferences(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class User(self.Base):
|
class User(self.Base):
|
||||||
@@ -192,7 +271,6 @@ class TestDependentObjectsWithSingleTableInheritance(TestCase):
|
|||||||
'polymorphic_identity': u'blog_post'
|
'polymorphic_identity': u'blog_post'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
self.Category = Category
|
self.Category = Category
|
||||||
self.TextItem = TextItem
|
self.TextItem = TextItem
|
||||||
self.Article = Article
|
self.Article = Article
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
from copy import copy
|
||||||
from pytest import raises
|
from pytest import raises
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@@ -15,7 +16,12 @@ class TestGetColumnKey(object):
|
|||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column('_name', sa.Unicode(255))
|
name = sa.Column('_name', sa.Unicode(255))
|
||||||
|
|
||||||
|
class Movie(Base):
|
||||||
|
__tablename__ = 'movie'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
self.Building = Building
|
self.Building = Building
|
||||||
|
self.Movie = Movie
|
||||||
|
|
||||||
def test_supports_aliases(self):
|
def test_supports_aliases(self):
|
||||||
assert (
|
assert (
|
||||||
@@ -29,6 +35,10 @@ class TestGetColumnKey(object):
|
|||||||
'name'
|
'name'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_supports_vague_matching_of_column_objects(self):
|
||||||
|
column = copy(self.Building.__table__.c._name)
|
||||||
|
assert get_column_key(self.Building, column) == 'name'
|
||||||
|
|
||||||
def test_throws_value_error_for_unknown_column(self):
|
def test_throws_value_error_for_unknown_column(self):
|
||||||
with raises(ValueError):
|
with raises(sa.orm.exc.UnmappedColumnError):
|
||||||
get_column_key(self.Building, 'unknown')
|
get_column_key(self.Building, self.Movie.__table__.c.id)
|
||||||
|
@@ -56,6 +56,18 @@ class TestGetMapper(object):
|
|||||||
sa.inspect(self.Building)
|
sa.inspect(self.Building)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_column(self):
|
||||||
|
assert (
|
||||||
|
get_mapper(self.Building.__table__.c.id) ==
|
||||||
|
sa.inspect(self.Building)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_column_of_an_alias(self):
|
||||||
|
assert (
|
||||||
|
get_mapper(sa.orm.aliased(self.Building.__table__).c.id) ==
|
||||||
|
sa.inspect(self.Building)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestGetMapperWithQueryEntities(TestCase):
|
class TestGetMapperWithQueryEntities(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
@@ -79,6 +91,10 @@ class TestGetMapperWithQueryEntities(TestCase):
|
|||||||
sa.inspect(self.Building)
|
sa.inspect(self.Building)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_column_entity(self):
|
||||||
|
query = self.session.query(self.Building.id)
|
||||||
|
assert get_mapper(query._entities[0]) == sa.inspect(self.Building)
|
||||||
|
|
||||||
|
|
||||||
class TestGetMapperWithMultipleMappersFound(object):
|
class TestGetMapperWithMultipleMappersFound(object):
|
||||||
def setup_method(self, method):
|
def setup_method(self, method):
|
||||||
|
@@ -67,11 +67,12 @@ class TestGetDotAttr(TestCase):
|
|||||||
subsection = self.SubSection(section=section)
|
subsection = self.SubSection(section=section)
|
||||||
subsubsection = self.SubSubSection(subsection=subsection)
|
subsubsection = self.SubSubSection(subsection=subsection)
|
||||||
|
|
||||||
|
assert getdotattr(document, 'sections') == [section]
|
||||||
assert getdotattr(document, 'sections.subsections') == [
|
assert getdotattr(document, 'sections.subsections') == [
|
||||||
[subsection]
|
subsection
|
||||||
]
|
]
|
||||||
assert getdotattr(document, 'sections.subsections.subsubsections') == [
|
assert getdotattr(document, 'sections.subsections.subsubsections') == [
|
||||||
[subsubsection]
|
subsubsection
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_class_paths(self):
|
def test_class_paths(self):
|
||||||
|
@@ -1,24 +0,0 @@
|
|||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from sqlalchemy_utils import has_any_changes
|
|
||||||
|
|
||||||
|
|
||||||
class TestHasAnyChanges(object):
|
|
||||||
def setup_method(self, method):
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
class Article(Base):
|
|
||||||
__tablename__ = 'article_translation'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
title = sa.Column(sa.String(100))
|
|
||||||
|
|
||||||
self.Article = Article
|
|
||||||
|
|
||||||
def test_without_changed_attr(self):
|
|
||||||
article = self.Article()
|
|
||||||
assert not has_any_changes(article, ['title'])
|
|
||||||
|
|
||||||
def test_with_changed_attr(self):
|
|
||||||
article = self.Article(title='Some title')
|
|
||||||
assert has_any_changes(article, ['title', 'id'])
|
|
@@ -1,4 +1,5 @@
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from pytest import raises
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from sqlalchemy_utils import has_index
|
from sqlalchemy_utils import has_index
|
||||||
@@ -23,6 +24,11 @@ class TestHasIndex(object):
|
|||||||
|
|
||||||
self.table = ArticleTranslation.__table__
|
self.table = ArticleTranslation.__table__
|
||||||
|
|
||||||
|
def test_column_that_belongs_to_an_alias(self):
|
||||||
|
alias = sa.orm.aliased(self.table)
|
||||||
|
with raises(TypeError):
|
||||||
|
assert has_index(alias.c.id)
|
||||||
|
|
||||||
def test_compound_primary_key(self):
|
def test_compound_primary_key(self):
|
||||||
assert has_index(self.table.c.id)
|
assert has_index(self.table.c.id)
|
||||||
assert not has_index(self.table.c.locale)
|
assert not has_index(self.table.c.locale)
|
||||||
|
@@ -1,10 +1,11 @@
|
|||||||
|
from pytest import raises
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from sqlalchemy_utils import has_unique_index
|
from sqlalchemy_utils import has_unique_index
|
||||||
|
|
||||||
|
|
||||||
class TestHasIndex(object):
|
class TestHasUniqueIndex(object):
|
||||||
def setup_method(self, method):
|
def setup_method(self, method):
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
@@ -31,6 +32,11 @@ class TestHasIndex(object):
|
|||||||
def test_primary_key(self):
|
def test_primary_key(self):
|
||||||
assert has_unique_index(self.articles.c.id)
|
assert has_unique_index(self.articles.c.id)
|
||||||
|
|
||||||
|
def test_column_of_aliased_table(self):
|
||||||
|
alias = sa.orm.aliased(self.articles)
|
||||||
|
with raises(TypeError):
|
||||||
|
assert has_unique_index(alias.c.id)
|
||||||
|
|
||||||
def test_unique_index(self):
|
def test_unique_index(self):
|
||||||
assert has_unique_index(self.article_translations.c.is_deleted)
|
assert has_unique_index(self.article_translations.c.is_deleted)
|
||||||
|
|
||||||
|
25
tests/functions/test_is_loaded.py
Normal file
25
tests/functions/test_is_loaded.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from sqlalchemy_utils import is_loaded
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsLoaded(object):
|
||||||
|
def setup_method(self, method):
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
class Article(Base):
|
||||||
|
__tablename__ = 'article_translation'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
title = sa.orm.deferred(sa.Column(sa.String(100)))
|
||||||
|
|
||||||
|
self.Article = Article
|
||||||
|
|
||||||
|
def test_loaded_property(self):
|
||||||
|
article = self.Article(id=1)
|
||||||
|
assert is_loaded(article, 'id')
|
||||||
|
|
||||||
|
def test_unloaded_property(self):
|
||||||
|
article = self.Article(id=4)
|
||||||
|
assert not is_loaded(article, 'title')
|
||||||
|
|
@@ -17,7 +17,6 @@ class TestMakeOrderByDeterministic(TestCase):
|
|||||||
sa.func.lower(name)
|
sa.func.lower(name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Article(self.Base):
|
class Article(self.Base):
|
||||||
__tablename__ = 'article'
|
__tablename__ = 'article'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
@@ -31,6 +30,7 @@ class TestMakeOrderByDeterministic(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.User = User
|
self.User = User
|
||||||
|
self.Article = Article
|
||||||
|
|
||||||
def test_column_property(self):
|
def test_column_property(self):
|
||||||
query = self.session.query(self.User).order_by(self.User.email_lower)
|
query = self.session.query(self.User).order_by(self.User.email_lower)
|
||||||
@@ -82,4 +82,10 @@ class TestMakeOrderByDeterministic(TestCase):
|
|||||||
def test_query_without_order_by(self):
|
def test_query_without_order_by(self):
|
||||||
query = self.session.query(self.User)
|
query = self.session.query(self.User)
|
||||||
query = make_order_by_deterministic(query)
|
query = make_order_by_deterministic(query)
|
||||||
assert 'ORDER BY' not in str(query)
|
assert 'ORDER BY "user".id' in str(query)
|
||||||
|
|
||||||
|
def test_alias(self):
|
||||||
|
alias = sa.orm.aliased(self.User.__table__)
|
||||||
|
query = self.session.query(alias).order_by(alias.c.name)
|
||||||
|
query = make_order_by_deterministic(query)
|
||||||
|
assert str(query).endswith('ORDER BY user_1.name, "user".id ASC')
|
||||||
|
90
tests/test_asserts.py
Normal file
90
tests/test_asserts.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy_utils import (
|
||||||
|
assert_nullable,
|
||||||
|
assert_non_nullable,
|
||||||
|
assert_max_length
|
||||||
|
)
|
||||||
|
from sqlalchemy_utils.asserts import raises
|
||||||
|
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestRaises(object):
|
||||||
|
def test_matching_exception(self):
|
||||||
|
with raises(Exception):
|
||||||
|
raise Exception()
|
||||||
|
assert True
|
||||||
|
|
||||||
|
def test_non_matchin_exception(self):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with raises(ValueError):
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
|
||||||
|
class AssertionTestCase(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String(20))
|
||||||
|
age = sa.Column(sa.Integer, nullable=False)
|
||||||
|
email = sa.Column(sa.String(200), nullable=False, unique=True)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
TestCase.setup_method(self, method)
|
||||||
|
user = self.User(name='Someone', email='someone@example.com', age=15)
|
||||||
|
self.session.add(user)
|
||||||
|
self.session.commit()
|
||||||
|
self.user = user
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssertNonNullable(AssertionTestCase):
|
||||||
|
def test_non_nullable_column(self):
|
||||||
|
# Test everything twice so that session gets rolled back properly
|
||||||
|
assert_non_nullable(self.user, 'age')
|
||||||
|
assert_non_nullable(self.user, 'age')
|
||||||
|
|
||||||
|
def test_nullable_column(self):
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_non_nullable(self.user, 'name')
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_non_nullable(self.user, 'name')
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssertNullable(AssertionTestCase):
|
||||||
|
def test_nullable_column(self):
|
||||||
|
assert_nullable(self.user, 'name')
|
||||||
|
assert_nullable(self.user, 'name')
|
||||||
|
|
||||||
|
def test_non_nullable_column(self):
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_nullable(self.user, 'age')
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_nullable(self.user, 'age')
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssertMaxLength(AssertionTestCase):
|
||||||
|
def test_with_max_length(self):
|
||||||
|
assert_max_length(self.user, 'name', 20)
|
||||||
|
assert_max_length(self.user, 'name', 20)
|
||||||
|
|
||||||
|
def test_with_non_nullable_column(self):
|
||||||
|
assert_max_length(self.user, 'email', 200)
|
||||||
|
assert_max_length(self.user, 'email', 200)
|
||||||
|
|
||||||
|
def test_smaller_than_max_length(self):
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_max_length(self.user, 'name', 19)
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_max_length(self.user, 'name', 19)
|
||||||
|
|
||||||
|
def test_bigger_than_max_length(self):
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_max_length(self.user, 'name', 21)
|
||||||
|
with raises(AssertionError):
|
||||||
|
assert_max_length(self.user, 'name', 21)
|
68
tests/test_translation_hybrid.py
Normal file
68
tests/test_translation_hybrid.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSON
|
||||||
|
from sqlalchemy_utils import TranslationHybrid
|
||||||
|
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranslationHybrid(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class City(self.Base):
|
||||||
|
__tablename__ = 'city'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name_translations = sa.Column(JSON())
|
||||||
|
name = self.translation_hybrid(name_translations)
|
||||||
|
locale = 'en'
|
||||||
|
|
||||||
|
self.City = City
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.translation_hybrid = TranslationHybrid('fi', 'en')
|
||||||
|
TestCase.setup_method(self, method)
|
||||||
|
|
||||||
|
def test_using_hybrid_as_constructor(self):
|
||||||
|
city = self.City(name='Helsinki')
|
||||||
|
assert city.name_translations['fi'] == 'Helsinki'
|
||||||
|
|
||||||
|
def test_hybrid_as_expression(self):
|
||||||
|
assert self.City.name == self.City.name_translations
|
||||||
|
|
||||||
|
def test_if_no_translation_exists_returns_none(self):
|
||||||
|
city = self.City()
|
||||||
|
assert city.name is None
|
||||||
|
|
||||||
|
def test_fall_back_to_default_translation(self):
|
||||||
|
city = self.City(name_translations={'en': 'Helsinki'})
|
||||||
|
self.translation_hybrid.current_locale = 'sv'
|
||||||
|
assert city.name == 'Helsinki'
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranslationHybridWithDynamicDefaultLocale(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class City(self.Base):
|
||||||
|
__tablename__ = 'city'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name_translations = sa.Column(JSON)
|
||||||
|
name = self.translation_hybrid(name_translations)
|
||||||
|
locale = sa.Column(sa.String(10))
|
||||||
|
|
||||||
|
self.City = City
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.translation_hybrid = TranslationHybrid(
|
||||||
|
'fi',
|
||||||
|
lambda self: self.locale
|
||||||
|
)
|
||||||
|
TestCase.setup_method(self, method)
|
||||||
|
|
||||||
|
def test_fallback_to_dynamic_locale(self):
|
||||||
|
self.translation_hybrid.current_locale = 'en'
|
||||||
|
city = self.City(name_translations={})
|
||||||
|
city.locale = 'fi'
|
||||||
|
city.name_translations['fi'] = 'Helsinki'
|
||||||
|
|
||||||
|
assert city.name == 'Helsinki'
|
@@ -1,6 +1,8 @@
|
|||||||
from pytest import mark
|
import six
|
||||||
from tests import TestCase
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from pytest import mark
|
||||||
|
|
||||||
|
from tests import TestCase
|
||||||
from sqlalchemy_utils import PhoneNumberType, PhoneNumber
|
from sqlalchemy_utils import PhoneNumberType, PhoneNumber
|
||||||
from sqlalchemy_utils.types import phone_number
|
from sqlalchemy_utils.types import phone_number
|
||||||
|
|
||||||
@@ -45,8 +47,11 @@ class TestPhoneNumber(object):
|
|||||||
|
|
||||||
def test_phone_number_str_repr(self):
|
def test_phone_number_str_repr(self):
|
||||||
number = PhoneNumber('+358401234567')
|
number = PhoneNumber('+358401234567')
|
||||||
assert number.__unicode__() == number.national
|
if six.PY2:
|
||||||
assert number.__str__() == number.national.encode('utf-8')
|
assert unicode(number) == number.national
|
||||||
|
assert str(number) == number.national.encode('utf-8')
|
||||||
|
else:
|
||||||
|
assert str(number) == number.national
|
||||||
|
|
||||||
|
|
||||||
@mark.skipif('phone_number.phonenumbers is None')
|
@mark.skipif('phone_number.phonenumbers is None')
|
||||||
|
Reference in New Issue
Block a user