diff --git a/CHANGES.rst b/CHANGES.rst index b9a70f4..69227f6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,13 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.30.11 (2015-06-18) +^^^^^^^^^^^^^^^^^^^^ + +- Make locale casting for translation hybrid expressions cast locales on compilation phase. This extra lazy locale casting is needed in some cases where translation hybrid expressions are used before get_locale +function is available. + + 0.30.10 (2015-06-17) ^^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/i18n.py b/sqlalchemy_utils/i18n.py index b36e47f..8c8fe7f 100644 --- a/sqlalchemy_utils/i18n.py +++ b/sqlalchemy_utils/i18n.py @@ -1,4 +1,7 @@ +import six import sqlalchemy as sa +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.expression import ColumnElement from sqlalchemy.ext.hybrid import hybrid_property from .exceptions import ImproperlyConfigured @@ -14,11 +17,44 @@ except ImportError: def get_locale(): raise ImproperlyConfigured( 'Could not load get_locale function from Flask-Babel. Either ' - 'install babel or make a similar function and override it ' + 'install Flask-Babel or make a similar function and override it ' 'in this module.' ) +def cast_locale(obj, locale): + """ + Cast given locale to string. Supports also callbacks that return locales. + + :param obj: + Object or class to use as a possible parameter to locale callable + :param locale: + Locale object or string or callable that returns a locale. + """ + if callable(locale): + try: + locale = locale() + except TypeError: + locale = locale(obj) + if isinstance(locale, babel.Locale): + return str(locale) + return locale + + +class cast_locale_expr(ColumnElement): + def __init__(self, cls, locale): + self.cls = cls + self.locale = locale + + +@compiles(cast_locale_expr) +def compile_cast_locale_expr(element, compiler, **kw): + locale = cast_locale(element.cls, element.locale) + if isinstance(locale, six.string_types): + return "'{0}'".format(locale) + return compiler.process(locale) + + class TranslationHybrid(object): def __init__(self, current_locale, default_locale, default_value=None): if babel is None: @@ -29,21 +65,6 @@ class TranslationHybrid(object): self.default_locale = default_locale self.default_value = default_value - def cast_locale(self, obj, locale): - """ - Cast given locale to string. Supports also callbacks that return - locales. - """ - if callable(locale): - try: - locale = locale() - except TypeError: - locale = locale(obj) - if isinstance(locale, babel.Locale): - return str(locale) - - return locale - def getter_factory(self, attr): """ Return a hybrid_property getter function for given attribute. The @@ -52,11 +73,11 @@ class TranslationHybrid(object): is no translation found for default locale it returns None. """ def getter(obj): - current_locale = self.cast_locale(obj, self.current_locale) + current_locale = cast_locale(obj, self.current_locale) try: return getattr(obj, attr.key)[current_locale] except (TypeError, KeyError): - default_locale = self.cast_locale( + default_locale = cast_locale( obj, self.default_locale ) try: @@ -69,14 +90,14 @@ class TranslationHybrid(object): def setter(obj, value): if getattr(obj, attr.key) is None: setattr(obj, attr.key, {}) - locale = self.cast_locale(obj, self.current_locale) + locale = cast_locale(obj, self.current_locale) getattr(obj, attr.key)[locale] = value return setter def expr_factory(self, attr): def expr(cls): - current_locale = self.cast_locale(cls, self.current_locale) - default_locale = self.cast_locale(cls, self.default_locale) + current_locale = cast_locale_expr(cls, self.current_locale) + default_locale = cast_locale_expr(cls, self.default_locale) return sa.func.coalesce(attr[current_locale], attr[default_locale]) return expr diff --git a/tests/test_translation_hybrid.py b/tests/test_translation_hybrid.py index 91b6572..e4c533a 100644 --- a/tests/test_translation_hybrid.py +++ b/tests/test_translation_hybrid.py @@ -1,4 +1,5 @@ import sqlalchemy as sa +from flexmock import flexmock from pytest import mark from sqlalchemy.dialects.postgresql import HSTORE @@ -68,7 +69,7 @@ class TestTranslationHybrid(TestCase): assert self.session.query(self.City.name).scalar() == name def test_dynamic_locale(self): - self.translation_hybrid = TranslationHybrid( + translation_hybrid = TranslationHybrid( lambda obj: obj.locale, 'fi' ) @@ -77,10 +78,29 @@ class TestTranslationHybrid(TestCase): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name_translations = sa.Column(HSTORE) - name = self.translation_hybrid(name_translations) + name = translation_hybrid(name_translations) locale = sa.Column(sa.String) assert ( 'coalesce(article.name_translations -> article.locale' in str(Article.name) ) + + def test_locales_casted_only_in_compilation_phase(self): + class LocaleGetter(object): + def current_locale(self): + return lambda obj: obj.locale + + flexmock(LocaleGetter).should_receive('current_locale').never() + translation_hybrid = TranslationHybrid( + LocaleGetter().current_locale, + 'fi' + ) + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name_translations = sa.Column(HSTORE) + name = translation_hybrid(name_translations) + locale = sa.Column(sa.String) + + Article.name