From b6e22bd08f8efd9b3f157021edc9fdfea8ec3923 Mon Sep 17 00:00:00 2001 From: Dimitrios Semitsoglou-Tsiapos Date: Fri, 12 Jun 2015 15:12:09 +0200 Subject: [PATCH] ChoiceType: non-None should not return None --- sqlalchemy_utils/types/choice.py | 14 ++++-- tests/types/test_choice.py | 82 +++++++++++++++++++++++++------- 2 files changed, 77 insertions(+), 19 deletions(-) diff --git a/sqlalchemy_utils/types/choice.py b/sqlalchemy_utils/types/choice.py index c901490..b63e46d 100644 --- a/sqlalchemy_utils/types/choice.py +++ b/sqlalchemy_utils/types/choice.py @@ -214,10 +214,18 @@ class EnumTypeImpl(object): self.enum_class = enum_class def _coerce(self, value): - return self.enum_class(value) if value else None + if value is None: + return None + if value in self.enum_class: + return value + return self.enum_class(value) def process_bind_param(self, value, dialect): - return self.enum_class(value).value if value else None + if value is None: + return None + return self.enum_class(value).value def process_result_value(self, value, dialect): - return self.enum_class(value) if value else None + if value is None: + return None + return self.enum_class(value) diff --git a/tests/types/test_choice.py b/tests/types/test_choice.py index c11bdfe..831f0d6 100644 --- a/tests/types/test_choice.py +++ b/tests/types/test_choice.py @@ -84,27 +84,33 @@ class TestChoiceTypeWithCustomUnderlyingType(TestCase): class TestEnumType(TestCase): def create_models(self): class OrderStatus(Enum): - unpaid = 1 - paid = 2 + unpaid = 0 + paid = 1 class Order(self.Base): __tablename__ = 'order' id_ = sa.Column(sa.Integer, primary_key=True) status = sa.Column( ChoiceType(OrderStatus, impl=sa.Integer()), - default=OrderStatus.unpaid + default=OrderStatus.unpaid, ) def __repr__(self): return 'Order(%r, %r)' % (self.id_, self.status) - def pay(self): - self.status = OrderStatus.paid + class OrderNullable(self.Base): + __tablename__ = 'order_nullable' + id_ = sa.Column(sa.Integer, primary_key=True) + status = sa.Column( + ChoiceType(OrderStatus, impl=sa.Integer()), + nullable=True, + ) self.OrderStatus = OrderStatus self.Order = Order + self.OrderNullable = OrderNullable - def test_parameter_processing(self): + def test_parameter_initialization(self): order = self.Order() self.session.add(order) @@ -112,20 +118,64 @@ class TestEnumType(TestCase): order = self.session.query(self.Order).first() assert order.status is self.OrderStatus.unpaid - assert order.status.value == 1 + assert order.status.value == 0 - order.pay() - self.session.commit() - - order = self.session.query(self.Order).first() - assert order.status is self.OrderStatus.paid - assert order.status.value == 2 - - def test_parameter_coercing(self): + def test_setting_by_value(self): order = self.Order() - order.status = 2 + order.status = 1 self.session.add(order) self.session.commit() + order = self.session.query(self.Order).first() assert order.status is self.OrderStatus.paid + + def test_setting_by_enum(self): + order = self.Order() + order.status = self.OrderStatus.paid + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.paid + + def test_setting_value_that_resolves_to_none(self): + order = self.Order() + order.status = 0 + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.unpaid + + def test_setting_to_wrong_enum_raises_valueerror(self): + class WrongEnum(Enum): + foo = 0 + bar = 1 + + order = self.Order() + + with raises(ValueError): + order.status = WrongEnum.foo + + def test_setting_to_uncoerceable_type_raises_valueerror(self): + order = self.Order() + with raises(ValueError): + order.status = 'Bad value' + + def test_order_nullable_stores_none(self): + # With nullable=False as in `Order`, a `None` value is always + # converted to the default value, unless we explicitly set it to + # sqlalchemy.sql.null(), so we use this class to test our ability + # to set and retrive `None`. + order_nullable = self.OrderNullable() + assert order_nullable.status is None + + order_nullable.status = None + + self.session.add(order_nullable) + self.session.commit() + + assert order_nullable.status is None