From 1eb1f83847928aadf650b2cb27c21108aa3806c9 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 8 Apr 2015 19:10:53 +0300 Subject: [PATCH] Make Country class hashable --- CHANGES.rst | 6 +++++ sqlalchemy_utils/__init__.py | 3 +-- sqlalchemy_utils/primitives/__init__.py | 1 + sqlalchemy_utils/primitives/country.py | 31 ++++++++++++++++++++++++ sqlalchemy_utils/types/__init__.py | 3 +-- sqlalchemy_utils/types/country.py | 32 +------------------------ tests/primitives/test_country.py | 18 ++++++++++++++ tests/types/test_country.py | 14 ----------- 8 files changed, 59 insertions(+), 49 deletions(-) create mode 100644 sqlalchemy_utils/primitives/country.py create mode 100644 tests/primitives/test_country.py diff --git a/CHANGES.rst b/CHANGES.rst index c1912ec..cfe9c6b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.29.10 (2015-04-xx) +^^^^^^^^^^^^^^^^^^^^ + +- Added __hash__ method to Country class + + 0.29.9 (2015-04-07) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index c8344e9..5717885 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -54,7 +54,7 @@ from .listeners import ( # noqa ) from .models import Timestamp # noqa from .observer import observes # noqa -from .primitives import Currency, WeekDay, WeekDays # noqa +from .primitives import Country, Currency, WeekDay, WeekDays # noqa from .proxy_dict import proxy_dict, ProxyDict # noqa from .query_chain import QueryChain # noqa from .types import ( # noqa @@ -62,7 +62,6 @@ from .types import ( # noqa Choice, ChoiceType, ColorType, - Country, CountryType, CurrencyType, DateRangeType, diff --git a/sqlalchemy_utils/primitives/__init__.py b/sqlalchemy_utils/primitives/__init__.py index f09d888..71a5829 100644 --- a/sqlalchemy_utils/primitives/__init__.py +++ b/sqlalchemy_utils/primitives/__init__.py @@ -1,3 +1,4 @@ +from .country import Country # noqa from .currency import Currency # noqa from .weekday import WeekDay # noqa from .weekdays import WeekDays # noqa diff --git a/sqlalchemy_utils/primitives/country.py b/sqlalchemy_utils/primitives/country.py new file mode 100644 index 0000000..986452b --- /dev/null +++ b/sqlalchemy_utils/primitives/country.py @@ -0,0 +1,31 @@ +import six +from sqlalchemy_utils import i18n + + +class Country(object): + def __init__(self, code_or_country): + if isinstance(code_or_country, Country): + self.code = code_or_country.code + else: + self.code = code_or_country + + @property + def name(self): + return i18n.get_locale().territories[self.code] + + def __eq__(self, other): + if isinstance(other, Country): + return self.code == other.code + elif isinstance(other, six.string_types): + return self.code == other + else: + return NotImplemented + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.code) + + def __unicode__(self): + return self.name diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 8a44b55..4aa4bcc 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -5,7 +5,7 @@ from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from .arrow import ArrowType from .choice import Choice, ChoiceType from .color import ColorType -from .country import Country, CountryType +from .country import CountryType from .currency import CurrencyType from .email import EmailType from .encrypted import EncryptedType @@ -32,7 +32,6 @@ __all__ = ( Choice, ChoiceType, ColorType, - Country, CountryType, CurrencyType, DateRangeType, diff --git a/sqlalchemy_utils/types/country.py b/sqlalchemy_utils/types/country.py index 9baeaaa..7f591d3 100644 --- a/sqlalchemy_utils/types/country.py +++ b/sqlalchemy_utils/types/country.py @@ -1,40 +1,10 @@ import six from sqlalchemy import types -from sqlalchemy_utils import i18n - +from sqlalchemy_utils.primitives import Country from .scalar_coercible import ScalarCoercible -class Country(object): - def __init__(self, code_or_country): - if isinstance(code_or_country, Country): - self.code = code_or_country.code - else: - self.code = code_or_country - - @property - def name(self): - return i18n.get_locale().territories[self.code] - - def __eq__(self, other): - if isinstance(other, Country): - return self.code == other.code - elif isinstance(other, six.string_types): - return self.code == other - else: - return NotImplemented - - def __ne__(self, other): - return not (self == other) - - def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self.code) - - def __unicode__(self): - return self.name - - class CountryType(types.TypeDecorator, ScalarCoercible): """ Changes Country objects to a string representation on the way in and diff --git a/tests/primitives/test_country.py b/tests/primitives/test_country.py new file mode 100644 index 0000000..c2662a9 --- /dev/null +++ b/tests/primitives/test_country.py @@ -0,0 +1,18 @@ +from sqlalchemy_utils import Country + + +class TestCountry(object): + def test_init(self): + assert Country(u'fi') == Country(Country(u'fi')) + + def test_equality_operator(self): + assert Country(u'fi') == u'fi' + assert u'fi' == Country(u'fi') + assert Country(u'fi') == Country(u'fi') + + def test_non_equality_operator(self): + assert Country(u'fi') != u'sv' + assert not (Country(u'fi') != u'fi') + + def test_hash(self): + return hash(Country('fi')) == hash('fi') diff --git a/tests/types/test_country.py b/tests/types/test_country.py index 2df8ff3..4003f2c 100644 --- a/tests/types/test_country.py +++ b/tests/types/test_country.py @@ -4,20 +4,6 @@ from sqlalchemy_utils import Country, CountryType from tests import TestCase -class TestCountry(object): - def test_init(self): - assert Country(u'fi') == Country(Country(u'fi')) - - def test_equality_operator(self): - assert Country(u'fi') == u'fi' - assert u'fi' == Country(u'fi') - assert Country(u'fi') == Country(u'fi') - - def test_non_equality_operator(self): - assert Country(u'fi') != u'sv' - assert not (Country(u'fi') != u'fi') - - class TestCountryType(TestCase): def create_models(self): class User(self.Base):