From 1a1aa5cc16c073df2e8de8721331c69896a7accf Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 13 Aug 2013 11:43:02 +0300 Subject: [PATCH] Added some tests for country type, added ScalarCoercible type mixing --- sqlalchemy_utils/__init__.py | 4 +++ sqlalchemy_utils/types/__init__.py | 11 ++++++ sqlalchemy_utils/types/color.py | 7 ++-- sqlalchemy_utils/types/country.py | 31 +++++++++++++--- sqlalchemy_utils/types/password.py | 6 ++-- sqlalchemy_utils/types/scalar_coercible.py | 6 ++++ tests/types/test_country.py | 42 ++++++++++++++++++++++ 7 files changed, 95 insertions(+), 12 deletions(-) create mode 100644 sqlalchemy_utils/types/scalar_coercible.py create mode 100644 tests/types/test_country.py diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 1c0727a..8cf3d49 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -15,6 +15,8 @@ from .proxy_dict import ProxyDict, proxy_dict from .types import ( ArrowType, ColorType, + CountryType, + Country, EmailType, instrumented_list, InstrumentedList, @@ -53,6 +55,8 @@ __all__ = ( with_backrefs, ArrowType, ColorType, + CountryType, + Country, EmailType, ImproperlyConfigured, InstrumentedList, diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 6d5f35b..3259692 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -2,6 +2,7 @@ from functools import wraps from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from .arrow import ArrowType from .color import ColorType +from .country import CountryType, Country from .email import EmailType from .ip_address import IPAddressType from .number_range import ( @@ -21,6 +22,8 @@ from .uuid import UUIDType __all__ = ( ArrowType, ColorType, + CountryType, + Country, EmailType, IPAddressType, NumberRange, @@ -39,6 +42,14 @@ __all__ = ( ) +class ScalarCoercedType(object): + def _coerce(self, value): + raise NotImplemented + + def coercion_listener(self, target, value, oldvalue, initiator): + return self._coerce(value) + + class InstrumentedList(_InstrumentedList): """Enhanced version of SQLAlchemy InstrumentedList. Provides some additional functionality.""" diff --git a/sqlalchemy_utils/types/color.py b/sqlalchemy_utils/types/color.py index 689ce9a..2e2e450 100644 --- a/sqlalchemy_utils/types/color.py +++ b/sqlalchemy_utils/types/color.py @@ -1,6 +1,7 @@ import six from sqlalchemy import types from sqlalchemy_utils import ImproperlyConfigured +from .scalar_coercible import ScalarCoercible try: @@ -12,7 +13,7 @@ except ImportError: Color = None -class ColorType(types.TypeDecorator): +class ColorType(types.TypeDecorator, ScalarCoercible): """ Changes Color objects to a string representation on the way in and changes them back to Color objects on the way out. @@ -40,7 +41,7 @@ class ColorType(types.TypeDecorator): return Color(value) return value - def coercion_listener(self, target, value, oldvalue, initiator): + def _coerce(self, value): if value is not None and not isinstance(value, Color): - value = Color(value) + return Color(value) return value diff --git a/sqlalchemy_utils/types/country.py b/sqlalchemy_utils/types/country.py index 5cc442b..5994f19 100644 --- a/sqlalchemy_utils/types/country.py +++ b/sqlalchemy_utils/types/country.py @@ -1,20 +1,25 @@ from sqlalchemy import types from sqlalchemy_utils import ImproperlyConfigured +import six +from .scalar_coercible import ScalarCoercible class Country(object): get_locale = None - def __init__(self, code): + def __init__(self, code, get_locale=None): self.code = code + if get_locale is not None: + self.get_locale = get_locale + if self.get_locale is None: - ImproperlyConfigured( + raise ImproperlyConfigured( "Country class needs define get_locale." ) @property def name(self): - return self.get_locale().territories[self.code] + return self.get_locale.im_func().territories[self.code] def __eq__(self, other): if isinstance(other, Country): @@ -29,16 +34,32 @@ class Country(object): return self.name -class CountryType(types.TypeDecorator): +class CountryType(types.TypeDecorator, ScalarCoercible): + """ + Changes Country objects to a string representation on the way in and + changes them back to Country objects on the way out. + """ + impl = types.String(2) + get_locale = None + + def __init__(self, get_locale=None, *args, **kwargs): + if get_locale is not None: + self.get_locale = get_locale + types.TypeDecorator.__init__(self, *args, **kwargs) def process_bind_param(self, value, dialect): if isinstance(value, Country): return value.code - if isinstance(value, basestring): + if isinstance(value, six.string_types): return value def process_result_value(self, value, dialect): if value is not None: + return Country(value, get_locale=self.get_locale) + + def _coerce(self, value): + if value is not None and not isinstance(value, Country): return Country(value) + return value diff --git a/sqlalchemy_utils/types/password.py b/sqlalchemy_utils/types/password.py index 01bc9c9..ade3524 100644 --- a/sqlalchemy_utils/types/password.py +++ b/sqlalchemy_utils/types/password.py @@ -2,6 +2,7 @@ import six import weakref from sqlalchemy_utils import ImproperlyConfigured from sqlalchemy import types +from .scalar_coercible import ScalarCoercible try: import passlib @@ -33,7 +34,7 @@ class Password(object): return not (self == value) -class PasswordType(types.TypeDecorator): +class PasswordType(types.TypeDecorator, ScalarCoercible): """ Hashes passwords as they come into the database and allows verifying them using a pythonic interface :: @@ -107,6 +108,3 @@ class PasswordType(types.TypeDecorator): value.context = weakref.proxy(self.context) return value - - def coercion_listener(self, target, value, oldvalue, initiator): - return self._coerce(value) diff --git a/sqlalchemy_utils/types/scalar_coercible.py b/sqlalchemy_utils/types/scalar_coercible.py new file mode 100644 index 0000000..ec436cc --- /dev/null +++ b/sqlalchemy_utils/types/scalar_coercible.py @@ -0,0 +1,6 @@ +class ScalarCoercible(object): + def _coerce(self, value): + raise NotImplemented + + def coercion_listener(self, target, value, oldvalue, initiator): + return self._coerce(value) diff --git a/tests/types/test_country.py b/tests/types/test_country.py new file mode 100644 index 0000000..9bd5d13 --- /dev/null +++ b/tests/types/test_country.py @@ -0,0 +1,42 @@ +import sqlalchemy as sa +from sqlalchemy_utils import CountryType, Country +from tests import TestCase + + +def get_locale(): + class Locale(): + territories = {'fi': 'Finland'} + + return Locale() + + +Country.get_locale = get_locale + + +class TestCountryType(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + country = sa.Column(CountryType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_color_parameter_processing(self): + user = self.User( + country=Country(u'fi') + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.country.name == u'Finland' + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(country='fi') + + assert isinstance(user.country, Country)