diff --git a/sqlalchemy_utils/types/country.py b/sqlalchemy_utils/types/country.py index 434cb49..b4c1efd 100644 --- a/sqlalchemy_utils/types/country.py +++ b/sqlalchemy_utils/types/country.py @@ -5,8 +5,11 @@ from sqlalchemy_utils import i18n class Country(object): - def __init__(self, code): - self.code = code + def __init__(self, code_or_country): + if isinstance(code_or_country, Country): + self.code = code_or_country.code + else: + self.code = code_or_country @property def name(self): @@ -15,6 +18,8 @@ class Country(object): def __eq__(self, other): if isinstance(other, Country): return self.code == other.code + elif isinstance(other, six.string_types): + return self.code == other else: return NotImplemented diff --git a/tests/types/test_country.py b/tests/types/test_country.py index 58802a6..45ed25b 100644 --- a/tests/types/test_country.py +++ b/tests/types/test_country.py @@ -13,6 +13,16 @@ def get_locale(): i18n.get_locale = get_locale +class TestCountry(object): + def test_init(self): + assert Country(u'fi') == Country(Country(u'fi')) + + def test_equality_operator(self): + assert Country(u'fi') == u'fi' + assert u'fi' == Country(u'fi') + assert Country(u'fi') == Country(u'fi') + + class TestCountryType(TestCase): def create_models(self): class User(self.Base):