From 1139141655f1c05e0ba604e31220b82e0a487404 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 8 Apr 2015 21:53:07 +0300 Subject: [PATCH] Country enhancements * Make Country validate itself during object initialization * Make Country string coercible --- CHANGES.rst | 2 + sqlalchemy_utils/primitives/country.py | 21 +++++++- tests/primitives/test_country.py | 66 ++++++++++++++++++++++---- tests/types/test_country.py | 4 +- 4 files changed, 82 insertions(+), 11 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index cfe9c6b..ad732f2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -8,6 +8,8 @@ Here you can see the full list of changes between each SQLAlchemy-Utils release. ^^^^^^^^^^^^^^^^^^^^ - Added __hash__ method to Country class +- Made Country validate itself during object initialization +- Made Country string coercible 0.29.9 (2015-04-07) diff --git a/sqlalchemy_utils/primitives/country.py b/sqlalchemy_utils/primitives/country.py index 986452b..c768169 100644 --- a/sqlalchemy_utils/primitives/country.py +++ b/sqlalchemy_utils/primitives/country.py @@ -1,18 +1,37 @@ import six from sqlalchemy_utils import i18n +from sqlalchemy_utils.utils import str_coercible +@str_coercible class Country(object): def __init__(self, code_or_country): if isinstance(code_or_country, Country): self.code = code_or_country.code - else: + elif isinstance(code_or_country, six.string_types): + self.validate(code_or_country) self.code = code_or_country + else: + raise TypeError( + "Country() argument must be a string or a country, not '{0}'" + .format( + type(code_or_country).__name__ + ) + ) @property def name(self): return i18n.get_locale().territories[self.code] + @classmethod + def validate(self, code): + try: + i18n.get_locale().territories[code] + except KeyError: + raise ValueError( + 'Could not convert string to country code: {0}'.format(code) + ) + def __eq__(self, other): if isinstance(other, Country): return self.code == other.code diff --git a/tests/primitives/test_country.py b/tests/primitives/test_country.py index c2662a9..95463bc 100644 --- a/tests/primitives/test_country.py +++ b/tests/primitives/test_country.py @@ -1,18 +1,68 @@ -from sqlalchemy_utils import Country +import six +from pytest import mark, raises + +from sqlalchemy_utils import Country, i18n +from sqlalchemy_utils.primitives.currency import babel # noqa +@mark.skipif('babel is None') class TestCountry(object): + def setup_method(self, method): + i18n.get_locale = lambda: babel.Locale('en') + def test_init(self): - assert Country(u'fi') == Country(Country(u'fi')) + assert Country(u'FI') == Country(Country(u'FI')) + + def test_constructor_with_wrong_type(self): + with raises(TypeError) as e: + Country(None) + assert str(e.value) == ( + "Country() argument must be a string or a country, not 'NoneType'" + ) + + def test_constructor_with_invalid_code(self): + with raises(ValueError) as e: + Country('SomeUnknownCode') + assert str(e.value) == ( + 'Could not convert string to country code: SomeUnknownCode' + ) + + @mark.parametrize( + 'code', + ( + 'FI', + 'US', + ) + ) + def test_validate_with_valid_codes(self, code): + Country.validate(code) + + def test_validate_with_invalid_code(self): + with raises(ValueError) as e: + Country.validate('SomeUnknownCode') + assert str(e.value) == ( + 'Could not convert string to country code: SomeUnknownCode' + ) def test_equality_operator(self): - assert Country(u'fi') == u'fi' - assert u'fi' == Country(u'fi') - assert Country(u'fi') == Country(u'fi') + assert Country(u'FI') == u'FI' + assert u'FI' == Country(u'FI') + assert Country(u'FI') == Country(u'FI') def test_non_equality_operator(self): - assert Country(u'fi') != u'sv' - assert not (Country(u'fi') != u'fi') + assert Country(u'FI') != u'sv' + assert not (Country(u'FI') != u'FI') def test_hash(self): - return hash(Country('fi')) == hash('fi') + return hash(Country('FI')) == hash('FI') + + def test_repr(self): + return repr(Country('FI')) == "Country('FI')" + + def test_unicode(self): + country = Country('FI') + assert six.text_type(country) == u'Finland' + + def test_str(self): + country = Country('FI') + assert str(country) == 'Finland' diff --git a/tests/types/test_country.py b/tests/types/test_country.py index 4003f2c..a0387c6 100644 --- a/tests/types/test_country.py +++ b/tests/types/test_country.py @@ -18,7 +18,7 @@ class TestCountryType(TestCase): def test_parameter_processing(self): user = self.User( - country=Country(u'fi') + country=Country(u'FI') ) self.session.add(user) @@ -28,6 +28,6 @@ class TestCountryType(TestCase): assert user.country.name == u'Finland' def test_scalar_attributes_get_coerced_to_objects(self): - user = self.User(country='fi') + user = self.User(country='FI') assert isinstance(user.country, Country)