Add support for in operator for range types
This commit is contained in:
@@ -18,6 +18,7 @@ http://wiki.postgresql.org/images/f/f0/Range-types.pdf
|
||||
|
||||
.. _intervals: https://github.com/kvesteri/intervals
|
||||
"""
|
||||
from collections import Iterable
|
||||
intervals = None
|
||||
try:
|
||||
import intervals
|
||||
@@ -74,21 +75,42 @@ ischema_names['tstzrange'] = TSTZRANGE
|
||||
|
||||
class RangeComparator(types.TypeEngine.Comparator):
|
||||
@classmethod
|
||||
def coerce_arg(cls, func):
|
||||
def coerced_func(cls, func):
|
||||
def operation(self, other, **kwargs):
|
||||
coerced_types = (
|
||||
self.type.interval_class.type,
|
||||
tuple,
|
||||
list,
|
||||
) + six.string_types
|
||||
|
||||
if isinstance(other, coerced_types):
|
||||
other = self.type.interval_class(other)
|
||||
other = self.coerce_arg(other)
|
||||
return getattr(types.TypeEngine.Comparator, func)(
|
||||
self, other, **kwargs
|
||||
)
|
||||
return operation
|
||||
|
||||
def coerce_arg(self, other):
|
||||
coerced_types = (
|
||||
self.type.interval_class.type,
|
||||
tuple,
|
||||
list,
|
||||
) + six.string_types
|
||||
|
||||
if isinstance(other, coerced_types):
|
||||
return self.type.interval_class(other)
|
||||
return other
|
||||
|
||||
def in_(self, other):
|
||||
if (
|
||||
isinstance(other, Iterable) and
|
||||
not isinstance(other, six.string_types)
|
||||
):
|
||||
other = map(self.coerce_arg, other)
|
||||
return super(RangeComparator, self).in_(other)
|
||||
|
||||
def notin_(self, other):
|
||||
if (
|
||||
isinstance(other, Iterable) and
|
||||
not isinstance(other, six.string_types)
|
||||
):
|
||||
other = map(self.coerce_arg, other)
|
||||
return super(RangeComparator, self).notin_(other)
|
||||
|
||||
|
||||
|
||||
funcs = [
|
||||
'__eq__',
|
||||
@@ -104,7 +126,7 @@ for func in funcs:
|
||||
setattr(
|
||||
RangeComparator,
|
||||
func,
|
||||
RangeComparator.coerce_arg(func)
|
||||
RangeComparator.coerced_func(func)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -98,6 +98,22 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase):
|
||||
)
|
||||
assert query.count()
|
||||
|
||||
@mark.parametrize(
|
||||
'number_range',
|
||||
(
|
||||
[1, 3],
|
||||
'1 - 3',
|
||||
(0, 4)
|
||||
)
|
||||
)
|
||||
def test_in_operator(self, number_range):
|
||||
self.create_building([1, 3])
|
||||
query = (
|
||||
self.session.query(self.Building)
|
||||
.filter(self.Building.persons_at_night.in_([number_range]))
|
||||
)
|
||||
assert query.count()
|
||||
|
||||
def test_eq_with_query_arg(self):
|
||||
self.create_building([1, 3])
|
||||
query = (
|
||||
|
||||
Reference in New Issue
Block a user