diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 5101395..6a7e4c6 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -54,6 +54,7 @@ from .listeners import ( # noqa ) from .models import Timestamp # noqa from .observer import observes # noqa +from .primitives import Currency, WeekDay, WeekDays # noqa from .proxy_dict import proxy_dict, ProxyDict # noqa from .query_chain import QueryChain # noqa from .types import ( # noqa diff --git a/sqlalchemy_utils/primitives/__init__.py b/sqlalchemy_utils/primitives/__init__.py index 6831294..f09d888 100644 --- a/sqlalchemy_utils/primitives/__init__.py +++ b/sqlalchemy_utils/primitives/__init__.py @@ -1,7 +1,3 @@ -from .weekday import WeekDay -from .weekdays import WeekDays - -__all__ = ( - WeekDay, - WeekDays -) +from .currency import Currency # noqa +from .weekday import WeekDay # noqa +from .weekdays import WeekDays # noqa diff --git a/sqlalchemy_utils/primitives/currency.py b/sqlalchemy_utils/primitives/currency.py new file mode 100644 index 0000000..46e3848 --- /dev/null +++ b/sqlalchemy_utils/primitives/currency.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +import six +from babel.numbers import get_currency_symbol + +from sqlalchemy_utils import i18n +from sqlalchemy_utils.utils import str_coercible + + +@str_coercible +class Currency(object): + def __init__(self, code): + if isinstance(code, Currency): + self.code = code + elif isinstance(code, six.string_types): + self.validate(code) + self.code = code + else: + raise TypeError( + 'First argument given to Currency constructor should be ' + 'either an instance of Currency or valid three letter ' + 'currency code.' + ) + + @classmethod + def validate(self, code): + try: + i18n.get_locale().currencies[code] + except KeyError: + raise ValueError("{0}' is not valid currency code.") + + @property + def symbol(self): + return get_currency_symbol(self.code, i18n.get_locale()) + + @property + def name(self): + return i18n.get_locale().currencies[self.code] + + def __eq__(self, other): + if isinstance(other, Currency): + return self.code == other.code + elif isinstance(other, six.string_types): + return self.code == other + else: + return NotImplemented + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.code) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.code) + + def __unicode__(self): + return self.name diff --git a/tests/primitives/test_currency.py b/tests/primitives/test_currency.py new file mode 100644 index 0000000..c8e6d4d --- /dev/null +++ b/tests/primitives/test_currency.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +import six +from babel import Locale +from pytest import mark, raises + +from sqlalchemy_utils import Currency, i18n + + +class TestCurrency(object): + def setup_method(self, method): + i18n.get_locale = lambda: Locale('en') + + def test_init(self): + assert Currency('USD') == Currency(Currency('USD')) + + def test_hashability(self): + assert len(set([Currency('USD'), Currency('USD')])) == 1 + + def test_invalid_currency_code(self): + with raises(ValueError): + Currency('Unknown code') + + def test_invalid_currency_code_type(self): + with raises(TypeError): + Currency(None) + + @mark.parametrize( + ('code', 'name'), + ( + ('USD', 'US Dollar'), + ('EUR', 'Euro') + ) + ) + def test_name_property(self, code, name): + assert Currency(code).name == name + + @mark.parametrize( + ('code', 'symbol'), + ( + ('USD', u'$'), + ('EUR', u'€') + ) + ) + def test_symbol_property(self, code, symbol): + assert Currency(code).symbol == symbol + + def test_equality_operator(self): + assert Currency('USD') == 'USD' + assert 'USD' == Currency('USD') + assert Currency('USD') == Currency('USD') + + def test_non_equality_operator(self): + assert Currency('USD') != 'EUR' + assert not (Currency('USD') != 'USD') + + def test_unicode(self): + currency = Currency('USD') + assert six.text_type(currency) == u'US Dollar' + + def test_str(self): + currency = Currency('USD') + assert str(currency) == 'US Dollar' + + def test_representation(self): + currency = Currency('USD') + assert repr(currency) == "Currency('USD')"