From e6d0e680dd1a8dc105efa3a0a4ac80e0b9e1bc4f Mon Sep 17 00:00:00 2001 From: Jiangge Zhang Date: Thu, 15 Jan 2015 16:32:32 +0800 Subject: [PATCH] add support for built-in enum or backported enum34. --- docs/data_types.rst | 7 +++ setup.py | 1 + sqlalchemy_utils/__init__.py | 2 + sqlalchemy_utils/types/__init__.py | 2 + sqlalchemy_utils/types/enum.py | 70 ++++++++++++++++++++++++++++++ tests/types/test_enum.py | 53 ++++++++++++++++++++++ 6 files changed, 135 insertions(+) create mode 100644 sqlalchemy_utils/types/enum.py create mode 100644 tests/types/test_enum.py diff --git a/docs/data_types.rst b/docs/data_types.rst index 9501570..28fc066 100644 --- a/docs/data_types.rst +++ b/docs/data_types.rst @@ -45,6 +45,13 @@ EncryptedType .. autoclass:: EncryptedType +EnumType +^^^^^^^^ + +.. module:: sqlalchemy_utils.types.enum + +.. autoclass:: EnumType + JSONType ^^^^^^^^ diff --git a/setup.py b/setup.py index fa822fa..bac51a7 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ extras_require = { 'password': ['passlib >= 1.6, < 2.0'], 'color': ['colour>=0.0.4'], 'ipaddress': ['ipaddr'] if not PY3 else [], + 'enum': ['enum34'] if sys.version_info < (3, 4) else [], 'timezone': ['python-dateutil'], 'url': ['furl >= 0.4.1'], 'encrypted': ['cryptography>=0.6'] diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 0e97c92..368ee1c 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -65,6 +65,7 @@ from .types import ( DateTimeRangeType, EmailType, EncryptedType, + EnumType, instrumented_list, InstrumentedList, IntRangeType, @@ -144,6 +145,7 @@ __all__ = ( DateTimeRangeType, EmailType, EncryptedType, + EnumType, ExpressionParser, ImproperlyConfigured, InstrumentedList, diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 9e37990..706c85f 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -6,6 +6,7 @@ from .color import ColorType from .country import CountryType, Country from .email import EmailType from .encrypted import EncryptedType +from .enum import EnumType from .ip_address import IPAddressType from .json import JSONType from .locale import LocaleType @@ -36,6 +37,7 @@ __all__ = ( DateTimeRangeType, EmailType, EncryptedType, + EnumType, IntRangeType, IPAddressType, JSONType, diff --git a/sqlalchemy_utils/types/enum.py b/sqlalchemy_utils/types/enum.py new file mode 100644 index 0000000..afb828d --- /dev/null +++ b/sqlalchemy_utils/types/enum.py @@ -0,0 +1,70 @@ +from __future__ import absolute_import + +try: + from enum import Enum +except ImportError: + Enum = None + +from sqlalchemy import types +from sqlalchemy_utils.exceptions import ImproperlyConfigured +from .scalar_coercible import ScalarCoercible + + +class EnumType(types.TypeDecorator, ScalarCoercible): + """ + EnumType offers way of integrating with :mod:`enum` in the standard + library of Python 3.4+ or the enum34_ backported package on PyPI. + + .. _enum34: https://pypi.python.org/pypi/enum34 + + :: + + from enum import Enum + from sqlalchemy_utils import EnumType + + + class OrderStatus(Enum): + unpaid = 1 + paid = 2 + + + class Order(Base): + __tablename__ = 'order' + id = sa.Column(sa.Integer, autoincrement=True) + status = sa.Column(EnumType(OrderStatus)) + + + order = Order() + order.status = OrderStatus.unpaid + session.add(order) + session.commit() + + assert user.status is OrderStatus.unpaid + assert user.status.value == 1 + assert user.status.name == 'paid' + """ + + impl = types.Integer + + def __init__(self, enum_class, impl=None, *args, **kwargs): + if Enum is None: + raise ImproperlyConfigured( + "'enum34' package is required to use 'EnumType' in Python " + "< 3.4") + if not issubclass(enum_class, Enum): + raise ImproperlyConfigured( + "EnumType needs a class of enum defined.") + + super(EnumType, self).__init__(*args, **kwargs) + self.enum_class = enum_class + if impl is not None: + self.impl = types.Integer + + def process_bind_param(self, value, dialect): + return self.enum_class(value).value if value else None + + def process_result_value(self, value, dialect): + return self.enum_class(value) if value else None + + def _coerce(self, value): + return self.enum_class(value) if value else None diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py new file mode 100644 index 0000000..8d94251 --- /dev/null +++ b/tests/types/test_enum.py @@ -0,0 +1,53 @@ +from pytest import mark +import sqlalchemy as sa +from sqlalchemy_utils.types import enum +from tests import TestCase + + +@mark.skipif('enum.Enum is None') +class TestEnumType(TestCase): + def create_models(self): + class OrderStatus(enum.Enum): + unpaid = 1 + paid = 2 + + class Order(self.Base): + __tablename__ = 'document' + id_ = sa.Column(sa.Integer, primary_key=True) + status = sa.Column( + enum.EnumType(OrderStatus), default=OrderStatus.unpaid) + + def __repr__(self): + return 'Order(%r, %r)' % (self.id_, self.status) + + def pay(self): + self.status = OrderStatus.paid + + self.OrderStatus = OrderStatus + self.Order = Order + + def test_parameter_processing(self): + order = self.Order() + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.unpaid + assert order.status.value == 1 + + order.pay() + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.paid + assert order.status.value == 2 + + def test_parameter_coercing(self): + order = self.Order() + order.status = 2 + + self.session.add(order) + self.session.commit() + + assert order.status is self.OrderStatus.paid