Fix NumericRange
This commit is contained in:
@@ -194,9 +194,12 @@ class RangeType(types.TypeDecorator, ScalarCoercible):
|
|||||||
|
|
||||||
def process_result_value(self, value, dialect):
|
def process_result_value(self, value, dialect):
|
||||||
if value:
|
if value:
|
||||||
|
if self.interval_class.step is not None:
|
||||||
return self.canonicalize_result_value(
|
return self.canonicalize_result_value(
|
||||||
self.interval_class(value)
|
self.interval_class(value)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return self.interval_class(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def canonicalize_result_value(self, value):
|
def canonicalize_result_value(self, value):
|
||||||
@@ -289,7 +292,7 @@ class NumericRangeType(RangeType):
|
|||||||
impl = NUMRANGE
|
impl = NUMRANGE
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DateRangeType, self).__init__(*args, **kwargs)
|
super(NumericRangeType, self).__init__(*args, **kwargs)
|
||||||
self.interval_class = intervals.DecimalInterval
|
self.interval_class = intervals.DecimalInterval
|
||||||
|
|
||||||
|
|
||||||
@@ -297,5 +300,5 @@ class DateTimeRangeType(RangeType):
|
|||||||
impl = TSRANGE
|
impl = TSRANGE
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DateRangeType, self).__init__(*args, **kwargs)
|
super(DateTimeRangeType, self).__init__(*args, **kwargs)
|
||||||
self.interval_class = intervals.DateTimeInterval
|
self.interval_class = intervals.DateTimeInterval
|
||||||
|
80
tests/types/test_numeric_range.py
Normal file
80
tests/types/test_numeric_range.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from pytest import mark
|
||||||
|
import sqlalchemy as sa
|
||||||
|
intervals = None
|
||||||
|
try:
|
||||||
|
import intervals
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
from tests import TestCase
|
||||||
|
from infinity import inf
|
||||||
|
from sqlalchemy_utils import NumericRangeType
|
||||||
|
|
||||||
|
|
||||||
|
@mark.skipif('intervals is None')
|
||||||
|
class NumericRangeTestCase(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Car(self.Base):
|
||||||
|
__tablename__ = 'car'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
price_range = sa.Column(NumericRangeType)
|
||||||
|
|
||||||
|
self.Car = Car
|
||||||
|
|
||||||
|
def create_car(self, number_range):
|
||||||
|
car = self.Car(
|
||||||
|
price_range=number_range
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.add(car)
|
||||||
|
self.session.commit()
|
||||||
|
return self.session.query(self.Car).first()
|
||||||
|
|
||||||
|
@mark.parametrize(
|
||||||
|
'number_range',
|
||||||
|
(
|
||||||
|
[1, 3],
|
||||||
|
'1 - 3',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def test_save_number_range(self, number_range):
|
||||||
|
car = self.create_car(number_range)
|
||||||
|
assert car.price_range.lower == 1
|
||||||
|
assert car.price_range.upper == 3
|
||||||
|
|
||||||
|
def test_infinite_upper_bound(self):
|
||||||
|
car = self.create_car([1, inf])
|
||||||
|
assert car.price_range.lower == 1
|
||||||
|
assert car.price_range.upper == inf
|
||||||
|
|
||||||
|
def test_infinite_lower_bound(self):
|
||||||
|
car = self.create_car([-inf, 1])
|
||||||
|
assert car.price_range.lower == -inf
|
||||||
|
assert car.price_range.upper == 1
|
||||||
|
|
||||||
|
def test_nullify_number_range(self):
|
||||||
|
car = self.Car(
|
||||||
|
price_range=intervals.IntInterval([1, 3])
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.add(car)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
car = self.session.query(self.Car).first()
|
||||||
|
car.price_range = None
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
car = self.session.query(self.Car).first()
|
||||||
|
assert car.price_range is None
|
||||||
|
|
||||||
|
def test_string_coercion(self):
|
||||||
|
car = self.Car(price_range='[12, 18]')
|
||||||
|
assert isinstance(car.price_range, intervals.DecimalInterval)
|
||||||
|
|
||||||
|
def test_integer_coercion(self):
|
||||||
|
car = self.Car(price_range=15)
|
||||||
|
assert car.price_range.lower == 15
|
||||||
|
assert car.price_range.upper == 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumericRangeOnPostgres(NumericRangeTestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
Reference in New Issue
Block a user