diff --git a/sqlalchemy_utils/types/range.py b/sqlalchemy_utils/types/range.py index 4f65870..87ad619 100644 --- a/sqlalchemy_utils/types/range.py +++ b/sqlalchemy_utils/types/range.py @@ -59,7 +59,11 @@ Membership operators :: - Car.price_range.in_([[300, 500]]) + Car.price_range.contains([300, 500]) + + Car.price_range.contained_by([300, 500]) + + Car.price_range.in_([[300, 500], [800, 900]]) ~ Car.price_range.in_([[300, 400], [700, 800]]) @@ -158,6 +162,13 @@ class RangeComparator(types.TypeEngine.Comparator): other = map(self.coerce_arg, other) return super(RangeComparator, self).notin_(other) + def contains(self, other, **kwargs): + other = self.coerce_arg(other) + return self.op('@>')(other) + + def contained_by(self, other, **kwargs): + other = self.coerce_arg(other) + return self.op('<@')(other) funcs = [ diff --git a/tests/types/test_int_range.py b/tests/types/test_int_range.py index d3d426d..ae7d50f 100644 --- a/tests/types/test_int_range.py +++ b/tests/types/test_int_range.py @@ -101,16 +101,50 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): @mark.parametrize( 'number_range', ( - [1, 3], - '1 - 3', - (0, 4) + [[1, 3]], + ['1 - 3'], + [(0, 4)], ) ) def test_in_operator(self, number_range): self.create_building([1, 3]) query = ( self.session.query(self.Building) - .filter(self.Building.persons_at_night.in_([number_range])) + .filter(self.Building.persons_at_night.in_(number_range)) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (1, 3), + 2 + ) + ) + def test_contains_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night.contains(number_range)) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (0, 8), + (-inf, inf) + ) + ) + def test_contained_by_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night.contained_by(number_range)) ) assert query.count()