add support for built-in enum or backported enum34.

This commit is contained in:
Jiangge Zhang
2015-01-15 16:32:32 +08:00
parent 0ef12b0a07
commit e6d0e680dd
6 changed files with 135 additions and 0 deletions

View File

@@ -45,6 +45,13 @@ EncryptedType
.. autoclass:: EncryptedType .. autoclass:: EncryptedType
EnumType
^^^^^^^^
.. module:: sqlalchemy_utils.types.enum
.. autoclass:: EnumType
JSONType JSONType
^^^^^^^^ ^^^^^^^^

View File

@@ -45,6 +45,7 @@ extras_require = {
'password': ['passlib >= 1.6, < 2.0'], 'password': ['passlib >= 1.6, < 2.0'],
'color': ['colour>=0.0.4'], 'color': ['colour>=0.0.4'],
'ipaddress': ['ipaddr'] if not PY3 else [], 'ipaddress': ['ipaddr'] if not PY3 else [],
'enum': ['enum34'] if sys.version_info < (3, 4) else [],
'timezone': ['python-dateutil'], 'timezone': ['python-dateutil'],
'url': ['furl >= 0.4.1'], 'url': ['furl >= 0.4.1'],
'encrypted': ['cryptography>=0.6'] 'encrypted': ['cryptography>=0.6']

View File

@@ -65,6 +65,7 @@ from .types import (
DateTimeRangeType, DateTimeRangeType,
EmailType, EmailType,
EncryptedType, EncryptedType,
EnumType,
instrumented_list, instrumented_list,
InstrumentedList, InstrumentedList,
IntRangeType, IntRangeType,
@@ -144,6 +145,7 @@ __all__ = (
DateTimeRangeType, DateTimeRangeType,
EmailType, EmailType,
EncryptedType, EncryptedType,
EnumType,
ExpressionParser, ExpressionParser,
ImproperlyConfigured, ImproperlyConfigured,
InstrumentedList, InstrumentedList,

View File

@@ -6,6 +6,7 @@ from .color import ColorType
from .country import CountryType, Country from .country import CountryType, Country
from .email import EmailType from .email import EmailType
from .encrypted import EncryptedType from .encrypted import EncryptedType
from .enum import EnumType
from .ip_address import IPAddressType from .ip_address import IPAddressType
from .json import JSONType from .json import JSONType
from .locale import LocaleType from .locale import LocaleType
@@ -36,6 +37,7 @@ __all__ = (
DateTimeRangeType, DateTimeRangeType,
EmailType, EmailType,
EncryptedType, EncryptedType,
EnumType,
IntRangeType, IntRangeType,
IPAddressType, IPAddressType,
JSONType, JSONType,

View File

@@ -0,0 +1,70 @@
from __future__ import absolute_import
try:
from enum import Enum
except ImportError:
Enum = None
from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
class EnumType(types.TypeDecorator, ScalarCoercible):
"""
EnumType offers way of integrating with :mod:`enum` in the standard
library of Python 3.4+ or the enum34_ backported package on PyPI.
.. _enum34: https://pypi.python.org/pypi/enum34
::
from enum import Enum
from sqlalchemy_utils import EnumType
class OrderStatus(Enum):
unpaid = 1
paid = 2
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, autoincrement=True)
status = sa.Column(EnumType(OrderStatus))
order = Order()
order.status = OrderStatus.unpaid
session.add(order)
session.commit()
assert user.status is OrderStatus.unpaid
assert user.status.value == 1
assert user.status.name == 'paid'
"""
impl = types.Integer
def __init__(self, enum_class, impl=None, *args, **kwargs):
if Enum is None:
raise ImproperlyConfigured(
"'enum34' package is required to use 'EnumType' in Python "
"< 3.4")
if not issubclass(enum_class, Enum):
raise ImproperlyConfigured(
"EnumType needs a class of enum defined.")
super(EnumType, self).__init__(*args, **kwargs)
self.enum_class = enum_class
if impl is not None:
self.impl = types.Integer
def process_bind_param(self, value, dialect):
return self.enum_class(value).value if value else None
def process_result_value(self, value, dialect):
return self.enum_class(value) if value else None
def _coerce(self, value):
return self.enum_class(value) if value else None

53
tests/types/test_enum.py Normal file
View File

@@ -0,0 +1,53 @@
from pytest import mark
import sqlalchemy as sa
from sqlalchemy_utils.types import enum
from tests import TestCase
@mark.skipif('enum.Enum is None')
class TestEnumType(TestCase):
def create_models(self):
class OrderStatus(enum.Enum):
unpaid = 1
paid = 2
class Order(self.Base):
__tablename__ = 'document'
id_ = sa.Column(sa.Integer, primary_key=True)
status = sa.Column(
enum.EnumType(OrderStatus), default=OrderStatus.unpaid)
def __repr__(self):
return 'Order(%r, %r)' % (self.id_, self.status)
def pay(self):
self.status = OrderStatus.paid
self.OrderStatus = OrderStatus
self.Order = Order
def test_parameter_processing(self):
order = self.Order()
self.session.add(order)
self.session.commit()
order = self.session.query(self.Order).first()
assert order.status is self.OrderStatus.unpaid
assert order.status.value == 1
order.pay()
self.session.commit()
order = self.session.query(self.Order).first()
assert order.status is self.OrderStatus.paid
assert order.status.value == 2
def test_parameter_coercing(self):
order = self.Order()
order.status = 2
self.session.add(order)
self.session.commit()
assert order.status is self.OrderStatus.paid