Added some tests for country type, added ScalarCoercible type mixing

This commit is contained in:
Konsta Vesterinen
2013-08-13 11:43:02 +03:00
parent 39feea5a6f
commit 1a1aa5cc16
7 changed files with 95 additions and 12 deletions

View File

@@ -15,6 +15,8 @@ from .proxy_dict import ProxyDict, proxy_dict
from .types import ( from .types import (
ArrowType, ArrowType,
ColorType, ColorType,
CountryType,
Country,
EmailType, EmailType,
instrumented_list, instrumented_list,
InstrumentedList, InstrumentedList,
@@ -53,6 +55,8 @@ __all__ = (
with_backrefs, with_backrefs,
ArrowType, ArrowType,
ColorType, ColorType,
CountryType,
Country,
EmailType, EmailType,
ImproperlyConfigured, ImproperlyConfigured,
InstrumentedList, InstrumentedList,

View File

@@ -2,6 +2,7 @@ from functools import wraps
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
from .arrow import ArrowType from .arrow import ArrowType
from .color import ColorType from .color import ColorType
from .country import CountryType, Country
from .email import EmailType from .email import EmailType
from .ip_address import IPAddressType from .ip_address import IPAddressType
from .number_range import ( from .number_range import (
@@ -21,6 +22,8 @@ from .uuid import UUIDType
__all__ = ( __all__ = (
ArrowType, ArrowType,
ColorType, ColorType,
CountryType,
Country,
EmailType, EmailType,
IPAddressType, IPAddressType,
NumberRange, NumberRange,
@@ -39,6 +42,14 @@ __all__ = (
) )
class ScalarCoercedType(object):
def _coerce(self, value):
raise NotImplemented
def coercion_listener(self, target, value, oldvalue, initiator):
return self._coerce(value)
class InstrumentedList(_InstrumentedList): class InstrumentedList(_InstrumentedList):
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some """Enhanced version of SQLAlchemy InstrumentedList. Provides some
additional functionality.""" additional functionality."""

View File

@@ -1,6 +1,7 @@
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils import ImproperlyConfigured from sqlalchemy_utils import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
try: try:
@@ -12,7 +13,7 @@ except ImportError:
Color = None Color = None
class ColorType(types.TypeDecorator): class ColorType(types.TypeDecorator, ScalarCoercible):
""" """
Changes Color objects to a string representation on the way in and Changes Color objects to a string representation on the way in and
changes them back to Color objects on the way out. changes them back to Color objects on the way out.
@@ -40,7 +41,7 @@ class ColorType(types.TypeDecorator):
return Color(value) return Color(value)
return value return value
def coercion_listener(self, target, value, oldvalue, initiator): def _coerce(self, value):
if value is not None and not isinstance(value, Color): if value is not None and not isinstance(value, Color):
value = Color(value) return Color(value)
return value return value

View File

@@ -1,20 +1,25 @@
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils import ImproperlyConfigured from sqlalchemy_utils import ImproperlyConfigured
import six
from .scalar_coercible import ScalarCoercible
class Country(object): class Country(object):
get_locale = None get_locale = None
def __init__(self, code): def __init__(self, code, get_locale=None):
self.code = code self.code = code
if get_locale is not None:
self.get_locale = get_locale
if self.get_locale is None: if self.get_locale is None:
ImproperlyConfigured( raise ImproperlyConfigured(
"Country class needs define get_locale." "Country class needs define get_locale."
) )
@property @property
def name(self): def name(self):
return self.get_locale().territories[self.code] return self.get_locale.im_func().territories[self.code]
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Country): if isinstance(other, Country):
@@ -29,16 +34,32 @@ class Country(object):
return self.name return self.name
class CountryType(types.TypeDecorator): class CountryType(types.TypeDecorator, ScalarCoercible):
"""
Changes Country objects to a string representation on the way in and
changes them back to Country objects on the way out.
"""
impl = types.String(2) impl = types.String(2)
get_locale = None
def __init__(self, get_locale=None, *args, **kwargs):
if get_locale is not None:
self.get_locale = get_locale
types.TypeDecorator.__init__(self, *args, **kwargs)
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
if isinstance(value, Country): if isinstance(value, Country):
return value.code return value.code
if isinstance(value, basestring): if isinstance(value, six.string_types):
return value return value
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
if value is not None: if value is not None:
return Country(value, get_locale=self.get_locale)
def _coerce(self, value):
if value is not None and not isinstance(value, Country):
return Country(value) return Country(value)
return value

View File

@@ -2,6 +2,7 @@ import six
import weakref import weakref
from sqlalchemy_utils import ImproperlyConfigured from sqlalchemy_utils import ImproperlyConfigured
from sqlalchemy import types from sqlalchemy import types
from .scalar_coercible import ScalarCoercible
try: try:
import passlib import passlib
@@ -33,7 +34,7 @@ class Password(object):
return not (self == value) return not (self == value)
class PasswordType(types.TypeDecorator): class PasswordType(types.TypeDecorator, ScalarCoercible):
""" """
Hashes passwords as they come into the database and allows verifying Hashes passwords as they come into the database and allows verifying
them using a pythonic interface :: them using a pythonic interface ::
@@ -107,6 +108,3 @@ class PasswordType(types.TypeDecorator):
value.context = weakref.proxy(self.context) value.context = weakref.proxy(self.context)
return value return value
def coercion_listener(self, target, value, oldvalue, initiator):
return self._coerce(value)

View File

@@ -0,0 +1,6 @@
class ScalarCoercible(object):
def _coerce(self, value):
raise NotImplemented
def coercion_listener(self, target, value, oldvalue, initiator):
return self._coerce(value)

View File

@@ -0,0 +1,42 @@
import sqlalchemy as sa
from sqlalchemy_utils import CountryType, Country
from tests import TestCase
def get_locale():
class Locale():
territories = {'fi': 'Finland'}
return Locale()
Country.get_locale = get_locale
class TestCountryType(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
country = sa.Column(CountryType)
def __repr__(self):
return 'User(%r)' % self.id
self.User = User
def test_color_parameter_processing(self):
user = self.User(
country=Country(u'fi')
)
self.session.add(user)
self.session.commit()
user = self.session.query(self.User).first()
assert user.country.name == u'Finland'
def test_scalar_attributes_get_coerced_to_objects(self):
user = self.User(country='fi')
assert isinstance(user.country, Country)