diff --git a/setup.py b/setup.py index 825a84a..76c0b29 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ extras_require = { 'anyjson': ['anyjson>=0.3.3'], 'babel': ['Babel>=1.3'], 'arrow': ['arrow>=0.3.4'], - 'intervals': ['intervals>=0.2.0'], + 'intervals': ['intervals>=0.2.2'], 'phone': ['phonenumbers>=5.9.2'], 'password': ['passlib >= 1.6, < 2.0'], 'color': ['colour>=0.0.4'], diff --git a/sqlalchemy_utils/types/range.py b/sqlalchemy_utils/types/range.py index b19b5b3..1b6b7a2 100644 --- a/sqlalchemy_utils/types/range.py +++ b/sqlalchemy_utils/types/range.py @@ -182,9 +182,10 @@ class RangeType(types.TypeDecorator, ScalarCoercible): def load_dialect_impl(self, dialect): if dialect.name == 'postgresql': - # Use the native JSON type. + # Use the native range type for postgres. return dialect.type_descriptor(self.impl) else: + # Other drivers don't have native types. return dialect.type_descriptor(sa.String(255)) def process_bind_param(self, value, dialect): @@ -193,7 +194,7 @@ class RangeType(types.TypeDecorator, ScalarCoercible): return value def process_result_value(self, value, dialect): - if value: + if value is not None: if self.interval_class.step is not None: return self.canonicalize_result_value( self.interval_class(value) @@ -206,9 +207,9 @@ class RangeType(types.TypeDecorator, ScalarCoercible): return intervals.canonicalize(value, True, True) def _coerce(self, value): - if value is not None: - value = self.interval_class(value) - return value + if value is None: + return None + return self.interval_class(value) class IntRangeType(RangeType): diff --git a/tests/types/test_int_range.py b/tests/types/test_int_range.py index d153ac1..9c5b33c 100644 --- a/tests/types/test_int_range.py +++ b/tests/types/test_int_range.py @@ -32,6 +32,18 @@ class NumberRangeTestCase(TestCase): self.session.commit() return self.session.query(self.Building).first() + def test_nullify_range(self): + building = self.create_building(None) + assert building.persons_at_night == None + + def test_update_with_none(self): + interval = intervals.IntInterval('(,)') + building = self.create_building(interval) + building.persons_at_night = None + assert building.persons_at_night is None + self.session.commit() + assert building.persons_at_night is None + @mark.parametrize( 'number_range', ( diff --git a/tests/types/test_numeric_range.py b/tests/types/test_numeric_range.py index 6458858..ee30904 100644 --- a/tests/types/test_numeric_range.py +++ b/tests/types/test_numeric_range.py @@ -29,6 +29,10 @@ class NumericRangeTestCase(TestCase): self.session.commit() return self.session.query(self.Car).first() + def test_nullify_range(self): + building = self.create_car(None) + assert building.price_range == None + @mark.parametrize( 'number_range', (