add support for built-in enum or backported enum34.
This commit is contained in:
@@ -45,6 +45,13 @@ EncryptedType
|
||||
|
||||
.. autoclass:: EncryptedType
|
||||
|
||||
EnumType
|
||||
^^^^^^^^
|
||||
|
||||
.. module:: sqlalchemy_utils.types.enum
|
||||
|
||||
.. autoclass:: EnumType
|
||||
|
||||
JSONType
|
||||
^^^^^^^^
|
||||
|
||||
|
1
setup.py
1
setup.py
@@ -45,6 +45,7 @@ extras_require = {
|
||||
'password': ['passlib >= 1.6, < 2.0'],
|
||||
'color': ['colour>=0.0.4'],
|
||||
'ipaddress': ['ipaddr'] if not PY3 else [],
|
||||
'enum': ['enum34'] if sys.version_info < (3, 4) else [],
|
||||
'timezone': ['python-dateutil'],
|
||||
'url': ['furl >= 0.4.1'],
|
||||
'encrypted': ['cryptography>=0.6']
|
||||
|
@@ -65,6 +65,7 @@ from .types import (
|
||||
DateTimeRangeType,
|
||||
EmailType,
|
||||
EncryptedType,
|
||||
EnumType,
|
||||
instrumented_list,
|
||||
InstrumentedList,
|
||||
IntRangeType,
|
||||
@@ -144,6 +145,7 @@ __all__ = (
|
||||
DateTimeRangeType,
|
||||
EmailType,
|
||||
EncryptedType,
|
||||
EnumType,
|
||||
ExpressionParser,
|
||||
ImproperlyConfigured,
|
||||
InstrumentedList,
|
||||
|
@@ -6,6 +6,7 @@ from .color import ColorType
|
||||
from .country import CountryType, Country
|
||||
from .email import EmailType
|
||||
from .encrypted import EncryptedType
|
||||
from .enum import EnumType
|
||||
from .ip_address import IPAddressType
|
||||
from .json import JSONType
|
||||
from .locale import LocaleType
|
||||
@@ -36,6 +37,7 @@ __all__ = (
|
||||
DateTimeRangeType,
|
||||
EmailType,
|
||||
EncryptedType,
|
||||
EnumType,
|
||||
IntRangeType,
|
||||
IPAddressType,
|
||||
JSONType,
|
||||
|
70
sqlalchemy_utils/types/enum.py
Normal file
70
sqlalchemy_utils/types/enum.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
try:
|
||||
from enum import Enum
|
||||
except ImportError:
|
||||
Enum = None
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class EnumType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
EnumType offers way of integrating with :mod:`enum` in the standard
|
||||
library of Python 3.4+ or the enum34_ backported package on PyPI.
|
||||
|
||||
.. _enum34: https://pypi.python.org/pypi/enum34
|
||||
|
||||
::
|
||||
|
||||
from enum import Enum
|
||||
from sqlalchemy_utils import EnumType
|
||||
|
||||
|
||||
class OrderStatus(Enum):
|
||||
unpaid = 1
|
||||
paid = 2
|
||||
|
||||
|
||||
class Order(Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
status = sa.Column(EnumType(OrderStatus))
|
||||
|
||||
|
||||
order = Order()
|
||||
order.status = OrderStatus.unpaid
|
||||
session.add(order)
|
||||
session.commit()
|
||||
|
||||
assert user.status is OrderStatus.unpaid
|
||||
assert user.status.value == 1
|
||||
assert user.status.name == 'paid'
|
||||
"""
|
||||
|
||||
impl = types.Integer
|
||||
|
||||
def __init__(self, enum_class, impl=None, *args, **kwargs):
|
||||
if Enum is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'enum34' package is required to use 'EnumType' in Python "
|
||||
"< 3.4")
|
||||
if not issubclass(enum_class, Enum):
|
||||
raise ImproperlyConfigured(
|
||||
"EnumType needs a class of enum defined.")
|
||||
|
||||
super(EnumType, self).__init__(*args, **kwargs)
|
||||
self.enum_class = enum_class
|
||||
if impl is not None:
|
||||
self.impl = types.Integer
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return self.enum_class(value).value if value else None
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return self.enum_class(value) if value else None
|
||||
|
||||
def _coerce(self, value):
|
||||
return self.enum_class(value) if value else None
|
53
tests/types/test_enum.py
Normal file
53
tests/types/test_enum.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from pytest import mark
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.types import enum
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@mark.skipif('enum.Enum is None')
|
||||
class TestEnumType(TestCase):
|
||||
def create_models(self):
|
||||
class OrderStatus(enum.Enum):
|
||||
unpaid = 1
|
||||
paid = 2
|
||||
|
||||
class Order(self.Base):
|
||||
__tablename__ = 'document'
|
||||
id_ = sa.Column(sa.Integer, primary_key=True)
|
||||
status = sa.Column(
|
||||
enum.EnumType(OrderStatus), default=OrderStatus.unpaid)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Order(%r, %r)' % (self.id_, self.status)
|
||||
|
||||
def pay(self):
|
||||
self.status = OrderStatus.paid
|
||||
|
||||
self.OrderStatus = OrderStatus
|
||||
self.Order = Order
|
||||
|
||||
def test_parameter_processing(self):
|
||||
order = self.Order()
|
||||
|
||||
self.session.add(order)
|
||||
self.session.commit()
|
||||
|
||||
order = self.session.query(self.Order).first()
|
||||
assert order.status is self.OrderStatus.unpaid
|
||||
assert order.status.value == 1
|
||||
|
||||
order.pay()
|
||||
self.session.commit()
|
||||
|
||||
order = self.session.query(self.Order).first()
|
||||
assert order.status is self.OrderStatus.paid
|
||||
assert order.status.value == 2
|
||||
|
||||
def test_parameter_coercing(self):
|
||||
order = self.Order()
|
||||
order.status = 2
|
||||
|
||||
self.session.add(order)
|
||||
self.session.commit()
|
||||
|
||||
assert order.status is self.OrderStatus.paid
|
Reference in New Issue
Block a user