diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 6dcd518..6e04147 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -19,6 +19,8 @@ from .generic import generic_relationship from .proxy_dict import ProxyDict, proxy_dict from .types import ( ArrowType, + Choice, + ChoiceType, ColorType, CountryType, Country, @@ -65,6 +67,8 @@ __all__ = ( table_name, with_backrefs, ArrowType, + Choice, + ChoiceType, ColorType, CountryType, Country, diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 854f56a..1f634b7 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -1,6 +1,7 @@ from functools import wraps from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from .arrow import ArrowType +from .choice import ChoiceType, Choice from .color import ColorType from .country import CountryType, Country from .email import EmailType @@ -24,6 +25,8 @@ from .weekdays import WeekDay, WeekDays, WeekDaysType __all__ = ( ArrowType, + Choice, + ChoiceType, ColorType, CountryType, Country, diff --git a/sqlalchemy_utils/types/choice.py b/sqlalchemy_utils/types/choice.py new file mode 100644 index 0000000..c32aacb --- /dev/null +++ b/sqlalchemy_utils/types/choice.py @@ -0,0 +1,56 @@ +from sqlalchemy import types +import six +from ..exceptions import ImproperlyConfigured +from .scalar_coercible import ScalarCoercible + + +class Choice(object): + def __init__(self, code, value): + self.code = code + self.value = value + + def __eq__(self, other): + if isinstance(other, Choice): + return self.code == other.code + return other == self.code + + def __ne__(self, other): + return not (self == other) + + def __unicode__(self): + return six.text_type(self.value) + + def __repr__(self): + return 'Choice(code={code}, value={value})'.format( + code=self.code, + value=self.value + ) + + +class ChoiceType(types.TypeDecorator, ScalarCoercible): + impl = types.Unicode(255) + + def __init__(self, choices): + if not choices: + raise ImproperlyConfigured( + 'ChoiceType needs list of choices defined.' + ) + self.choices = choices + self.choices_dict = dict(choices) + + def _coerce(self, value): + if value is None: + return value + if isinstance(value, Choice): + return value + return Choice(value, self.choices_dict[value]) + + def process_bind_param(self, value, dialect): + if value: + return value.code + return value + + def process_result_value(self, value, dialect): + if value: + return Choice(value, self.choices_dict[value]) + return value diff --git a/tests/types/test_choice_type.py b/tests/types/test_choice_type.py new file mode 100644 index 0000000..7c78dd5 --- /dev/null +++ b/tests/types/test_choice_type.py @@ -0,0 +1,56 @@ +from pytest import raises +import sqlalchemy as sa +from sqlalchemy_utils import ChoiceType, Choice, ImproperlyConfigured +from tests import TestCase + + +class TestChoice(object): + # def test_init(self): + # assert Choice(1, 1) == Choice(Choice(1, 1)) + + def test_equality_operator(self): + assert Choice(1, 1) == 1 + assert 1 == Choice(1, 1) + assert Choice(1, 1) == Choice(1, 1) + + def test_non_equality_operator(self): + assert Choice(1, 1) != 2 + assert not (Choice(1, 1) != 1) + + +class TestChoiceType(TestCase): + def create_models(self): + class User(self.Base): + TYPES = [ + ('admin', 'Admin'), + ('regular-user', 'Regular user') + ] + + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(ChoiceType(TYPES)) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_parameter_processing(self): + user = self.User( + type=u'admin' + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.type.value == u'Admin' + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(type=u'admin') + + assert isinstance(user.type, Choice) + + def test_throws_exception_if_no_choices_given(self): + with raises(ImproperlyConfigured): + ChoiceType([])