diff --git a/docs/data_types.rst b/docs/data_types.rst index 9501570..a8839ea 100644 --- a/docs/data_types.rst +++ b/docs/data_types.rst @@ -38,6 +38,15 @@ CountryType .. autoclass:: CountryType + +CurrencyType +^^^^^^^^^^^^ + +.. module:: sqlalchemy_utils.types.currency + +.. autoclass:: CurrencyType + + EncryptedType ^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 6a7e4c6..7d59edc 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -64,6 +64,7 @@ from .types import ( # noqa ColorType, Country, CountryType, + CurrencyType, DateRangeType, DateTimeRangeType, EmailType, diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 4895582..8a44b55 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -6,6 +6,7 @@ from .arrow import ArrowType from .choice import Choice, ChoiceType from .color import ColorType from .country import Country, CountryType +from .currency import CurrencyType from .email import EmailType from .encrypted import EncryptedType from .ip_address import IPAddressType @@ -33,6 +34,7 @@ __all__ = ( ColorType, Country, CountryType, + CurrencyType, DateRangeType, DateTimeRangeType, EmailType, diff --git a/sqlalchemy_utils/types/currency.py b/sqlalchemy_utils/types/currency.py new file mode 100644 index 0000000..f2e9940 --- /dev/null +++ b/sqlalchemy_utils/types/currency.py @@ -0,0 +1,66 @@ +import six +from sqlalchemy import types + +from sqlalchemy_utils.primitives import Currency + +from .scalar_coercible import ScalarCoercible + + +class CurrencyType(types.TypeDecorator, ScalarCoercible): + """ + Changes Currency objects to a string representation on the way in and + changes them back to Currency objects on the way out. + + In order to use CurrencyType you need to install Babel_ first. + + .. _Babel: http://babel.pocoo.org/ + + :: + + + from sqlalchemy_utils import CurrencyType, Currency + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + currency = sa.Column(CurrencyType) + + + user = User() + user.currency = Currency('USD') + session.add(user) + session.commit() + + user.currency # Currency('USD') + user.currency.name # US Dollar + + str(user.currency) # US Dollar + user.currency.symbol # $ + + + + CurrencyType is scalar coercible:: + + + user.currency = 'US' + user.currency # Currency('US') + """ + impl = types.String(3) + python_type = Currency + + def process_bind_param(self, value, dialect): + if isinstance(value, Currency): + return value.code + elif isinstance(value, six.string_types): + return value + + def process_result_value(self, value, dialect): + if value is not None: + return Currency(value) + + def _coerce(self, value): + if value is not None and not isinstance(value, Currency): + return Currency(value) + return value diff --git a/tests/types/test_currency.py b/tests/types/test_currency.py new file mode 100644 index 0000000..88fd422 --- /dev/null +++ b/tests/types/test_currency.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +import sqlalchemy as sa +from babel import Locale + +from sqlalchemy_utils import Currency, CurrencyType, i18n +from tests import TestCase + + +class TestCurrencyType(TestCase): + def setup_method(self, method): + TestCase.setup_method(self, method) + i18n.get_locale = lambda: Locale('en') + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + currency = sa.Column(CurrencyType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_parameter_processing(self): + user = self.User( + currency=Currency('USD') + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.currency.name == u'US Dollar' + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(currency='USD') + assert isinstance(user.currency, Currency)