ChoiceType: non-None should not return None
This commit is contained in:
@@ -214,10 +214,18 @@ class EnumTypeImpl(object):
|
||||
self.enum_class = enum_class
|
||||
|
||||
def _coerce(self, value):
|
||||
return self.enum_class(value) if value else None
|
||||
if value is None:
|
||||
return None
|
||||
if value in self.enum_class:
|
||||
return value
|
||||
return self.enum_class(value)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return self.enum_class(value).value if value else None
|
||||
if value is None:
|
||||
return None
|
||||
return self.enum_class(value).value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return self.enum_class(value) if value else None
|
||||
if value is None:
|
||||
return None
|
||||
return self.enum_class(value)
|
||||
|
@@ -84,27 +84,33 @@ class TestChoiceTypeWithCustomUnderlyingType(TestCase):
|
||||
class TestEnumType(TestCase):
|
||||
def create_models(self):
|
||||
class OrderStatus(Enum):
|
||||
unpaid = 1
|
||||
paid = 2
|
||||
unpaid = 0
|
||||
paid = 1
|
||||
|
||||
class Order(self.Base):
|
||||
__tablename__ = 'order'
|
||||
id_ = sa.Column(sa.Integer, primary_key=True)
|
||||
status = sa.Column(
|
||||
ChoiceType(OrderStatus, impl=sa.Integer()),
|
||||
default=OrderStatus.unpaid
|
||||
default=OrderStatus.unpaid,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Order(%r, %r)' % (self.id_, self.status)
|
||||
|
||||
def pay(self):
|
||||
self.status = OrderStatus.paid
|
||||
class OrderNullable(self.Base):
|
||||
__tablename__ = 'order_nullable'
|
||||
id_ = sa.Column(sa.Integer, primary_key=True)
|
||||
status = sa.Column(
|
||||
ChoiceType(OrderStatus, impl=sa.Integer()),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
self.OrderStatus = OrderStatus
|
||||
self.Order = Order
|
||||
self.OrderNullable = OrderNullable
|
||||
|
||||
def test_parameter_processing(self):
|
||||
def test_parameter_initialization(self):
|
||||
order = self.Order()
|
||||
|
||||
self.session.add(order)
|
||||
@@ -112,20 +118,64 @@ class TestEnumType(TestCase):
|
||||
|
||||
order = self.session.query(self.Order).first()
|
||||
assert order.status is self.OrderStatus.unpaid
|
||||
assert order.status.value == 1
|
||||
assert order.status.value == 0
|
||||
|
||||
order.pay()
|
||||
self.session.commit()
|
||||
|
||||
order = self.session.query(self.Order).first()
|
||||
assert order.status is self.OrderStatus.paid
|
||||
assert order.status.value == 2
|
||||
|
||||
def test_parameter_coercing(self):
|
||||
def test_setting_by_value(self):
|
||||
order = self.Order()
|
||||
order.status = 2
|
||||
order.status = 1
|
||||
|
||||
self.session.add(order)
|
||||
self.session.commit()
|
||||
|
||||
order = self.session.query(self.Order).first()
|
||||
assert order.status is self.OrderStatus.paid
|
||||
|
||||
def test_setting_by_enum(self):
|
||||
order = self.Order()
|
||||
order.status = self.OrderStatus.paid
|
||||
|
||||
self.session.add(order)
|
||||
self.session.commit()
|
||||
|
||||
order = self.session.query(self.Order).first()
|
||||
assert order.status is self.OrderStatus.paid
|
||||
|
||||
def test_setting_value_that_resolves_to_none(self):
|
||||
order = self.Order()
|
||||
order.status = 0
|
||||
|
||||
self.session.add(order)
|
||||
self.session.commit()
|
||||
|
||||
order = self.session.query(self.Order).first()
|
||||
assert order.status is self.OrderStatus.unpaid
|
||||
|
||||
def test_setting_to_wrong_enum_raises_valueerror(self):
|
||||
class WrongEnum(Enum):
|
||||
foo = 0
|
||||
bar = 1
|
||||
|
||||
order = self.Order()
|
||||
|
||||
with raises(ValueError):
|
||||
order.status = WrongEnum.foo
|
||||
|
||||
def test_setting_to_uncoerceable_type_raises_valueerror(self):
|
||||
order = self.Order()
|
||||
with raises(ValueError):
|
||||
order.status = 'Bad value'
|
||||
|
||||
def test_order_nullable_stores_none(self):
|
||||
# With nullable=False as in `Order`, a `None` value is always
|
||||
# converted to the default value, unless we explicitly set it to
|
||||
# sqlalchemy.sql.null(), so we use this class to test our ability
|
||||
# to set and retrive `None`.
|
||||
order_nullable = self.OrderNullable()
|
||||
assert order_nullable.status is None
|
||||
|
||||
order_nullable.status = None
|
||||
|
||||
self.session.add(order_nullable)
|
||||
self.session.commit()
|
||||
|
||||
assert order_nullable.status is None
|
||||
|
Reference in New Issue
Block a user