Merge branch 'master' into updates

This commit is contained in:
Ryan Leckey
2014-12-11 12:19:01 -08:00
34 changed files with 1068 additions and 234 deletions

View File

@@ -4,6 +4,69 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.27.12 (2014-12-xx)
^^^^^^^^^^^^^^^^^^^^
- Fixed PhoneNumber string coercion (#93)
0.27.11 (2014-12-06)
^^^^^^^^^^^^^^^^^^^^
- Added loose typed column checking support for get_column_key
- Made get_column_key throw UnmappedColumnError to be consistent with SQLAlchemy
0.27.10 (2014-12-03)
^^^^^^^^^^^^^^^^^^^^
- Fixed column alias handling in dependent_objects
0.27.9 (2014-12-01)
^^^^^^^^^^^^^^^^^^^
- Fixed aggregated decorator many-to-many relationship handling
- Fixed aggregated column alias handling
0.27.8 (2014-11-13)
^^^^^^^^^^^^^^^^^^^
- Added is_loaded utility function
- Removed deprecated has_any_changes
0.27.7 (2014-11-03)
^^^^^^^^^^^^^^^^^^^
- Added support for Column and ColumnEntity objects in get_mapper
- Made make_order_by_deterministic add deterministic column more aggressively
0.27.6 (2014-10-29)
^^^^^^^^^^^^^^^^^^^
- Fixed assert_max_length not working with non nullable columns
- Add PostgreSQL < 9.2 support for drop_database
0.27.5 (2014-10-24)
^^^^^^^^^^^^^^^^^^^
- Made assert_* functions automatically rollback session
- Changed make_order_by_deterministic attach order by primary key for queries without order by
- Fixed alias handling in has_unique_index
- Fixed alias handling in has_index
- Fixed alias handling in make_order_by_deterministic
0.27.4 (2014-10-23)
^^^^^^^^^^^^^^^^^^^
- Added assert_non_nullable, assert_nullable and assert_max_length testing functions
0.27.3 (2014-10-22)
^^^^^^^^^^^^^^^^^^^

View File

@@ -20,4 +20,5 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ
orm_helpers
utility_classes
models
testing
license

View File

@@ -76,6 +76,12 @@ identity
.. autofunction:: identity
is_loaded
^^^^^^^^^
.. autofunction:: is_loaded
make_order_by_deterministic
^^^^^^^^^^^^^^^^^^^^^^^^^^^

20
docs/testing.rst Normal file
View File

@@ -0,0 +1,20 @@
Testing
=======
.. automodule:: sqlalchemy_utils.asserts
assert_nullable
---------------
.. autofunction:: assert_nullable
assert_non_nullable
-------------------
.. autofunction:: assert_non_nullable
assert_max_length
-----------------
.. autofunction:: assert_max_length

View File

@@ -47,7 +47,7 @@ extras_require = {
'ipaddress': ['ipaddr'] if not PY3 else [],
'timezone': ['python-dateutil'],
'url': ['furl >= 0.3.5'] if not PY3 else [],
'encrypted': ['cryptography==0.6']
'encrypted': ['cryptography>=0.6']
}

View File

@@ -1,4 +1,5 @@
from .aggregates import aggregated
from .asserts import assert_nullable, assert_non_nullable, assert_max_length
from .batch import batch_fetch, with_backrefs
from .decorators import generates
from .exceptions import ImproperlyConfigured
@@ -23,11 +24,11 @@ from .functions import (
get_referencing_foreign_keys,
get_tables,
group_foreign_keys,
has_any_changes,
has_changes,
has_index,
has_unique_index,
identity,
is_loaded,
merge_references,
mock_engine,
naturally_equivalent,
@@ -36,6 +37,7 @@ from .functions import (
sort_query,
table_name,
)
from .i18n import TranslationHybrid
from .listeners import (
auto_delete_orphans,
coercion_listener,
@@ -78,12 +80,15 @@ from .types import (
from .models import Timestamp
__version__ = '0.27.3'
__version__ = '0.27.11'
__all__ = (
aggregated,
analyze,
assert_max_length,
assert_non_nullable,
assert_nullable,
auto_delete_orphans,
batch_fetch,
coercion_listener,
@@ -109,11 +114,11 @@ __all__ = (
get_referencing_foreign_keys,
get_tables,
group_foreign_keys,
has_any_changes,
has_changes,
has_index,
identity,
instrumented_list,
is_loaded,
merge_references,
mock_engine,
naturally_equivalent,

View File

@@ -394,13 +394,64 @@ class AggregatedAttribute(declared_attr):
self.relationship = relationship
def __get__(desc, self, cls):
value = (desc.fget, desc.relationship, desc.column)
if cls not in aggregated_attrs:
aggregated_attrs[cls] = [(desc.fget, desc.relationship)]
aggregated_attrs[cls] = [value]
else:
aggregated_attrs[cls].append((desc.fget, desc.relationship))
aggregated_attrs[cls].append(value)
return desc.column
def get_aggregate_query(agg_expr, relationships):
"""
Return a subquery for fetching an aggregate value of given aggregate
expression and given sequence of relationships.
The returned aggregate query can be used when updating denormalized column
value with query such as:
UPDATE table SET column = {aggregate_query}
WHERE {condition}
:param agg_expr:
an expression to be selected, for example sa.func.count('1')
:param relationships:
Sequence of relationships to be used for building the aggregate
query.
"""
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)
class AggregatedValue(object):
def __init__(self, class_, attr, relationships, expr):
self.class_ = class_
@@ -418,23 +469,7 @@ class AggregatedValue(object):
@property
def aggregate_query(self):
from_ = self.relationships[0].mapper.class_.__table__
for relationship in self.relationships[0:-1]:
property_ = relationship.property
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
query = sa.select(
[self.expr],
from_obj=[from_]
)
query = query.where(self.relationships[-1])
query = get_aggregate_query(self.expr, self.relationships)
return query.correlate(self.class_).as_scalar()
@@ -484,11 +519,22 @@ class AggregatedValue(object):
property_ = self.relationships[-1].property
from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]):
for relationship in reversed(self.relationships[0:-1]):
property_ = relationship.property
from_ = (
from_.join(property_.mapper.class_, property_.primaryjoin)
)
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.primaryjoin
)
from_ = from_.join(
property_.mapper.class_,
property_.secondaryjoin
)
else:
from_ = from_.join(
property_.mapper.class_,
property_.primaryjoin
)
return from_
def local_condition(self, prop, objects):
@@ -532,7 +578,7 @@ class AggregationManager(object):
def update_generator_registry(self):
for class_, attrs in six.iteritems(aggregated_attrs):
for expr, relationship in attrs:
for expr, relationship, column in attrs:
relationships = []
rel_class = class_
@@ -544,7 +590,7 @@ class AggregationManager(object):
self.generator_registry[rel_class].append(
AggregatedValue(
class_=class_,
attr=expr.__name__,
attr=column,
relationships=list(reversed(relationships)),
expr=expr(class_)
)

106
sqlalchemy_utils/asserts.py Normal file
View File

@@ -0,0 +1,106 @@
"""
The functions in this module can be used for testing that the constraints of
your models. Each assert function runs SQL UPDATEs that check for the existence
of given constraint. Consider the following model::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(200), nullable=True)
email = sa.Column(sa.String(255), nullable=False)
user = User(name='John Doe', email='john@example.com')
session.add(user)
session.commit()
We can easily test the constraints by assert_* functions::
from sqlalchemy_utils import (
assert_nullable,
assert_non_nullable,
assert_max_length
)
assert_nullable(user, 'name')
assert_non_nullable(user, 'email')
assert_max_length(user, 'name', 200)
# raises AssertionError because the max length of email is 255
assert_max_length(user, 'email', 300)
"""
import sqlalchemy as sa
from sqlalchemy.exc import DataError, IntegrityError
class raises(object):
def __init__(self, expected_exc):
self.expected_exc = expected_exc
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type != self.expected_exc:
return False
return True
def _update_field(obj, field, value):
session = sa.orm.object_session(obj)
table = sa.inspect(obj.__class__).columns[field].table
query = table.update().values(**{field: value})
session.execute(query)
session.flush()
def _expect_successful_update(obj, field, value, reraise_exc):
try:
_update_field(obj, field, value)
except (reraise_exc) as e:
session = sa.orm.object_session(obj)
session.rollback()
assert False, str(e)
def _expect_failing_update(obj, field, value, expected_exc):
with raises(expected_exc):
_update_field(obj, field, value)
session = sa.orm.object_session(obj)
session.rollback()
def assert_nullable(obj, column):
"""
Assert that given column is nullable. This is checked by running an SQL
update that assigns given column as None.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
"""
_expect_successful_update(obj, column, None, IntegrityError)
def assert_non_nullable(obj, column):
"""
Assert that given column is not nullable. This is checked by running an SQL
update that assigns given column as None.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
"""
_expect_failing_update(obj, column, None, IntegrityError)
def assert_max_length(obj, column, max_length):
"""
Assert that the given column is of given max length.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
"""
_expect_successful_update(obj, column, u'a' * max_length, DataError)
_expect_failing_update(obj, column, u'a' * (max_length + 1), DataError)

View File

@@ -35,9 +35,9 @@ from .orm import (
get_query_entities,
get_tables,
getdotattr,
has_any_changes,
has_changes,
identity,
is_loaded,
naturally_equivalent,
quote,
table_name,
@@ -62,9 +62,9 @@ __all__ = (
'get_tables',
'getdotattr',
'group_foreign_keys',
'has_any_changes',
'has_changes',
'identity',
'is_loaded',
'is_auto_assigned_date_column',
'is_indexed_foreign_key',
'make_order_by_deterministic',

View File

@@ -159,12 +159,18 @@ def has_index(column):
has_index(table.c.locale) # False
has_index(table.c.id) # True
"""
table = column.table
if not isinstance(table, sa.Table):
raise TypeError(
'Only columns belonging to Table objects are supported. Given '
'column belongs to %r.' % table
)
return (
column is column.table.primary_key.columns.values()[0]
column is table.primary_key.columns.values()[0]
or
any(
index.columns.values()[0] is column
for index in column.table.indexes
for index in table.indexes
)
)
@@ -198,8 +204,17 @@ def has_unique_index(column):
has_unique_index(table.c.is_published) # True
has_unique_index(table.c.is_deleted) # False
has_unique_index(table.c.id) # True
:raises TypeError: if given column does not belong to a Table object
"""
pks = column.table.primary_key.columns
table = column.table
if not isinstance(table, sa.Table):
raise TypeError(
'Only columns belonging to Table objects are supported. Given '
'column belongs to %r.' % table
)
pks = table.primary_key.columns
return (
(column is pks.values()[0] and len(pks) == 1)
or
@@ -384,12 +399,21 @@ def drop_database(url):
engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
# Disconnect all users from the database we are dropping.
version = list(
map(
int,
engine.execute('SHOW server_version;').first()[0].split('.')
)
)
pid_column = (
'pid' if (version[0] >= 9 and version[1] >= 2) else 'procpid'
)
text = '''
SELECT pg_terminate_backend(pg_stat_activity.pid)
SELECT pg_terminate_backend(pg_stat_activity.%(pid_column)s)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '%s'
AND pid <> pg_backend_pid()
''' % database
WHERE pg_stat_activity.datname = '%(database)s'
AND %(pid_column)s <> pg_backend_pid();
''' % {'pid_column': pid_column, 'database': database}
engine.execute(text)
# Drop the database.

View File

@@ -7,7 +7,7 @@ from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm import object_session
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .orm import get_mapper, get_tables
from .orm import get_column_key, get_mapper, get_tables
from ..query_chain import QueryChain
@@ -279,31 +279,40 @@ def dependent_objects(obj, foreign_keys=None):
table in mapper.tables and
not (parent_mapper and table in parent_mapper.tables)
):
criteria = []
visited_constraints = []
for key in keys:
if key.constraint not in visited_constraints:
visited_constraints.append(key.constraint)
subcriteria = [
getattr(class_, column.key) ==
getattr(
obj,
key.constraint.elements[index].column.key
)
for index, column
in enumerate(key.constraint.columns)
]
criteria.append(sa.and_(*subcriteria))
query = session.query(class_).filter(
sa.or_(
*criteria
)
sa.or_(*_get_criteria(keys, class_, obj))
)
chain.queries.append(query)
return chain
def _get_criteria(keys, class_, obj):
criteria = []
visited_constraints = []
for key in keys:
if key.constraint in visited_constraints:
continue
visited_constraints.append(key.constraint)
subcriteria = []
for index, column in enumerate(key.constraint.columns):
foreign_column = (
key.constraint.elements[index].column
)
subcriteria.append(
getattr(class_, get_column_key(class_, column)) ==
getattr(
obj,
sa.inspect(type(obj))
.get_property_by_column(
foreign_column
).key
)
)
criteria.append(sa.and_(*subcriteria))
return criteria
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
@@ -347,10 +356,10 @@ def is_indexed_foreign_key(constraint):
:param constraint: ForeignKeyConstraint object to check the indexes
"""
for index in constraint.table.indexes:
index_column_names = set(
column.name for column in index.columns
)
if index_column_names == set(constraint.columns):
return True
return False
return any(
set(column.name for column in index.columns)
==
set(constraint.columns)
for index
in constraint.table.indexes
)

View File

@@ -35,14 +35,21 @@ def get_column_key(model, column):
get_column_key(User, User.__table__.c.name) # 'name'
.. versionadded: 0.26.5
.. versionchanged: 0.27.11
Throws UnmappedColumnError instead of ValueError when no property was
found for given column. This is consistent with how SQLAlchemy works.
"""
for key, c in sa.inspect(model).columns.items():
if c is column:
return key
raise ValueError(
"Class %s doesn't have a column '%s'",
model.__name__,
column
mapper = sa.inspect(model)
try:
return mapper.get_property_by_column(column).key
except sa.orm.exc.UnmappedColumnError:
for key, c in mapper.columns.items():
if c.name == column.name and c.table is column.table:
return key
raise sa.orm.exc.UnmappedColumnError(
'No column %s is configured on mapper %s...' %
(column, mapper)
)
@@ -77,6 +84,10 @@ def get_mapper(mixed):
"""
if isinstance(mixed, sa.orm.query._MapperEntity):
mixed = mixed.expr
elif isinstance(mixed, sa.Column):
mixed = mixed.table
elif isinstance(mixed, sa.orm.query._ColumnEntity):
mixed = mixed.expr
if isinstance(mixed, sa.orm.Mapper):
return mixed
@@ -227,8 +238,6 @@ def get_tables(mixed):
tables = sum((m.tables for m in polymorphic_mappers), [])
else:
tables = mapper.tables
return tables
@@ -635,13 +644,7 @@ def getdotattr(obj_or_class, dot_path):
for path in dot_path.split('.'):
getter = attrgetter(path)
if isinstance(last, list):
tmp = []
for el in last:
if isinstance(el, list):
tmp.extend(map(getter, el))
else:
tmp.append(getter(el))
last = tmp
last = sum((getter(el) for el in last), [])
elif isinstance(last, InstrumentedAttribute):
last = getter(last.property.mapper.class_)
elif last is None:
@@ -722,35 +725,37 @@ def has_changes(obj, attrs=None, exclude=None):
)
def has_any_changes(obj, columns):
def is_loaded(obj, prop):
"""
Simple shortcut function for checking if any of the given attributes of
given declarative model object have changes.
Return whether or not given property of given object has been loaded.
::
from sqlalchemy_utils import has_any_changes
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
content = sa.orm.deferred(sa.Column(sa.String))
user = User()
article = session.query(Article).get(5)
has_any_changes(user, ('name', )) # False
# name gets loaded since its not a deferred property
assert is_loaded(article, 'name')
user.name = u'someone'
has_any_changes(user, ('name', 'age')) # True
# content has not yet been loaded since its a deferred property
assert not is_loaded(article, 'content')
.. versionadded: 0.26.3
.. deprecated:: 0.26.6
User :func:`has_changes` instead.
.. versionadded: 0.27.8
:param obj: SQLAlchemy declarative model object
:param attrs: Names of the attributes
:param prop: Name of the property or InstrumentedAttribute
"""
return any(has_changes(obj, column) for column in columns)
return not isinstance(
getattr(sa.inspect(obj).attrs, prop).loaded_value,
sa.util.langhelpers._symbol
)
def identity(obj_or_class):

View File

@@ -169,23 +169,28 @@ def make_order_by_deterministic(query):
.. versionadded: 0.27.1
"""
if not query._order_by:
return query
order_by_func = sa.asc
order_by = query._order_by[0]
if isinstance(order_by, sa.sql.expression.UnaryExpression):
if order_by.modifier == sa.sql.operators.desc_op:
order_by_func = sa.desc
else:
order_by_func = sa.asc
column = order_by.get_children()[0]
if not query._order_by:
column = None
else:
column = order_by
order_by_func = sa.asc
order_by = query._order_by[0]
if isinstance(order_by, sa.sql.expression.UnaryExpression):
if order_by.modifier == sa.sql.operators.desc_op:
order_by_func = sa.desc
else:
order_by_func = sa.asc
column = order_by.get_children()[0]
else:
column = order_by
# Queries that are ordered by an already
if isinstance(column, sa.Column) and has_unique_index(column):
return query
if isinstance(column, sa.Column):
try:
if has_unique_index(column):
return query
except TypeError:
pass
base_table = get_tables(query._entities[0])[0]
query = query.order_by(

View File

@@ -1,3 +1,5 @@
from sqlalchemy.ext.hybrid import hybrid_property
from .exceptions import ImproperlyConfigured
@@ -21,3 +23,60 @@ except ImportError:
'install babel or make a similar function and override it '
'in this module.'
)
class TranslationHybrid(object):
def __init__(self, current_locale, default_locale):
self.current_locale = current_locale
self.default_locale = default_locale
def cast_locale(self, obj, locale):
"""
Cast given locale to string. Supports also callbacks that return
locales.
"""
if callable(locale):
try:
return str(locale())
except TypeError:
return str(locale(obj))
return str(locale)
def getter_factory(self, attr):
"""
Return a hybrid_property getter function for given attribute. The
returned getter first checks if object has translation for current
locale. If not it tries to get translation for default locale. If there
is no translation found for default locale it returns None.
"""
def getter(obj):
current_locale = self.cast_locale(obj, self.current_locale)
try:
return getattr(obj, attr.key)[current_locale]
except (TypeError, KeyError):
default_locale = self.cast_locale(
obj, self.default_locale
)
try:
return getattr(obj, attr.key)[default_locale]
except (TypeError, KeyError):
return None
return getter
def setter_factory(self, attr):
def setter(obj, value):
if getattr(obj, attr.key) is None:
setattr(obj, attr.key, {})
locale = self.cast_locale(obj, self.current_locale)
getattr(obj, attr.key)[locale] = value
return setter
def expr_factory(self, attr):
return lambda cls: attr
def __call__(self, attr):
return hybrid_property(
fget=self.getter_factory(attr),
fset=self.setter_factory(attr),
expr=self.expr_factory(attr)
)

View File

@@ -1,6 +1,7 @@
import six
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from sqlalchemy_utils.utils import str_coercible
from .scalar_coercible import ScalarCoercible
@@ -13,6 +14,7 @@ except ImportError:
BasePhoneNumber = object
@str_coercible
class PhoneNumber(BasePhoneNumber):
'''
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
@@ -66,9 +68,6 @@ class PhoneNumber(BasePhoneNumber):
def __unicode__(self):
return self.national
def __str__(self):
return six.text_type(self.national).encode('utf-8')
class PhoneNumberType(types.TypeDecorator, ScalarCoercible):
"""

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
class TestAggregateValueGenerationWithBackrefs(TestCase):
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'

View File

@@ -0,0 +1,80 @@
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateManyToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
catalog_products = sa.Table(
'catalog_product',
self.Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalogs = sa.orm.relationship(
Catalog,
backref='products',
secondary=catalog_products
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
catalog2 = self.Catalog(products=products)
self.session.add(catalog)
self.session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@@ -0,0 +1,76 @@
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateOneToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalog = sa.orm.relationship(
Catalog,
backref='products'
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
products2 = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog2 = self.Catalog(products=products2)
self.session.add(catalog)
self.session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@@ -0,0 +1,62 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateOneToManyAndOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category')
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Some catalog'
self.session.add(catalog)
self.session.commit()
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
category=category
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1

View File

@@ -1,76 +1,15 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestDeepModelPathsForAggregates(TestCase):
class Test3LevelDeepOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category')
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Some catalog'
self.session.add(catalog)
self.session.commit()
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
category=category
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1
class Test3LevelDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
n = 1
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.sub_categories.products',
@@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship(
@@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
sub_category_id = sa.Column(
@@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
assert catalog.product_count == 1
def catalog_factory(self):
product = self.Product(
name=u'Product %d' % self.n
)
product = self.Product()
sub_category = self.SubCategory(
name=u'SubCategory %d' % self.n,
products=[product]
)
category = self.Category(
name=u'Category %d' % self.n,
sub_categories=[sub_category]
)
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Catalog %d' % self.n
category = self.Category(sub_categories=[sub_category])
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
self.n += 1
return catalog
def test_only_updates_affected_aggregates(self):
@@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
)
catalog.categories[0].sub_categories[0].products.append(
self.Product(name=u'Product 3')
self.Product()
)
self.session.commit()
self.session.refresh(catalog)

View File

@@ -0,0 +1,58 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregatedWithColumnAlias(TestCase):
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
@aggregated(
'comments',
sa.Column('_comment_count', sa.Integer, default=0)
)
def comment_count(self):
return sa.func.count('1')
comments = sa.orm.relationship('Comment', backref='thread')
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
self.Thread = Thread
self.Comment = Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread()
self.session.add(thread)
comment = self.Comment(thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_separate_insert(self):
thread = self.Thread()
self.session.add(thread)
self.session.commit()
comment = self.Comment(thread=thread)
self.session.add(comment)
self.session.commit()
self.session.refresh(thread)
assert thread.comment_count == 1
def test_assigns_aggregates_on_delete(self):
thread = self.Thread()
self.session.add(thread)
self.session.commit()
comment = self.Comment(thread=thread)
self.session.add(comment)
self.session.commit()
self.session.delete(comment)
self.session.commit()
self.session.refresh(thread)
assert thread.comment_count == 0

View File

@@ -78,6 +78,85 @@ class TestDependentObjects(TestCase):
assert objects[3] in deps
class TestDependentObjectsWithColumnAliases(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(
'_author_id', sa.Integer, sa.ForeignKey('user.id')
)
owner_id = sa.Column(
'_owner_id',
sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL')
)
author = sa.orm.relationship(User, foreign_keys=[author_id])
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
owner_id = sa.Column(
'_owner_id',
sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE')
)
owner = sa.orm.relationship(User)
self.User = User
self.Article = Article
self.BlogPost = BlogPost
def test_returns_all_dependent_objects(self):
user = self.User(first_name=u'John')
articles = [
self.Article(author=user),
self.Article(),
self.Article(owner=user),
self.Article(author=user, owner=user)
]
self.session.add_all(articles)
self.session.commit()
deps = list(dependent_objects(user))
assert len(deps) == 3
assert articles[0] in deps
assert articles[2] in deps
assert articles[3] in deps
def test_with_foreign_keys_parameter(self):
user = self.User(first_name=u'John')
objects = [
self.Article(author=user),
self.Article(),
self.Article(owner=user),
self.Article(author=user, owner=user),
self.BlogPost(owner=user)
]
self.session.add_all(objects)
self.session.commit()
deps = list(
dependent_objects(
user,
(
fk for fk in get_referencing_foreign_keys(self.User)
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
)
).limit(5)
)
assert len(deps) == 2
assert objects[0] in deps
assert objects[3] in deps
class TestDependentObjectsWithManyReferences(TestCase):
def create_models(self):
class User(self.Base):
@@ -192,7 +271,6 @@ class TestDependentObjectsWithSingleTableInheritance(TestCase):
'polymorphic_identity': u'blog_post'
}
self.Category = Category
self.TextItem = TextItem
self.Article = Article

View File

@@ -1,3 +1,4 @@
from copy import copy
from pytest import raises
import sqlalchemy as sa
@@ -15,7 +16,12 @@ class TestGetColumnKey(object):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
class Movie(Base):
__tablename__ = 'movie'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
self.Movie = Movie
def test_supports_aliases(self):
assert (
@@ -29,6 +35,10 @@ class TestGetColumnKey(object):
'name'
)
def test_supports_vague_matching_of_column_objects(self):
column = copy(self.Building.__table__.c._name)
assert get_column_key(self.Building, column) == 'name'
def test_throws_value_error_for_unknown_column(self):
with raises(ValueError):
get_column_key(self.Building, 'unknown')
with raises(sa.orm.exc.UnmappedColumnError):
get_column_key(self.Building, self.Movie.__table__.c.id)

View File

@@ -56,6 +56,18 @@ class TestGetMapper(object):
sa.inspect(self.Building)
)
def test_column(self):
assert (
get_mapper(self.Building.__table__.c.id) ==
sa.inspect(self.Building)
)
def test_column_of_an_alias(self):
assert (
get_mapper(sa.orm.aliased(self.Building.__table__).c.id) ==
sa.inspect(self.Building)
)
class TestGetMapperWithQueryEntities(TestCase):
def create_models(self):
@@ -79,6 +91,10 @@ class TestGetMapperWithQueryEntities(TestCase):
sa.inspect(self.Building)
)
def test_column_entity(self):
query = self.session.query(self.Building.id)
assert get_mapper(query._entities[0]) == sa.inspect(self.Building)
class TestGetMapperWithMultipleMappersFound(object):
def setup_method(self, method):

View File

@@ -67,11 +67,12 @@ class TestGetDotAttr(TestCase):
subsection = self.SubSection(section=section)
subsubsection = self.SubSubSection(subsection=subsection)
assert getdotattr(document, 'sections') == [section]
assert getdotattr(document, 'sections.subsections') == [
[subsection]
subsection
]
assert getdotattr(document, 'sections.subsections.subsubsections') == [
[subsubsection]
subsubsection
]
def test_class_paths(self):

View File

@@ -1,24 +0,0 @@
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import has_any_changes
class TestHasAnyChanges(object):
def setup_method(self, method):
Base = declarative_base()
class Article(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(100))
self.Article = Article
def test_without_changed_attr(self):
article = self.Article()
assert not has_any_changes(article, ['title'])
def test_with_changed_attr(self):
article = self.Article(title='Some title')
assert has_any_changes(article, ['title', 'id'])

View File

@@ -1,4 +1,5 @@
import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import has_index
@@ -23,6 +24,11 @@ class TestHasIndex(object):
self.table = ArticleTranslation.__table__
def test_column_that_belongs_to_an_alias(self):
alias = sa.orm.aliased(self.table)
with raises(TypeError):
assert has_index(alias.c.id)
def test_compound_primary_key(self):
assert has_index(self.table.c.id)
assert not has_index(self.table.c.locale)

View File

@@ -1,10 +1,11 @@
from pytest import raises
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import has_unique_index
class TestHasIndex(object):
class TestHasUniqueIndex(object):
def setup_method(self, method):
Base = declarative_base()
@@ -31,6 +32,11 @@ class TestHasIndex(object):
def test_primary_key(self):
assert has_unique_index(self.articles.c.id)
def test_column_of_aliased_table(self):
alias = sa.orm.aliased(self.articles)
with raises(TypeError):
assert has_unique_index(alias.c.id)
def test_unique_index(self):
assert has_unique_index(self.article_translations.c.is_deleted)

View File

@@ -0,0 +1,25 @@
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import is_loaded
class TestIsLoaded(object):
def setup_method(self, method):
Base = declarative_base()
class Article(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.orm.deferred(sa.Column(sa.String(100)))
self.Article = Article
def test_loaded_property(self):
article = self.Article(id=1)
assert is_loaded(article, 'id')
def test_unloaded_property(self):
article = self.Article(id=4)
assert not is_loaded(article, 'title')

View File

@@ -17,7 +17,6 @@ class TestMakeOrderByDeterministic(TestCase):
sa.func.lower(name)
)
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
@@ -31,6 +30,7 @@ class TestMakeOrderByDeterministic(TestCase):
)
self.User = User
self.Article = Article
def test_column_property(self):
query = self.session.query(self.User).order_by(self.User.email_lower)
@@ -82,4 +82,10 @@ class TestMakeOrderByDeterministic(TestCase):
def test_query_without_order_by(self):
query = self.session.query(self.User)
query = make_order_by_deterministic(query)
assert 'ORDER BY' not in str(query)
assert 'ORDER BY "user".id' in str(query)
def test_alias(self):
alias = sa.orm.aliased(self.User.__table__)
query = self.session.query(alias).order_by(alias.c.name)
query = make_order_by_deterministic(query)
assert str(query).endswith('ORDER BY user_1.name, "user".id ASC')

90
tests/test_asserts.py Normal file
View File

@@ -0,0 +1,90 @@
import sqlalchemy as sa
import pytest
from sqlalchemy_utils import (
assert_nullable,
assert_non_nullable,
assert_max_length
)
from sqlalchemy_utils.asserts import raises
from tests import TestCase
class TestRaises(object):
def test_matching_exception(self):
with raises(Exception):
raise Exception()
assert True
def test_non_matchin_exception(self):
with pytest.raises(Exception):
with raises(ValueError):
raise Exception()
class AssertionTestCase(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(20))
age = sa.Column(sa.Integer, nullable=False)
email = sa.Column(sa.String(200), nullable=False, unique=True)
self.User = User
def setup_method(self, method):
TestCase.setup_method(self, method)
user = self.User(name='Someone', email='someone@example.com', age=15)
self.session.add(user)
self.session.commit()
self.user = user
class TestAssertNonNullable(AssertionTestCase):
def test_non_nullable_column(self):
# Test everything twice so that session gets rolled back properly
assert_non_nullable(self.user, 'age')
assert_non_nullable(self.user, 'age')
def test_nullable_column(self):
with raises(AssertionError):
assert_non_nullable(self.user, 'name')
with raises(AssertionError):
assert_non_nullable(self.user, 'name')
class TestAssertNullable(AssertionTestCase):
def test_nullable_column(self):
assert_nullable(self.user, 'name')
assert_nullable(self.user, 'name')
def test_non_nullable_column(self):
with raises(AssertionError):
assert_nullable(self.user, 'age')
with raises(AssertionError):
assert_nullable(self.user, 'age')
class TestAssertMaxLength(AssertionTestCase):
def test_with_max_length(self):
assert_max_length(self.user, 'name', 20)
assert_max_length(self.user, 'name', 20)
def test_with_non_nullable_column(self):
assert_max_length(self.user, 'email', 200)
assert_max_length(self.user, 'email', 200)
def test_smaller_than_max_length(self):
with raises(AssertionError):
assert_max_length(self.user, 'name', 19)
with raises(AssertionError):
assert_max_length(self.user, 'name', 19)
def test_bigger_than_max_length(self):
with raises(AssertionError):
assert_max_length(self.user, 'name', 21)
with raises(AssertionError):
assert_max_length(self.user, 'name', 21)

View File

@@ -0,0 +1,68 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy_utils import TranslationHybrid
from tests import TestCase
class TestTranslationHybrid(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class City(self.Base):
__tablename__ = 'city'
id = sa.Column(sa.Integer, primary_key=True)
name_translations = sa.Column(JSON())
name = self.translation_hybrid(name_translations)
locale = 'en'
self.City = City
def setup_method(self, method):
self.translation_hybrid = TranslationHybrid('fi', 'en')
TestCase.setup_method(self, method)
def test_using_hybrid_as_constructor(self):
city = self.City(name='Helsinki')
assert city.name_translations['fi'] == 'Helsinki'
def test_hybrid_as_expression(self):
assert self.City.name == self.City.name_translations
def test_if_no_translation_exists_returns_none(self):
city = self.City()
assert city.name is None
def test_fall_back_to_default_translation(self):
city = self.City(name_translations={'en': 'Helsinki'})
self.translation_hybrid.current_locale = 'sv'
assert city.name == 'Helsinki'
class TestTranslationHybridWithDynamicDefaultLocale(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class City(self.Base):
__tablename__ = 'city'
id = sa.Column(sa.Integer, primary_key=True)
name_translations = sa.Column(JSON)
name = self.translation_hybrid(name_translations)
locale = sa.Column(sa.String(10))
self.City = City
def setup_method(self, method):
self.translation_hybrid = TranslationHybrid(
'fi',
lambda self: self.locale
)
TestCase.setup_method(self, method)
def test_fallback_to_dynamic_locale(self):
self.translation_hybrid.current_locale = 'en'
city = self.City(name_translations={})
city.locale = 'fi'
city.name_translations['fi'] = 'Helsinki'
assert city.name == 'Helsinki'

View File

@@ -1,6 +1,8 @@
from pytest import mark
from tests import TestCase
import six
import sqlalchemy as sa
from pytest import mark
from tests import TestCase
from sqlalchemy_utils import PhoneNumberType, PhoneNumber
from sqlalchemy_utils.types import phone_number
@@ -45,8 +47,11 @@ class TestPhoneNumber(object):
def test_phone_number_str_repr(self):
number = PhoneNumber('+358401234567')
assert number.__unicode__() == number.national
assert number.__str__() == number.national.encode('utf-8')
if six.PY2:
assert unicode(number) == number.national
assert str(number) == number.national.encode('utf-8')
else:
assert str(number) == number.national
@mark.skipif('phone_number.phonenumbers is None')