add support for built-in enum or backported enum34.
This commit is contained in:
@@ -45,6 +45,13 @@ EncryptedType
|
|||||||
|
|
||||||
.. autoclass:: EncryptedType
|
.. autoclass:: EncryptedType
|
||||||
|
|
||||||
|
EnumType
|
||||||
|
^^^^^^^^
|
||||||
|
|
||||||
|
.. module:: sqlalchemy_utils.types.enum
|
||||||
|
|
||||||
|
.. autoclass:: EnumType
|
||||||
|
|
||||||
JSONType
|
JSONType
|
||||||
^^^^^^^^
|
^^^^^^^^
|
||||||
|
|
||||||
|
1
setup.py
1
setup.py
@@ -45,6 +45,7 @@ extras_require = {
|
|||||||
'password': ['passlib >= 1.6, < 2.0'],
|
'password': ['passlib >= 1.6, < 2.0'],
|
||||||
'color': ['colour>=0.0.4'],
|
'color': ['colour>=0.0.4'],
|
||||||
'ipaddress': ['ipaddr'] if not PY3 else [],
|
'ipaddress': ['ipaddr'] if not PY3 else [],
|
||||||
|
'enum': ['enum34'] if sys.version_info < (3, 4) else [],
|
||||||
'timezone': ['python-dateutil'],
|
'timezone': ['python-dateutil'],
|
||||||
'url': ['furl >= 0.4.1'],
|
'url': ['furl >= 0.4.1'],
|
||||||
'encrypted': ['cryptography>=0.6']
|
'encrypted': ['cryptography>=0.6']
|
||||||
|
@@ -65,6 +65,7 @@ from .types import (
|
|||||||
DateTimeRangeType,
|
DateTimeRangeType,
|
||||||
EmailType,
|
EmailType,
|
||||||
EncryptedType,
|
EncryptedType,
|
||||||
|
EnumType,
|
||||||
instrumented_list,
|
instrumented_list,
|
||||||
InstrumentedList,
|
InstrumentedList,
|
||||||
IntRangeType,
|
IntRangeType,
|
||||||
@@ -144,6 +145,7 @@ __all__ = (
|
|||||||
DateTimeRangeType,
|
DateTimeRangeType,
|
||||||
EmailType,
|
EmailType,
|
||||||
EncryptedType,
|
EncryptedType,
|
||||||
|
EnumType,
|
||||||
ExpressionParser,
|
ExpressionParser,
|
||||||
ImproperlyConfigured,
|
ImproperlyConfigured,
|
||||||
InstrumentedList,
|
InstrumentedList,
|
||||||
|
@@ -6,6 +6,7 @@ from .color import ColorType
|
|||||||
from .country import CountryType, Country
|
from .country import CountryType, Country
|
||||||
from .email import EmailType
|
from .email import EmailType
|
||||||
from .encrypted import EncryptedType
|
from .encrypted import EncryptedType
|
||||||
|
from .enum import EnumType
|
||||||
from .ip_address import IPAddressType
|
from .ip_address import IPAddressType
|
||||||
from .json import JSONType
|
from .json import JSONType
|
||||||
from .locale import LocaleType
|
from .locale import LocaleType
|
||||||
@@ -36,6 +37,7 @@ __all__ = (
|
|||||||
DateTimeRangeType,
|
DateTimeRangeType,
|
||||||
EmailType,
|
EmailType,
|
||||||
EncryptedType,
|
EncryptedType,
|
||||||
|
EnumType,
|
||||||
IntRangeType,
|
IntRangeType,
|
||||||
IPAddressType,
|
IPAddressType,
|
||||||
JSONType,
|
JSONType,
|
||||||
|
70
sqlalchemy_utils/types/enum.py
Normal file
70
sqlalchemy_utils/types/enum.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
try:
|
||||||
|
from enum import Enum
|
||||||
|
except ImportError:
|
||||||
|
Enum = None
|
||||||
|
|
||||||
|
from sqlalchemy import types
|
||||||
|
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||||
|
from .scalar_coercible import ScalarCoercible
|
||||||
|
|
||||||
|
|
||||||
|
class EnumType(types.TypeDecorator, ScalarCoercible):
|
||||||
|
"""
|
||||||
|
EnumType offers way of integrating with :mod:`enum` in the standard
|
||||||
|
library of Python 3.4+ or the enum34_ backported package on PyPI.
|
||||||
|
|
||||||
|
.. _enum34: https://pypi.python.org/pypi/enum34
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from sqlalchemy_utils import EnumType
|
||||||
|
|
||||||
|
|
||||||
|
class OrderStatus(Enum):
|
||||||
|
unpaid = 1
|
||||||
|
paid = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Order(Base):
|
||||||
|
__tablename__ = 'order'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True)
|
||||||
|
status = sa.Column(EnumType(OrderStatus))
|
||||||
|
|
||||||
|
|
||||||
|
order = Order()
|
||||||
|
order.status = OrderStatus.unpaid
|
||||||
|
session.add(order)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
assert user.status is OrderStatus.unpaid
|
||||||
|
assert user.status.value == 1
|
||||||
|
assert user.status.name == 'paid'
|
||||||
|
"""
|
||||||
|
|
||||||
|
impl = types.Integer
|
||||||
|
|
||||||
|
def __init__(self, enum_class, impl=None, *args, **kwargs):
|
||||||
|
if Enum is None:
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
"'enum34' package is required to use 'EnumType' in Python "
|
||||||
|
"< 3.4")
|
||||||
|
if not issubclass(enum_class, Enum):
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
"EnumType needs a class of enum defined.")
|
||||||
|
|
||||||
|
super(EnumType, self).__init__(*args, **kwargs)
|
||||||
|
self.enum_class = enum_class
|
||||||
|
if impl is not None:
|
||||||
|
self.impl = types.Integer
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return self.enum_class(value).value if value else None
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
return self.enum_class(value) if value else None
|
||||||
|
|
||||||
|
def _coerce(self, value):
|
||||||
|
return self.enum_class(value) if value else None
|
53
tests/types/test_enum.py
Normal file
53
tests/types/test_enum.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from pytest import mark
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.types import enum
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
@mark.skipif('enum.Enum is None')
|
||||||
|
class TestEnumType(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class OrderStatus(enum.Enum):
|
||||||
|
unpaid = 1
|
||||||
|
paid = 2
|
||||||
|
|
||||||
|
class Order(self.Base):
|
||||||
|
__tablename__ = 'document'
|
||||||
|
id_ = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
status = sa.Column(
|
||||||
|
enum.EnumType(OrderStatus), default=OrderStatus.unpaid)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Order(%r, %r)' % (self.id_, self.status)
|
||||||
|
|
||||||
|
def pay(self):
|
||||||
|
self.status = OrderStatus.paid
|
||||||
|
|
||||||
|
self.OrderStatus = OrderStatus
|
||||||
|
self.Order = Order
|
||||||
|
|
||||||
|
def test_parameter_processing(self):
|
||||||
|
order = self.Order()
|
||||||
|
|
||||||
|
self.session.add(order)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
order = self.session.query(self.Order).first()
|
||||||
|
assert order.status is self.OrderStatus.unpaid
|
||||||
|
assert order.status.value == 1
|
||||||
|
|
||||||
|
order.pay()
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
order = self.session.query(self.Order).first()
|
||||||
|
assert order.status is self.OrderStatus.paid
|
||||||
|
assert order.status.value == 2
|
||||||
|
|
||||||
|
def test_parameter_coercing(self):
|
||||||
|
order = self.Order()
|
||||||
|
order.status = 2
|
||||||
|
|
||||||
|
self.session.add(order)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert order.status is self.OrderStatus.paid
|
Reference in New Issue
Block a user