Make assert_max_length support array type

This commit is contained in:
Konsta Vesterinen
2015-02-03 11:23:44 +02:00
parent 9e85029620
commit 270098403b
4 changed files with 99 additions and 5 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.29.5 (2015-02-03)
^^^^^^^^^^^^^^^^^^^
- Made assert_max_length support PostgreSQL array type
0.29.4 (2015-01-31)
^^^^^^^^^^^^^^^^^^^

View File

@@ -87,7 +87,7 @@ from .types import (
from .models import Timestamp
__version__ = '0.29.4'
__version__ = '0.29.5'
__all__ = (

View File

@@ -32,7 +32,9 @@ We can easily test the constraints by assert_* functions::
# raises AssertionError because the max length of email is 255
assert_max_length(user, 'email', 300)
"""
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.exc import DataError, IntegrityError
@@ -71,6 +73,27 @@ def _expect_failing_update(obj, field, value, expected_exc):
session.rollback()
def _repeated_value(type_):
if isinstance(type_, ARRAY):
if isinstance(type_.item_type, sa.Integer):
return [0]
elif isinstance(type_.item_type, sa.String):
return [u'a']
elif isinstance(type_.item_type, sa.Numeric):
return [Decimal('0')]
else:
raise TypeError('Unknown array item type')
else:
return u'a'
def _expected_exception(type_):
if isinstance(type_, ARRAY):
return IntegrityError
else:
return DataError
def assert_nullable(obj, column):
"""
Assert that given column is nullable. This is checked by running an SQL
@@ -95,14 +118,49 @@ def assert_non_nullable(obj, column):
def assert_max_length(obj, column, max_length):
"""
Assert that the given column is of given max length.
Assert that the given column is of given max length. This function supports
string typed columns as well as PostgreSQL array typed columns.
In the following example we add a check constraint that user can have a
maximum of 5 favorite colors and then test this.::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
favorite_colors = sa.Column(ARRAY(sa.String), nullable=False)
__table_args__ = (
sa.CheckConstraint(
sa.func.array_length(favorite_colors, 1) <= 5
)
)
user = User(name='John Doe', favorite_colors=['red', 'blue'])
session.add(user)
session.commit()
assert_max_length(user, 'favorite_colors', 5)
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param max_length: Maximum length of given column
"""
_expect_successful_update(obj, column, u'a' * max_length, DataError)
_expect_failing_update(obj, column, u'a' * (max_length + 1), DataError)
type_ = sa.inspect(obj.__class__).columns[column].type
_expect_successful_update(
obj,
column,
_repeated_value(type_) * max_length,
_expected_exception(type_)
)
_expect_failing_update(
obj,
column,
_repeated_value(type_) * (max_length + 1),
_expected_exception(type_)
)
def assert_min_value(obj, column, min_value):

View File

@@ -1,5 +1,6 @@
import sqlalchemy as sa
import pytest
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy_utils import (
assert_min_value,
assert_max_length,
@@ -34,21 +35,50 @@ class AssertionTestCase(TestCase):
name = sa.Column(sa.String(20))
age = sa.Column(sa.Integer, nullable=False)
email = sa.Column(sa.String(200), nullable=False, unique=True)
fav_numbers = sa.Column(ARRAY(sa.Integer))
__table_args__ = (
sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
sa.CheckConstraint(
sa.and_(
sa.func.array_length(fav_numbers, 1) <= 8
)
)
)
self.User = User
def setup_method(self, method):
TestCase.setup_method(self, method)
user = self.User(name='Someone', email='someone@example.com', age=15)
user = self.User(
name='Someone',
email='someone@example.com',
age=15,
fav_numbers=[1, 2, 3]
)
self.session.add(user)
self.session.commit()
self.user = user
class TestAssertMaxLengthWithArray(AssertionTestCase):
def test_with_max_length(self):
assert_max_length(self.user, 'fav_numbers', 8)
assert_max_length(self.user, 'fav_numbers', 8)
def test_smaller_than_max_length(self):
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7)
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7)
def test_bigger_than_max_length(self):
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9)
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9)
class TestAssertNonNullable(AssertionTestCase):
def test_non_nullable_column(self):
# Test everything twice so that session gets rolled back properly