add support for built-in enum or backported enum34.

This commit is contained in:
Jiangge Zhang
2015-01-15 16:32:32 +08:00
parent 0ef12b0a07
commit e6d0e680dd
6 changed files with 135 additions and 0 deletions

View File

@@ -45,6 +45,13 @@ EncryptedType
.. autoclass:: EncryptedType
EnumType
^^^^^^^^
.. module:: sqlalchemy_utils.types.enum
.. autoclass:: EnumType
JSONType
^^^^^^^^

View File

@@ -45,6 +45,7 @@ extras_require = {
'password': ['passlib >= 1.6, < 2.0'],
'color': ['colour>=0.0.4'],
'ipaddress': ['ipaddr'] if not PY3 else [],
'enum': ['enum34'] if sys.version_info < (3, 4) else [],
'timezone': ['python-dateutil'],
'url': ['furl >= 0.4.1'],
'encrypted': ['cryptography>=0.6']

View File

@@ -65,6 +65,7 @@ from .types import (
DateTimeRangeType,
EmailType,
EncryptedType,
EnumType,
instrumented_list,
InstrumentedList,
IntRangeType,
@@ -144,6 +145,7 @@ __all__ = (
DateTimeRangeType,
EmailType,
EncryptedType,
EnumType,
ExpressionParser,
ImproperlyConfigured,
InstrumentedList,

View File

@@ -6,6 +6,7 @@ from .color import ColorType
from .country import CountryType, Country
from .email import EmailType
from .encrypted import EncryptedType
from .enum import EnumType
from .ip_address import IPAddressType
from .json import JSONType
from .locale import LocaleType
@@ -36,6 +37,7 @@ __all__ = (
DateTimeRangeType,
EmailType,
EncryptedType,
EnumType,
IntRangeType,
IPAddressType,
JSONType,

View File

@@ -0,0 +1,70 @@
from __future__ import absolute_import
try:
from enum import Enum
except ImportError:
Enum = None
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
class EnumType(types.TypeDecorator, ScalarCoercible):
"""
EnumType offers way of integrating with :mod:`enum` in the standard
library of Python 3.4+ or the enum34_ backported package on PyPI.
.. _enum34: https://pypi.python.org/pypi/enum34
::
from enum import Enum
from sqlalchemy_utils import EnumType
class OrderStatus(Enum):
unpaid = 1
paid = 2
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, autoincrement=True)
status = sa.Column(EnumType(OrderStatus))
order = Order()
order.status = OrderStatus.unpaid
session.add(order)
session.commit()
assert user.status is OrderStatus.unpaid
assert user.status.value == 1
assert user.status.name == 'paid'
"""
impl = types.Integer
def __init__(self, enum_class, impl=None, *args, **kwargs):
if Enum is None:
raise ImproperlyConfigured(
"'enum34' package is required to use 'EnumType' in Python "
"< 3.4")
if not issubclass(enum_class, Enum):
raise ImproperlyConfigured(
"EnumType needs a class of enum defined.")
super(EnumType, self).__init__(*args, **kwargs)
self.enum_class = enum_class
if impl is not None:
self.impl = types.Integer
def process_bind_param(self, value, dialect):
return self.enum_class(value).value if value else None
def process_result_value(self, value, dialect):
return self.enum_class(value) if value else None
def _coerce(self, value):
return self.enum_class(value) if value else None

53
tests/types/test_enum.py Normal file
View File

@@ -0,0 +1,53 @@
from pytest import mark
import sqlalchemy as sa
from sqlalchemy_utils.types import enum
from tests import TestCase
@mark.skipif('enum.Enum is None')
class TestEnumType(TestCase):
def create_models(self):
class OrderStatus(enum.Enum):
unpaid = 1
paid = 2
class Order(self.Base):
__tablename__ = 'document'
id_ = sa.Column(sa.Integer, primary_key=True)
status = sa.Column(
enum.EnumType(OrderStatus), default=OrderStatus.unpaid)
def __repr__(self):
return 'Order(%r, %r)' % (self.id_, self.status)
def pay(self):
self.status = OrderStatus.paid
self.OrderStatus = OrderStatus
self.Order = Order
def test_parameter_processing(self):
order = self.Order()
self.session.add(order)
self.session.commit()
order = self.session.query(self.Order).first()
assert order.status is self.OrderStatus.unpaid
assert order.status.value == 1
order.pay()
self.session.commit()
order = self.session.query(self.Order).first()
assert order.status is self.OrderStatus.paid
assert order.status.value == 2
def test_parameter_coercing(self):
order = self.Order()
order.status = 2
self.session.add(order)
self.session.commit()
assert order.status is self.OrderStatus.paid