diff --git a/CHANGES.rst b/CHANGES.rst index de4bf8e..3e63f6a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.29.5 (2015-02-03) +^^^^^^^^^^^^^^^^^^^ + +- Made assert_max_length support PostgreSQL array type + + 0.29.4 (2015-01-31) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 115b90f..4d6a665 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -87,7 +87,7 @@ from .types import ( from .models import Timestamp -__version__ = '0.29.4' +__version__ = '0.29.5' __all__ = ( diff --git a/sqlalchemy_utils/asserts.py b/sqlalchemy_utils/asserts.py index 934bea9..507c8e6 100644 --- a/sqlalchemy_utils/asserts.py +++ b/sqlalchemy_utils/asserts.py @@ -32,7 +32,9 @@ We can easily test the constraints by assert_* functions:: # raises AssertionError because the max length of email is 255 assert_max_length(user, 'email', 300) """ +from decimal import Decimal import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.exc import DataError, IntegrityError @@ -71,6 +73,27 @@ def _expect_failing_update(obj, field, value, expected_exc): session.rollback() +def _repeated_value(type_): + if isinstance(type_, ARRAY): + if isinstance(type_.item_type, sa.Integer): + return [0] + elif isinstance(type_.item_type, sa.String): + return [u'a'] + elif isinstance(type_.item_type, sa.Numeric): + return [Decimal('0')] + else: + raise TypeError('Unknown array item type') + else: + return u'a' + + +def _expected_exception(type_): + if isinstance(type_, ARRAY): + return IntegrityError + else: + return DataError + + def assert_nullable(obj, column): """ Assert that given column is nullable. This is checked by running an SQL @@ -95,14 +118,49 @@ def assert_non_nullable(obj, column): def assert_max_length(obj, column, max_length): """ - Assert that the given column is of given max length. + Assert that the given column is of given max length. This function supports + string typed columns as well as PostgreSQL array typed columns. + + In the following example we add a check constraint that user can have a + maximum of 5 favorite colors and then test this.:: + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + favorite_colors = sa.Column(ARRAY(sa.String), nullable=False) + __table_args__ = ( + sa.CheckConstraint( + sa.func.array_length(favorite_colors, 1) <= 5 + ) + ) + + + user = User(name='John Doe', favorite_colors=['red', 'blue']) + session.add(user) + session.commit() + + + assert_max_length(user, 'favorite_colors', 5) + :param obj: SQLAlchemy declarative model object :param column: Name of the column :param max_length: Maximum length of given column """ - _expect_successful_update(obj, column, u'a' * max_length, DataError) - _expect_failing_update(obj, column, u'a' * (max_length + 1), DataError) + type_ = sa.inspect(obj.__class__).columns[column].type + _expect_successful_update( + obj, + column, + _repeated_value(type_) * max_length, + _expected_exception(type_) + ) + _expect_failing_update( + obj, + column, + _repeated_value(type_) * (max_length + 1), + _expected_exception(type_) + ) def assert_min_value(obj, column, min_value): diff --git a/tests/test_asserts.py b/tests/test_asserts.py index 3384341..ea6d587 100644 --- a/tests/test_asserts.py +++ b/tests/test_asserts.py @@ -1,5 +1,6 @@ import sqlalchemy as sa import pytest +from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy_utils import ( assert_min_value, assert_max_length, @@ -34,21 +35,50 @@ class AssertionTestCase(TestCase): name = sa.Column(sa.String(20)) age = sa.Column(sa.Integer, nullable=False) email = sa.Column(sa.String(200), nullable=False, unique=True) + fav_numbers = sa.Column(ARRAY(sa.Integer)) __table_args__ = ( sa.CheckConstraint(sa.and_(age >= 0, age <= 150)), + sa.CheckConstraint( + sa.and_( + sa.func.array_length(fav_numbers, 1) <= 8 + ) + ) ) self.User = User def setup_method(self, method): TestCase.setup_method(self, method) - user = self.User(name='Someone', email='someone@example.com', age=15) + user = self.User( + name='Someone', + email='someone@example.com', + age=15, + fav_numbers=[1, 2, 3] + ) self.session.add(user) self.session.commit() self.user = user +class TestAssertMaxLengthWithArray(AssertionTestCase): + def test_with_max_length(self): + assert_max_length(self.user, 'fav_numbers', 8) + assert_max_length(self.user, 'fav_numbers', 8) + + def test_smaller_than_max_length(self): + with raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 7) + with raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 7) + + def test_bigger_than_max_length(self): + with raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 9) + with raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 9) + + class TestAssertNonNullable(AssertionTestCase): def test_non_nullable_column(self): # Test everything twice so that session gets rolled back properly