Make locale casting more lazy

This commit is contained in:
Konsta Vesterinen
2015-06-18 10:00:02 +03:00
parent 204aba376d
commit d02a414a60
3 changed files with 71 additions and 23 deletions

View File

@@ -4,6 +4,13 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release. Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.30.11 (2015-06-18)
^^^^^^^^^^^^^^^^^^^^
- Make locale casting for translation hybrid expressions cast locales on compilation phase. This extra lazy locale casting is needed in some cases where translation hybrid expressions are used before get_locale
function is available.
0.30.10 (2015-06-17) 0.30.10 (2015-06-17)
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^

View File

@@ -1,4 +1,7 @@
import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from .exceptions import ImproperlyConfigured from .exceptions import ImproperlyConfigured
@@ -14,11 +17,44 @@ except ImportError:
def get_locale(): def get_locale():
raise ImproperlyConfigured( raise ImproperlyConfigured(
'Could not load get_locale function from Flask-Babel. Either ' 'Could not load get_locale function from Flask-Babel. Either '
'install babel or make a similar function and override it ' 'install Flask-Babel or make a similar function and override it '
'in this module.' 'in this module.'
) )
def cast_locale(obj, locale):
"""
Cast given locale to string. Supports also callbacks that return locales.
:param obj:
Object or class to use as a possible parameter to locale callable
:param locale:
Locale object or string or callable that returns a locale.
"""
if callable(locale):
try:
locale = locale()
except TypeError:
locale = locale(obj)
if isinstance(locale, babel.Locale):
return str(locale)
return locale
class cast_locale_expr(ColumnElement):
def __init__(self, cls, locale):
self.cls = cls
self.locale = locale
@compiles(cast_locale_expr)
def compile_cast_locale_expr(element, compiler, **kw):
locale = cast_locale(element.cls, element.locale)
if isinstance(locale, six.string_types):
return "'{0}'".format(locale)
return compiler.process(locale)
class TranslationHybrid(object): class TranslationHybrid(object):
def __init__(self, current_locale, default_locale, default_value=None): def __init__(self, current_locale, default_locale, default_value=None):
if babel is None: if babel is None:
@@ -29,21 +65,6 @@ class TranslationHybrid(object):
self.default_locale = default_locale self.default_locale = default_locale
self.default_value = default_value self.default_value = default_value
def cast_locale(self, obj, locale):
"""
Cast given locale to string. Supports also callbacks that return
locales.
"""
if callable(locale):
try:
locale = locale()
except TypeError:
locale = locale(obj)
if isinstance(locale, babel.Locale):
return str(locale)
return locale
def getter_factory(self, attr): def getter_factory(self, attr):
""" """
Return a hybrid_property getter function for given attribute. The Return a hybrid_property getter function for given attribute. The
@@ -52,11 +73,11 @@ class TranslationHybrid(object):
is no translation found for default locale it returns None. is no translation found for default locale it returns None.
""" """
def getter(obj): def getter(obj):
current_locale = self.cast_locale(obj, self.current_locale) current_locale = cast_locale(obj, self.current_locale)
try: try:
return getattr(obj, attr.key)[current_locale] return getattr(obj, attr.key)[current_locale]
except (TypeError, KeyError): except (TypeError, KeyError):
default_locale = self.cast_locale( default_locale = cast_locale(
obj, self.default_locale obj, self.default_locale
) )
try: try:
@@ -69,14 +90,14 @@ class TranslationHybrid(object):
def setter(obj, value): def setter(obj, value):
if getattr(obj, attr.key) is None: if getattr(obj, attr.key) is None:
setattr(obj, attr.key, {}) setattr(obj, attr.key, {})
locale = self.cast_locale(obj, self.current_locale) locale = cast_locale(obj, self.current_locale)
getattr(obj, attr.key)[locale] = value getattr(obj, attr.key)[locale] = value
return setter return setter
def expr_factory(self, attr): def expr_factory(self, attr):
def expr(cls): def expr(cls):
current_locale = self.cast_locale(cls, self.current_locale) current_locale = cast_locale_expr(cls, self.current_locale)
default_locale = self.cast_locale(cls, self.default_locale) default_locale = cast_locale_expr(cls, self.default_locale)
return sa.func.coalesce(attr[current_locale], attr[default_locale]) return sa.func.coalesce(attr[current_locale], attr[default_locale])
return expr return expr

View File

@@ -1,4 +1,5 @@
import sqlalchemy as sa import sqlalchemy as sa
from flexmock import flexmock
from pytest import mark from pytest import mark
from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import HSTORE
@@ -68,7 +69,7 @@ class TestTranslationHybrid(TestCase):
assert self.session.query(self.City.name).scalar() == name assert self.session.query(self.City.name).scalar() == name
def test_dynamic_locale(self): def test_dynamic_locale(self):
self.translation_hybrid = TranslationHybrid( translation_hybrid = TranslationHybrid(
lambda obj: obj.locale, lambda obj: obj.locale,
'fi' 'fi'
) )
@@ -77,10 +78,29 @@ class TestTranslationHybrid(TestCase):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name_translations = sa.Column(HSTORE) name_translations = sa.Column(HSTORE)
name = self.translation_hybrid(name_translations) name = translation_hybrid(name_translations)
locale = sa.Column(sa.String) locale = sa.Column(sa.String)
assert ( assert (
'coalesce(article.name_translations -> article.locale' 'coalesce(article.name_translations -> article.locale'
in str(Article.name) in str(Article.name)
) )
def test_locales_casted_only_in_compilation_phase(self):
class LocaleGetter(object):
def current_locale(self):
return lambda obj: obj.locale
flexmock(LocaleGetter).should_receive('current_locale').never()
translation_hybrid = TranslationHybrid(
LocaleGetter().current_locale,
'fi'
)
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name_translations = sa.Column(HSTORE)
name = translation_hybrid(name_translations)
locale = sa.Column(sa.String)
Article.name