From 5e7187a2c72f073a2f916279db10bb201f9012d6 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 14 Jan 2014 16:33:59 +0200 Subject: [PATCH] Add support for in operator for range types --- sqlalchemy_utils/types/range.py | 42 +++++++++++++++++++++++++-------- tests/types/test_int_range.py | 16 +++++++++++++ 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/sqlalchemy_utils/types/range.py b/sqlalchemy_utils/types/range.py index d299b1c..8108986 100644 --- a/sqlalchemy_utils/types/range.py +++ b/sqlalchemy_utils/types/range.py @@ -18,6 +18,7 @@ http://wiki.postgresql.org/images/f/f0/Range-types.pdf .. _intervals: https://github.com/kvesteri/intervals """ +from collections import Iterable intervals = None try: import intervals @@ -74,21 +75,42 @@ ischema_names['tstzrange'] = TSTZRANGE class RangeComparator(types.TypeEngine.Comparator): @classmethod - def coerce_arg(cls, func): + def coerced_func(cls, func): def operation(self, other, **kwargs): - coerced_types = ( - self.type.interval_class.type, - tuple, - list, - ) + six.string_types - - if isinstance(other, coerced_types): - other = self.type.interval_class(other) + other = self.coerce_arg(other) return getattr(types.TypeEngine.Comparator, func)( self, other, **kwargs ) return operation + def coerce_arg(self, other): + coerced_types = ( + self.type.interval_class.type, + tuple, + list, + ) + six.string_types + + if isinstance(other, coerced_types): + return self.type.interval_class(other) + return other + + def in_(self, other): + if ( + isinstance(other, Iterable) and + not isinstance(other, six.string_types) + ): + other = map(self.coerce_arg, other) + return super(RangeComparator, self).in_(other) + + def notin_(self, other): + if ( + isinstance(other, Iterable) and + not isinstance(other, six.string_types) + ): + other = map(self.coerce_arg, other) + return super(RangeComparator, self).notin_(other) + + funcs = [ '__eq__', @@ -104,7 +126,7 @@ for func in funcs: setattr( RangeComparator, func, - RangeComparator.coerce_arg(func) + RangeComparator.coerced_func(func) ) diff --git a/tests/types/test_int_range.py b/tests/types/test_int_range.py index 596a25e..bc6390b 100644 --- a/tests/types/test_int_range.py +++ b/tests/types/test_int_range.py @@ -98,6 +98,22 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): ) assert query.count() + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (0, 4) + ) + ) + def test_in_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night.in_([number_range])) + ) + assert query.count() + def test_eq_with_query_arg(self): self.create_building([1, 3]) query = (