diff --git a/CHANGES.rst b/CHANGES.rst index 2d8e8cf..de4bf8e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.29.4 (2015-01-31) +^^^^^^^^^^^^^^^^^^^ + +- Made CaseInsensitiveComparator not cast already lowercased types to lowercase + + 0.29.3 (2015-01-24) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/operators.py b/sqlalchemy_utils/operators.py index 4fc801a..3092763 100644 --- a/sqlalchemy_utils/operators.py +++ b/sqlalchemy_utils/operators.py @@ -1,17 +1,41 @@ import sqlalchemy as sa +def inspect_type(mixed): + if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): + return mixed.property.columns[0].type + elif isinstance(mixed, sa.orm.ColumnProperty): + return mixed.columns[0].type + elif isinstance(mixed, sa.Column): + return mixed.type + + +def is_case_insensitive(mixed): + try: + return isinstance( + inspect_type(mixed).comparator, + CaseInsensitiveComparator + ) + except AttributeError: + try: + return issubclass( + inspect_type(mixed).comparator_factory, + CaseInsensitiveComparator + ) + except AttributeError: + return False + + class CaseInsensitiveComparator(sa.Unicode.Comparator): @classmethod def lowercase_arg(cls, func): def operation(self, other, **kwargs): + operator = getattr(sa.Unicode.Comparator, func) if other is None: - return getattr(sa.Unicode.Comparator, func)( - self, other, **kwargs - ) - return getattr(sa.Unicode.Comparator, func)( - self, sa.func.lower(other), **kwargs - ) + return operator(self, other, **kwargs) + if not is_case_insensitive(other): + other = sa.func.lower(other) + return operator(self, other, **kwargs) return operation def in_(self, other): diff --git a/tests/test_case_insensitive_comparator.py b/tests/test_case_insensitive_comparator.py index e1f7e10..56dd79b 100644 --- a/tests/test_case_insensitive_comparator.py +++ b/tests/test_case_insensitive_comparator.py @@ -42,3 +42,8 @@ class TestCaseInsensitiveComparator(TestCase): '"user".email NOT IN (lower(:lower_1), lower(:lower_2))' in str(query) ) + + def test_does_not_apply_lower_to_types_that_are_already_lowercased(self): + assert str(self.User.email == self.User.email) == ( + '"user".email = "user".email' + )