diff --git a/docs/data_types.rst b/docs/data_types.rst index 9501570..3b6546f 100644 --- a/docs/data_types.rst +++ b/docs/data_types.rst @@ -38,6 +38,19 @@ CountryType .. autoclass:: CountryType + +CurrencyType +^^^^^^^^^^^^ + +.. module:: sqlalchemy_utils.types.currency + +.. autoclass:: CurrencyType + +.. module:: sqlalchemy_utils.primitives.currency + +.. autoclass:: Currency + + EncryptedType ^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 5101395..7d59edc 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -54,6 +54,7 @@ from .listeners import ( # noqa ) from .models import Timestamp # noqa from .observer import observes # noqa +from .primitives import Currency, WeekDay, WeekDays # noqa from .proxy_dict import proxy_dict, ProxyDict # noqa from .query_chain import QueryChain # noqa from .types import ( # noqa @@ -63,6 +64,7 @@ from .types import ( # noqa ColorType, Country, CountryType, + CurrencyType, DateRangeType, DateTimeRangeType, EmailType, diff --git a/sqlalchemy_utils/primitives/__init__.py b/sqlalchemy_utils/primitives/__init__.py index 6831294..f09d888 100644 --- a/sqlalchemy_utils/primitives/__init__.py +++ b/sqlalchemy_utils/primitives/__init__.py @@ -1,7 +1,3 @@ -from .weekday import WeekDay -from .weekdays import WeekDays - -__all__ = ( - WeekDay, - WeekDays -) +from .currency import Currency # noqa +from .weekday import WeekDay # noqa +from .weekdays import WeekDays # noqa diff --git a/sqlalchemy_utils/primitives/currency.py b/sqlalchemy_utils/primitives/currency.py new file mode 100644 index 0000000..d9a04c8 --- /dev/null +++ b/sqlalchemy_utils/primitives/currency.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +babel = None +try: + import babel +except ImportError: + pass +import six + +from sqlalchemy_utils import i18n, ImproperlyConfigured +from sqlalchemy_utils.utils import str_coercible + + +@str_coercible +class Currency(object): + """ + Currency class wraps a 3-letter currency code. It provides various + convenience properties and methods. + + :: + + from babel import Locale + from sqlalchemy_utils import Currency, i18n + + + # First lets add a locale getter for testing purposes + i18n.get_locale = lambda: Locale('en') + + + Currency('USD').name # US Dollar + Currency('USD').symbol # $ + + Currency(Currency('USD')).code # 'USD' + + Currency always validates the given code. + + :: + + Currency(None) # raises TypeError + + Currency('UnknownCode') # raises ValueError + + + Currency supports equality operators. + + :: + + Currency('USD') == Currency('USD') + Currency('USD') != Currency('EUR') + + + Currencies are hashable. + + + :: + + len(set([Currency('USD'), Currency('USD')])) # 1 + + + """ + def __init__(self, code): + if babel is None: + raise ImproperlyConfigured( + "'babel' package is required in order to use Currency class." + ) + if isinstance(code, Currency): + self.code = code + elif isinstance(code, six.string_types): + self.validate(code) + self.code = code + else: + raise TypeError( + 'First argument given to Currency constructor should be ' + 'either an instance of Currency or valid three letter ' + 'currency code.' + ) + + @classmethod + def validate(self, code): + try: + i18n.get_locale().currencies[code] + except KeyError: + raise ValueError("{0}' is not valid currency code.") + + @property + def symbol(self): + return babel.numbers.get_currency_symbol(self.code, i18n.get_locale()) + + @property + def name(self): + return i18n.get_locale().currencies[self.code] + + def __eq__(self, other): + if isinstance(other, Currency): + return self.code == other.code + elif isinstance(other, six.string_types): + return self.code == other + else: + return NotImplemented + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.code) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.code) + + def __unicode__(self): + return self.code diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 4895582..8a44b55 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -6,6 +6,7 @@ from .arrow import ArrowType from .choice import Choice, ChoiceType from .color import ColorType from .country import Country, CountryType +from .currency import CurrencyType from .email import EmailType from .encrypted import EncryptedType from .ip_address import IPAddressType @@ -33,6 +34,7 @@ __all__ = ( ColorType, Country, CountryType, + CurrencyType, DateRangeType, DateTimeRangeType, EmailType, diff --git a/sqlalchemy_utils/types/currency.py b/sqlalchemy_utils/types/currency.py new file mode 100644 index 0000000..f3290fe --- /dev/null +++ b/sqlalchemy_utils/types/currency.py @@ -0,0 +1,80 @@ +babel = None +try: + import babel +except ImportError: + pass +import six +from sqlalchemy import types + +from sqlalchemy_utils import ImproperlyConfigured +from sqlalchemy_utils.primitives import Currency + +from .scalar_coercible import ScalarCoercible + + +class CurrencyType(types.TypeDecorator, ScalarCoercible): + """ + Changes :class:`.Currency` objects to a string representation on the way in + and changes them back to :class:`.Currency` objects on the way out. + + In order to use CurrencyType you need to install Babel_ first. + + .. _Babel: http://babel.pocoo.org/ + + :: + + + from sqlalchemy_utils import CurrencyType, Currency + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + currency = sa.Column(CurrencyType) + + + user = User() + user.currency = Currency('USD') + session.add(user) + session.commit() + + user.currency # Currency('USD') + user.currency.name # US Dollar + + str(user.currency) # US Dollar + user.currency.symbol # $ + + + + CurrencyType is scalar coercible:: + + + user.currency = 'US' + user.currency # Currency('US') + """ + impl = types.String(3) + python_type = Currency + + def __init__(self, *args, **kwargs): + if babel is None: + raise ImproperlyConfigured( + "'babel' package is required in order to use CurrencyType." + ) + + super(CurrencyType, self).__init__(*args, **kwargs) + + def process_bind_param(self, value, dialect): + if isinstance(value, Currency): + return value.code + elif isinstance(value, six.string_types): + return value + + def process_result_value(self, value, dialect): + if value is not None: + return Currency(value) + + def _coerce(self, value): + if value is not None and not isinstance(value, Currency): + return Currency(value) + return value diff --git a/tests/primitives/test_currency.py b/tests/primitives/test_currency.py new file mode 100644 index 0000000..bec1498 --- /dev/null +++ b/tests/primitives/test_currency.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +import six +from pytest import mark, raises + +from sqlalchemy_utils import Currency, i18n +from sqlalchemy_utils.primitives.currency import babel # noqa + + +@mark.skipif('babel is None') +class TestCurrency(object): + def setup_method(self, method): + i18n.get_locale = lambda: babel.Locale('en') + + def test_init(self): + assert Currency('USD') == Currency(Currency('USD')) + + def test_hashability(self): + assert len(set([Currency('USD'), Currency('USD')])) == 1 + + def test_invalid_currency_code(self): + with raises(ValueError): + Currency('Unknown code') + + def test_invalid_currency_code_type(self): + with raises(TypeError): + Currency(None) + + @mark.parametrize( + ('code', 'name'), + ( + ('USD', 'US Dollar'), + ('EUR', 'Euro') + ) + ) + def test_name_property(self, code, name): + assert Currency(code).name == name + + @mark.parametrize( + ('code', 'symbol'), + ( + ('USD', u'$'), + ('EUR', u'€') + ) + ) + def test_symbol_property(self, code, symbol): + assert Currency(code).symbol == symbol + + def test_equality_operator(self): + assert Currency('USD') == 'USD' + assert 'USD' == Currency('USD') + assert Currency('USD') == Currency('USD') + + def test_non_equality_operator(self): + assert Currency('USD') != 'EUR' + assert not (Currency('USD') != 'USD') + + def test_unicode(self): + currency = Currency('USD') + assert six.text_type(currency) == u'USD' + + def test_str(self): + currency = Currency('USD') + assert str(currency) == 'USD' + + def test_representation(self): + currency = Currency('USD') + assert repr(currency) == "Currency('USD')" diff --git a/tests/types/test_currency.py b/tests/types/test_currency.py new file mode 100644 index 0000000..7a346be --- /dev/null +++ b/tests/types/test_currency.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import Currency, CurrencyType, i18n +from sqlalchemy_utils.types.currency import babel +from tests import TestCase + + +@mark.skipif('babel is None') +class TestCurrencyType(TestCase): + def setup_method(self, method): + TestCase.setup_method(self, method) + i18n.get_locale = lambda: babel.Locale('en') + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + currency = sa.Column(CurrencyType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_parameter_processing(self): + user = self.User( + currency=Currency('USD') + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.currency.name == u'US Dollar' + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(currency='USD') + assert isinstance(user.currency, Currency)