diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 7e75563..a2cf024 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -37,6 +37,7 @@ from .functions import ( sort_query, table_name, ) +from .i18n import TranslationHybrid from .listeners import ( auto_delete_orphans, coercion_listener, diff --git a/sqlalchemy_utils/i18n.py b/sqlalchemy_utils/i18n.py index 1f40b31..0d77186 100644 --- a/sqlalchemy_utils/i18n.py +++ b/sqlalchemy_utils/i18n.py @@ -1,3 +1,5 @@ +from sqlalchemy.ext.hybrid import hybrid_property + from .exceptions import ImproperlyConfigured @@ -21,3 +23,60 @@ except ImportError: 'install babel or make a similar function and override it ' 'in this module.' ) + + +class TranslationHybrid(object): + def __init__(self, current_locale, default_locale): + self.current_locale = current_locale + self.default_locale = default_locale + + def cast_locale(self, obj, locale): + """ + Cast given locale to string. Supports also callbacks that return + locales. + """ + if callable(locale): + try: + return str(locale()) + except TypeError: + return str(locale(obj)) + return str(locale) + + def getter_factory(self, attr): + """ + Return a hybrid_property getter function for given attribute. The + returned getter first checks if object has translation for current + locale. If not it tries to get translation for default locale. If there + is no translation found for default locale it returns None. + """ + def getter(obj): + current_locale = self.cast_locale(obj, self.current_locale) + try: + return getattr(obj, attr.key)[current_locale] + except (TypeError, KeyError): + default_locale = self.cast_locale( + obj, self.default_locale + ) + try: + return getattr(obj, attr.key)[default_locale] + except (TypeError, KeyError): + return None + return getter + + def setter_factory(self, attr): + def setter(obj, value): + if getattr(obj, attr.key) is None: + setattr(obj, attr.key, {}) + locale = self.cast_locale(obj, self.current_locale) + getattr(obj, attr.key)[locale] = value + return setter + + def expr_factory(self, attr): + return lambda cls: attr + + def __call__(self, attr): + return hybrid_property( + fget=self.getter_factory(attr), + fset=self.setter_factory(attr), + expr=self.expr_factory(attr) + ) diff --git a/tests/test_translation_hybrid.py b/tests/test_translation_hybrid.py new file mode 100644 index 0000000..3d22205 --- /dev/null +++ b/tests/test_translation_hybrid.py @@ -0,0 +1,68 @@ +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy_utils import TranslationHybrid + +from tests import TestCase + + +class TestTranslationHybrid(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class City(self.Base): + __tablename__ = 'city' + id = sa.Column(sa.Integer, primary_key=True) + name_translations = sa.Column(JSON()) + name = self.translation_hybrid(name_translations) + locale = 'en' + + self.City = City + + def setup_method(self, method): + self.translation_hybrid = TranslationHybrid('fi', 'en') + TestCase.setup_method(self, method) + + def test_using_hybrid_as_constructor(self): + city = self.City(name='Helsinki') + assert city.name_translations['fi'] == 'Helsinki' + + def test_hybrid_as_expression(self): + assert self.City.name == self.City.name_translations + + def test_if_no_translation_exists_returns_none(self): + city = self.City() + assert city.name is None + + def test_fall_back_to_default_translation(self): + city = self.City(name_translations={'en': 'Helsinki'}) + self.translation_hybrid.current_locale = 'sv' + assert city.name == 'Helsinki' + + +class TestTranslationHybridWithDynamicDefaultLocale(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class City(self.Base): + __tablename__ = 'city' + id = sa.Column(sa.Integer, primary_key=True) + name_translations = sa.Column(JSON) + name = self.translation_hybrid(name_translations) + locale = sa.Column(sa.String(10)) + + self.City = City + + def setup_method(self, method): + self.translation_hybrid = TranslationHybrid( + 'fi', + lambda self: self.locale + ) + TestCase.setup_method(self, method) + + def test_fallback_to_dynamic_locale(self): + self.translation_hybrid.current_locale = 'en' + city = self.City(name_translations={}) + city.locale = 'fi' + city.name_translations['fi'] = 'Helsinki' + + assert city.name == 'Helsinki'