Added some tests for country type, added ScalarCoercible type mixing
This commit is contained in:
@@ -15,6 +15,8 @@ from .proxy_dict import ProxyDict, proxy_dict
|
||||
from .types import (
|
||||
ArrowType,
|
||||
ColorType,
|
||||
CountryType,
|
||||
Country,
|
||||
EmailType,
|
||||
instrumented_list,
|
||||
InstrumentedList,
|
||||
@@ -53,6 +55,8 @@ __all__ = (
|
||||
with_backrefs,
|
||||
ArrowType,
|
||||
ColorType,
|
||||
CountryType,
|
||||
Country,
|
||||
EmailType,
|
||||
ImproperlyConfigured,
|
||||
InstrumentedList,
|
||||
|
@@ -2,6 +2,7 @@ from functools import wraps
|
||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||
from .arrow import ArrowType
|
||||
from .color import ColorType
|
||||
from .country import CountryType, Country
|
||||
from .email import EmailType
|
||||
from .ip_address import IPAddressType
|
||||
from .number_range import (
|
||||
@@ -21,6 +22,8 @@ from .uuid import UUIDType
|
||||
__all__ = (
|
||||
ArrowType,
|
||||
ColorType,
|
||||
CountryType,
|
||||
Country,
|
||||
EmailType,
|
||||
IPAddressType,
|
||||
NumberRange,
|
||||
@@ -39,6 +42,14 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
class ScalarCoercedType(object):
|
||||
def _coerce(self, value):
|
||||
raise NotImplemented
|
||||
|
||||
def coercion_listener(self, target, value, oldvalue, initiator):
|
||||
return self._coerce(value)
|
||||
|
||||
|
||||
class InstrumentedList(_InstrumentedList):
|
||||
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
||||
additional functionality."""
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import six
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy_utils import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
try:
|
||||
@@ -12,7 +13,7 @@ except ImportError:
|
||||
Color = None
|
||||
|
||||
|
||||
class ColorType(types.TypeDecorator):
|
||||
class ColorType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
Changes Color objects to a string representation on the way in and
|
||||
changes them back to Color objects on the way out.
|
||||
@@ -40,7 +41,7 @@ class ColorType(types.TypeDecorator):
|
||||
return Color(value)
|
||||
return value
|
||||
|
||||
def coercion_listener(self, target, value, oldvalue, initiator):
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, Color):
|
||||
value = Color(value)
|
||||
return Color(value)
|
||||
return value
|
||||
|
@@ -1,20 +1,25 @@
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy_utils import ImproperlyConfigured
|
||||
import six
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class Country(object):
|
||||
get_locale = None
|
||||
|
||||
def __init__(self, code):
|
||||
def __init__(self, code, get_locale=None):
|
||||
self.code = code
|
||||
if get_locale is not None:
|
||||
self.get_locale = get_locale
|
||||
|
||||
if self.get_locale is None:
|
||||
ImproperlyConfigured(
|
||||
raise ImproperlyConfigured(
|
||||
"Country class needs define get_locale."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.get_locale().territories[self.code]
|
||||
return self.get_locale.im_func().territories[self.code]
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Country):
|
||||
@@ -29,16 +34,32 @@ class Country(object):
|
||||
return self.name
|
||||
|
||||
|
||||
class CountryType(types.TypeDecorator):
|
||||
class CountryType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
Changes Country objects to a string representation on the way in and
|
||||
changes them back to Country objects on the way out.
|
||||
"""
|
||||
|
||||
impl = types.String(2)
|
||||
get_locale = None
|
||||
|
||||
def __init__(self, get_locale=None, *args, **kwargs):
|
||||
if get_locale is not None:
|
||||
self.get_locale = get_locale
|
||||
types.TypeDecorator.__init__(self, *args, **kwargs)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, Country):
|
||||
return value.code
|
||||
|
||||
if isinstance(value, basestring):
|
||||
if isinstance(value, six.string_types):
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
return Country(value, get_locale=self.get_locale)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, Country):
|
||||
return Country(value)
|
||||
return value
|
||||
|
@@ -2,6 +2,7 @@ import six
|
||||
import weakref
|
||||
from sqlalchemy_utils import ImproperlyConfigured
|
||||
from sqlalchemy import types
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
try:
|
||||
import passlib
|
||||
@@ -33,7 +34,7 @@ class Password(object):
|
||||
return not (self == value)
|
||||
|
||||
|
||||
class PasswordType(types.TypeDecorator):
|
||||
class PasswordType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
Hashes passwords as they come into the database and allows verifying
|
||||
them using a pythonic interface ::
|
||||
@@ -107,6 +108,3 @@ class PasswordType(types.TypeDecorator):
|
||||
value.context = weakref.proxy(self.context)
|
||||
|
||||
return value
|
||||
|
||||
def coercion_listener(self, target, value, oldvalue, initiator):
|
||||
return self._coerce(value)
|
||||
|
6
sqlalchemy_utils/types/scalar_coercible.py
Normal file
6
sqlalchemy_utils/types/scalar_coercible.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class ScalarCoercible(object):
|
||||
def _coerce(self, value):
|
||||
raise NotImplemented
|
||||
|
||||
def coercion_listener(self, target, value, oldvalue, initiator):
|
||||
return self._coerce(value)
|
42
tests/types/test_country.py
Normal file
42
tests/types/test_country.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import CountryType, Country
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
def get_locale():
|
||||
class Locale():
|
||||
territories = {'fi': 'Finland'}
|
||||
|
||||
return Locale()
|
||||
|
||||
|
||||
Country.get_locale = get_locale
|
||||
|
||||
|
||||
class TestCountryType(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
country = sa.Column(CountryType)
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.id
|
||||
|
||||
self.User = User
|
||||
|
||||
def test_color_parameter_processing(self):
|
||||
user = self.User(
|
||||
country=Country(u'fi')
|
||||
)
|
||||
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
user = self.session.query(self.User).first()
|
||||
assert user.country.name == u'Finland'
|
||||
|
||||
def test_scalar_attributes_get_coerced_to_objects(self):
|
||||
user = self.User(country='fi')
|
||||
|
||||
assert isinstance(user.country, Country)
|
Reference in New Issue
Block a user