Make Country class hashable

This commit is contained in:
Konsta Vesterinen
2015-04-08 19:10:53 +03:00
parent 7f253ab9d9
commit 1eb1f83847
8 changed files with 59 additions and 49 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.29.10 (2015-04-xx)
^^^^^^^^^^^^^^^^^^^^
- Added __hash__ method to Country class
0.29.9 (2015-04-07)
^^^^^^^^^^^^^^^^^^^

View File

@@ -54,7 +54,7 @@ from .listeners import ( # noqa
)
from .models import Timestamp # noqa
from .observer import observes # noqa
from .primitives import Currency, WeekDay, WeekDays # noqa
from .primitives import Country, Currency, WeekDay, WeekDays # noqa
from .proxy_dict import proxy_dict, ProxyDict # noqa
from .query_chain import QueryChain # noqa
from .types import ( # noqa
@@ -62,7 +62,6 @@ from .types import ( # noqa
Choice,
ChoiceType,
ColorType,
Country,
CountryType,
CurrencyType,
DateRangeType,

View File

@@ -1,3 +1,4 @@
from .country import Country # noqa
from .currency import Currency # noqa
from .weekday import WeekDay # noqa
from .weekdays import WeekDays # noqa

View File

@@ -0,0 +1,31 @@
import six
from sqlalchemy_utils import i18n
class Country(object):
def __init__(self, code_or_country):
if isinstance(code_or_country, Country):
self.code = code_or_country.code
else:
self.code = code_or_country
@property
def name(self):
return i18n.get_locale().territories[self.code]
def __eq__(self, other):
if isinstance(other, Country):
return self.code == other.code
elif isinstance(other, six.string_types):
return self.code == other
else:
return NotImplemented
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, self.code)
def __unicode__(self):
return self.name

View File

@@ -5,7 +5,7 @@ from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
from .arrow import ArrowType
from .choice import Choice, ChoiceType
from .color import ColorType
from .country import Country, CountryType
from .country import CountryType
from .currency import CurrencyType
from .email import EmailType
from .encrypted import EncryptedType
@@ -32,7 +32,6 @@ __all__ = (
Choice,
ChoiceType,
ColorType,
Country,
CountryType,
CurrencyType,
DateRangeType,

View File

@@ -1,40 +1,10 @@
import six
from sqlalchemy import types
from sqlalchemy_utils import i18n
from sqlalchemy_utils.primitives import Country
from .scalar_coercible import ScalarCoercible
class Country(object):
def __init__(self, code_or_country):
if isinstance(code_or_country, Country):
self.code = code_or_country.code
else:
self.code = code_or_country
@property
def name(self):
return i18n.get_locale().territories[self.code]
def __eq__(self, other):
if isinstance(other, Country):
return self.code == other.code
elif isinstance(other, six.string_types):
return self.code == other
else:
return NotImplemented
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, self.code)
def __unicode__(self):
return self.name
class CountryType(types.TypeDecorator, ScalarCoercible):
"""
Changes Country objects to a string representation on the way in and

View File

@@ -0,0 +1,18 @@
from sqlalchemy_utils import Country
class TestCountry(object):
def test_init(self):
assert Country(u'fi') == Country(Country(u'fi'))
def test_equality_operator(self):
assert Country(u'fi') == u'fi'
assert u'fi' == Country(u'fi')
assert Country(u'fi') == Country(u'fi')
def test_non_equality_operator(self):
assert Country(u'fi') != u'sv'
assert not (Country(u'fi') != u'fi')
def test_hash(self):
return hash(Country('fi')) == hash('fi')

View File

@@ -4,20 +4,6 @@ from sqlalchemy_utils import Country, CountryType
from tests import TestCase
class TestCountry(object):
def test_init(self):
assert Country(u'fi') == Country(Country(u'fi'))
def test_equality_operator(self):
assert Country(u'fi') == u'fi'
assert u'fi' == Country(u'fi')
assert Country(u'fi') == Country(u'fi')
def test_non_equality_operator(self):
assert Country(u'fi') != u'sv'
assert not (Country(u'fi') != u'fi')
class TestCountryType(TestCase):
def create_models(self):
class User(self.Base):