diff --git a/sqlalchemy_utils/types/range.py b/sqlalchemy_utils/types/range.py index 62b932a..b19b5b3 100644 --- a/sqlalchemy_utils/types/range.py +++ b/sqlalchemy_utils/types/range.py @@ -194,9 +194,12 @@ class RangeType(types.TypeDecorator, ScalarCoercible): def process_result_value(self, value, dialect): if value: - return self.canonicalize_result_value( - self.interval_class(value) - ) + if self.interval_class.step is not None: + return self.canonicalize_result_value( + self.interval_class(value) + ) + else: + return self.interval_class(value) return value def canonicalize_result_value(self, value): @@ -289,7 +292,7 @@ class NumericRangeType(RangeType): impl = NUMRANGE def __init__(self, *args, **kwargs): - super(DateRangeType, self).__init__(*args, **kwargs) + super(NumericRangeType, self).__init__(*args, **kwargs) self.interval_class = intervals.DecimalInterval @@ -297,5 +300,5 @@ class DateTimeRangeType(RangeType): impl = TSRANGE def __init__(self, *args, **kwargs): - super(DateRangeType, self).__init__(*args, **kwargs) + super(DateTimeRangeType, self).__init__(*args, **kwargs) self.interval_class = intervals.DateTimeInterval diff --git a/tests/types/test_numeric_range.py b/tests/types/test_numeric_range.py new file mode 100644 index 0000000..6458858 --- /dev/null +++ b/tests/types/test_numeric_range.py @@ -0,0 +1,80 @@ +from pytest import mark +import sqlalchemy as sa +intervals = None +try: + import intervals +except ImportError: + pass +from tests import TestCase +from infinity import inf +from sqlalchemy_utils import NumericRangeType + + +@mark.skipif('intervals is None') +class NumericRangeTestCase(TestCase): + def create_models(self): + class Car(self.Base): + __tablename__ = 'car' + id = sa.Column(sa.Integer, primary_key=True) + price_range = sa.Column(NumericRangeType) + + self.Car = Car + + def create_car(self, number_range): + car = self.Car( + price_range=number_range + ) + + self.session.add(car) + self.session.commit() + return self.session.query(self.Car).first() + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + ) + ) + def test_save_number_range(self, number_range): + car = self.create_car(number_range) + assert car.price_range.lower == 1 + assert car.price_range.upper == 3 + + def test_infinite_upper_bound(self): + car = self.create_car([1, inf]) + assert car.price_range.lower == 1 + assert car.price_range.upper == inf + + def test_infinite_lower_bound(self): + car = self.create_car([-inf, 1]) + assert car.price_range.lower == -inf + assert car.price_range.upper == 1 + + def test_nullify_number_range(self): + car = self.Car( + price_range=intervals.IntInterval([1, 3]) + ) + + self.session.add(car) + self.session.commit() + + car = self.session.query(self.Car).first() + car.price_range = None + self.session.commit() + + car = self.session.query(self.Car).first() + assert car.price_range is None + + def test_string_coercion(self): + car = self.Car(price_range='[12, 18]') + assert isinstance(car.price_range, intervals.DecimalInterval) + + def test_integer_coercion(self): + car = self.Car(price_range=15) + assert car.price_range.lower == 15 + assert car.price_range.upper == 15 + + +class TestNumericRangeOnPostgres(NumericRangeTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'