diff --git a/sqlalchemy_utils/types/range.py b/sqlalchemy_utils/types/range.py index 1b6b7a2..379e0ea 100644 --- a/sqlalchemy_utils/types/range.py +++ b/sqlalchemy_utils/types/range.py @@ -1,13 +1,15 @@ """ -SQLAlchemy-Utils provides wide variety of range data types. All range data types return -Interval objects of intervals_ package. In order to use range data types you need to install intervals_ with: +SQLAlchemy-Utils provides wide variety of range data types. All range data +types return Interval objects of intervals_ package. In order to use range data +types you need to install intervals_ with: :: pip install intervals -Intervals package provides good chunk of additional interval operators that for example psycopg2 range objects do not support. +Intervals package provides good chunk of additional interval operators that for +example psycopg2 range objects do not support. @@ -16,10 +18,51 @@ Some good reading for practical interval implementations: http://wiki.postgresql.org/images/f/f0/Range-types.pdf +Range type initialization +------------------------- + +:: + + + + from sqlalchemy_utils import IntRangeType + + + class Event(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + estimated_number_of_persons = sa.Column(IntRangeType) + + + +You can also set a step parameter for range type. The values that are not +multipliers of given step will be rounded up to nearest step multiplier. + + +:: + + + from sqlalchemy_utils import IntRangeType + + + class Event(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + estimated_number_of_persons = sa.Column(IntRangeType(step=1000)) + + + event = Event(estimated_number_of_persons=[100, 1200]) + event.estimated_number_of_persons.lower # 0 + event.estimated_number_of_persons.upper # 1000 + + Range type operators -------------------- -SQLAlchemy-Utils supports many range type operators. These operators follow the `intervals` package interval coercion rules. +SQLAlchemy-Utils supports many range type operators. These operators follow the +`intervals` package interval coercion rules. So for example when we make a query such as: @@ -50,6 +93,13 @@ All range types support all comparison operators (>, >=, ==, !=, <=, <). Car.price_range > (300, 500) + # Whether or not range is strictly left of another range + Car.price_range << [300, 500] + + # Whether or not range is strictly right of another range + Car.price_range << [300, 500] + + Membership operators ^^^^^^^^^^^^^^^^^^^^ @@ -65,6 +115,7 @@ Membership operators ~ Car.price_range.in_([[300, 400], [700, 800]]) + .. _intervals: https://github.com/kvesteri/intervals """ from collections import Iterable @@ -178,6 +229,7 @@ class RangeType(types.TypeDecorator, ScalarCoercible): raise ImproperlyConfigured( 'RangeType needs intervals package installed.' ) + self.step = kwargs.pop('step', None) super(RangeType, self).__init__(*args, **kwargs) def load_dialect_impl(self, dialect): @@ -197,10 +249,10 @@ class RangeType(types.TypeDecorator, ScalarCoercible): if value is not None: if self.interval_class.step is not None: return self.canonicalize_result_value( - self.interval_class(value) + self.interval_class(value, step=self.step) ) else: - return self.interval_class(value) + return self.interval_class(value, step=self.step) return value def canonicalize_result_value(self, value): @@ -209,7 +261,7 @@ class RangeType(types.TypeDecorator, ScalarCoercible): def _coerce(self, value): if value is None: return None - return self.interval_class(value) + return self.interval_class(value, step=self.step) class IntRangeType(RangeType): @@ -263,7 +315,6 @@ class IntRangeType(RangeType): self.interval_class = intervals.IntInterval - class DateRangeType(RangeType): """ DateRangeType provides way for saving ranges of dates into database. On @@ -290,6 +341,24 @@ class DateRangeType(RangeType): class NumericRangeType(RangeType): + """ + NumericRangeType provides way for saving ranges of decimals into database. + On PostgreSQL this type maps to native NUMRANGE type while on other drivers + this maps to simple string column. + + Example:: + + + from sqlalchemy_utils import NumericRangeType + + + class Car(Base): + __tablename__ = 'car' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255))) + price_range = sa.Column(NumericRangeType) + """ + impl = NUMRANGE def __init__(self, *args, **kwargs): diff --git a/tests/types/test_numeric_range.py b/tests/types/test_numeric_range.py index ee30904..7007adf 100644 --- a/tests/types/test_numeric_range.py +++ b/tests/types/test_numeric_range.py @@ -1,3 +1,6 @@ +from decimal import Decimal + + from pytest import mark import sqlalchemy as sa intervals = None @@ -30,8 +33,8 @@ class NumericRangeTestCase(TestCase): return self.session.query(self.Car).first() def test_nullify_range(self): - building = self.create_car(None) - assert building.price_range == None + car = self.create_car(None) + assert car.price_range is None @mark.parametrize( 'number_range', @@ -82,3 +85,37 @@ class NumericRangeTestCase(TestCase): class TestNumericRangeOnPostgres(NumericRangeTestCase): dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + +@mark.skipif('intervals is None') +class TestNumericRangeWithStep(TestCase): + def create_models(self): + class Car(self.Base): + __tablename__ = 'car' + id = sa.Column(sa.Integer, primary_key=True) + price_range = sa.Column(NumericRangeType(step=Decimal('0.5'))) + + self.Car = Car + + def create_car(self, number_range): + car = self.Car( + price_range=number_range + ) + + self.session.add(car) + self.session.commit() + return self.session.query(self.Car).first() + + def test_passes_step_argument_to_interval_object(self): + car = self.create_car([Decimal('0.2'), Decimal('0.8')]) + assert car.price_range.lower == Decimal('0') + assert car.price_range.upper == Decimal('1') + assert car.price_range.step == Decimal('0.5') + + def test_passes_step_fetched_objects(self): + self.create_car([Decimal('0.2'), Decimal('0.8')]) + self.session.expunge_all() + car = self.session.query(self.Car).first() + assert car.price_range.lower == Decimal('0') + assert car.price_range.upper == Decimal('1') + assert car.price_range.step == Decimal('0.5')