From ee94d5c976c2af9dbc0a0ea555b48e46a9be3df2 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 24 Oct 2014 09:55:21 +0300 Subject: [PATCH] Add automatic session rollback for assert_* functions --- CHANGES.rst | 9 ++++++++- sqlalchemy_utils/asserts.py | 33 ++++++++++++++++++++------------- tests/test_asserts.py | 12 ++++++++++++ 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 6809f9e..a73e10a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,10 +4,17 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.27.5 (2014-10-24) +^^^^^^^^^^^^^^^^^^^ + +- Made assert_* functions automatically rollback session + + + 0.27.4 (2014-10-23) ^^^^^^^^^^^^^^^^^^^ -- Added assert_non_nullable, assert_nullable and assert_max_length testing methods +- Added assert_non_nullable, assert_nullable and assert_max_length testing functions 0.27.3 (2014-10-22) diff --git a/sqlalchemy_utils/asserts.py b/sqlalchemy_utils/asserts.py index 85cd5fe..5f58359 100644 --- a/sqlalchemy_utils/asserts.py +++ b/sqlalchemy_utils/asserts.py @@ -57,6 +57,22 @@ def _update_field(obj, field, value): session.flush() +def _expect_successful_update(obj, field, value, reraise_exc): + try: + _update_field(obj, field, value) + except (reraise_exc) as e: + session = sa.orm.object_session(obj) + session.rollback() + assert False, str(e) + + +def _expect_failing_update(obj, field, value, expected_exc): + with raises(expected_exc): + _update_field(obj, field, None) + session = sa.orm.object_session(obj) + session.rollback() + + def assert_nullable(obj, column): """ Assert that given column is nullable. This is checked by running an SQL @@ -65,10 +81,7 @@ def assert_nullable(obj, column): :param obj: SQLAlchemy declarative model object :param column: Name of the column """ - try: - _update_field(obj, column, None) - except (IntegrityError) as e: - assert False, str(e) + _expect_successful_update(obj, column, None, IntegrityError) def assert_non_nullable(obj, column): @@ -79,8 +92,7 @@ def assert_non_nullable(obj, column): :param obj: SQLAlchemy declarative model object :param column: Name of the column """ - with raises(IntegrityError): - _update_field(obj, column, None) + _expect_failing_update(obj, column, None, IntegrityError) def assert_max_length(obj, column, max_length): @@ -90,10 +102,5 @@ def assert_max_length(obj, column, max_length): :param obj: SQLAlchemy declarative model object :param column: Name of the column """ - try: - _update_field(obj, column, u'a' * max_length) - except (DataError) as e: - assert False, str(e) - with raises(DataError): - _update_field(obj, column, u'a' * (max_length + 1)) - + _expect_successful_update(obj, column, u'a' * max_length, DataError) + _expect_failing_update(obj, column, u'a' * (max_length + 1), DataError) diff --git a/tests/test_asserts.py b/tests/test_asserts.py index 3986518..2e436cc 100644 --- a/tests/test_asserts.py +++ b/tests/test_asserts.py @@ -45,30 +45,42 @@ class AssertionTestCase(TestCase): class TestAssertNonNullable(AssertionTestCase): def test_non_nullable_column(self): + # Test everything twice so that session gets rolled back properly + assert_non_nullable(self.user, 'age') assert_non_nullable(self.user, 'age') def test_nullable_column(self): with raises(AssertionError): assert_non_nullable(self.user, 'name') + with raises(AssertionError): + assert_non_nullable(self.user, 'name') class TestAssertNullable(AssertionTestCase): def test_nullable_column(self): assert_nullable(self.user, 'name') + assert_nullable(self.user, 'name') def test_non_nullable_column(self): with raises(AssertionError): assert_nullable(self.user, 'age') + with raises(AssertionError): + assert_nullable(self.user, 'age') class TestAssertMaxLength(AssertionTestCase): def test_with_max_length(self): assert_max_length(self.user, 'name', 20) + assert_max_length(self.user, 'name', 20) def test_smaller_than_max_length(self): with raises(AssertionError): assert_max_length(self.user, 'name', 19) + with raises(AssertionError): + assert_max_length(self.user, 'name', 19) def test_bigger_than_max_length(self): with raises(AssertionError): assert_max_length(self.user, 'name', 21) + with raises(AssertionError): + assert_max_length(self.user, 'name', 21)