diff --git a/CHANGES.rst b/CHANGES.rst index 69227f6..9b614b5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,7 @@ Here you can see the full list of changes between each SQLAlchemy-Utils release. 0.30.11 (2015-06-18) ^^^^^^^^^^^^^^^^^^^^ +- Fix None type handling of ChoiceType - Make locale casting for translation hybrid expressions cast locales on compilation phase. This extra lazy locale casting is needed in some cases where translation hybrid expressions are used before get_locale function is available. diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 47b3ff7..7560f4f 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -92,4 +92,4 @@ from .types import ( # noqa WeekDaysType ) -__version__ = '0.30.10' +__version__ = '0.30.11' diff --git a/sqlalchemy_utils/types/choice.py b/sqlalchemy_utils/types/choice.py index c901490..5a7446c 100644 --- a/sqlalchemy_utils/types/choice.py +++ b/sqlalchemy_utils/types/choice.py @@ -214,10 +214,14 @@ class EnumTypeImpl(object): self.enum_class = enum_class def _coerce(self, value): - return self.enum_class(value) if value else None + if value is None: + return None + return self.enum_class(value) def process_bind_param(self, value, dialect): - return self.enum_class(value).value if value else None + if value is None: + return None + return self.enum_class(value).value def process_result_value(self, value, dialect): - return self.enum_class(value) if value else None + return self._coerce(value) diff --git a/tests/types/test_choice.py b/tests/types/test_choice.py index c11bdfe..831f0d6 100644 --- a/tests/types/test_choice.py +++ b/tests/types/test_choice.py @@ -84,27 +84,33 @@ class TestChoiceTypeWithCustomUnderlyingType(TestCase): class TestEnumType(TestCase): def create_models(self): class OrderStatus(Enum): - unpaid = 1 - paid = 2 + unpaid = 0 + paid = 1 class Order(self.Base): __tablename__ = 'order' id_ = sa.Column(sa.Integer, primary_key=True) status = sa.Column( ChoiceType(OrderStatus, impl=sa.Integer()), - default=OrderStatus.unpaid + default=OrderStatus.unpaid, ) def __repr__(self): return 'Order(%r, %r)' % (self.id_, self.status) - def pay(self): - self.status = OrderStatus.paid + class OrderNullable(self.Base): + __tablename__ = 'order_nullable' + id_ = sa.Column(sa.Integer, primary_key=True) + status = sa.Column( + ChoiceType(OrderStatus, impl=sa.Integer()), + nullable=True, + ) self.OrderStatus = OrderStatus self.Order = Order + self.OrderNullable = OrderNullable - def test_parameter_processing(self): + def test_parameter_initialization(self): order = self.Order() self.session.add(order) @@ -112,20 +118,64 @@ class TestEnumType(TestCase): order = self.session.query(self.Order).first() assert order.status is self.OrderStatus.unpaid - assert order.status.value == 1 + assert order.status.value == 0 - order.pay() - self.session.commit() - - order = self.session.query(self.Order).first() - assert order.status is self.OrderStatus.paid - assert order.status.value == 2 - - def test_parameter_coercing(self): + def test_setting_by_value(self): order = self.Order() - order.status = 2 + order.status = 1 self.session.add(order) self.session.commit() + order = self.session.query(self.Order).first() assert order.status is self.OrderStatus.paid + + def test_setting_by_enum(self): + order = self.Order() + order.status = self.OrderStatus.paid + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.paid + + def test_setting_value_that_resolves_to_none(self): + order = self.Order() + order.status = 0 + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.unpaid + + def test_setting_to_wrong_enum_raises_valueerror(self): + class WrongEnum(Enum): + foo = 0 + bar = 1 + + order = self.Order() + + with raises(ValueError): + order.status = WrongEnum.foo + + def test_setting_to_uncoerceable_type_raises_valueerror(self): + order = self.Order() + with raises(ValueError): + order.status = 'Bad value' + + def test_order_nullable_stores_none(self): + # With nullable=False as in `Order`, a `None` value is always + # converted to the default value, unless we explicitly set it to + # sqlalchemy.sql.null(), so we use this class to test our ability + # to set and retrive `None`. + order_nullable = self.OrderNullable() + assert order_nullable.status is None + + order_nullable.status = None + + self.session.add(order_nullable) + self.session.commit() + + assert order_nullable.status is None