Add support for in operator for range types

This commit is contained in:
Konsta Vesterinen
2014-01-14 16:33:59 +02:00
parent 0073cc50a2
commit 5e7187a2c7
2 changed files with 48 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ http://wiki.postgresql.org/images/f/f0/Range-types.pdf
.. _intervals: https://github.com/kvesteri/intervals
"""
from collections import Iterable
intervals = None
try:
import intervals
@@ -74,21 +75,42 @@ ischema_names['tstzrange'] = TSTZRANGE
class RangeComparator(types.TypeEngine.Comparator):
@classmethod
def coerce_arg(cls, func):
def coerced_func(cls, func):
def operation(self, other, **kwargs):
coerced_types = (
self.type.interval_class.type,
tuple,
list,
) + six.string_types
if isinstance(other, coerced_types):
other = self.type.interval_class(other)
other = self.coerce_arg(other)
return getattr(types.TypeEngine.Comparator, func)(
self, other, **kwargs
)
return operation
def coerce_arg(self, other):
coerced_types = (
self.type.interval_class.type,
tuple,
list,
) + six.string_types
if isinstance(other, coerced_types):
return self.type.interval_class(other)
return other
def in_(self, other):
if (
isinstance(other, Iterable) and
not isinstance(other, six.string_types)
):
other = map(self.coerce_arg, other)
return super(RangeComparator, self).in_(other)
def notin_(self, other):
if (
isinstance(other, Iterable) and
not isinstance(other, six.string_types)
):
other = map(self.coerce_arg, other)
return super(RangeComparator, self).notin_(other)
funcs = [
'__eq__',
@@ -104,7 +126,7 @@ for func in funcs:
setattr(
RangeComparator,
func,
RangeComparator.coerce_arg(func)
RangeComparator.coerced_func(func)
)

View File

@@ -98,6 +98,22 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase):
)
assert query.count()
@mark.parametrize(
'number_range',
(
[1, 3],
'1 - 3',
(0, 4)
)
)
def test_in_operator(self, number_range):
self.create_building([1, 3])
query = (
self.session.query(self.Building)
.filter(self.Building.persons_at_night.in_([number_range]))
)
assert query.count()
def test_eq_with_query_arg(self):
self.create_building([1, 3])
query = (