diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 76f1ae4..fcf5860 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -22,6 +22,7 @@ from .functions import ( from .listeners import coercion_listener from .merge import merge, Merger from .generic import generic_relationship +from .primitives import NumberRange, NumberRangeException from .proxy_dict import ProxyDict, proxy_dict from .types import ( ArrowType, @@ -39,8 +40,6 @@ from .types import ( PasswordType, PhoneNumber, PhoneNumberType, - NumberRange, - NumberRangeException, NumberRangeRawType, NumberRangeType, ScalarListType, diff --git a/sqlalchemy_utils/primitives/__init__.py b/sqlalchemy_utils/primitives/__init__.py new file mode 100644 index 0000000..f00f3ac --- /dev/null +++ b/sqlalchemy_utils/primitives/__init__.py @@ -0,0 +1,7 @@ +from .number_range import NumberRange, NumberRangeException + + +__all__ = ( + NumberRange, + NumberRangeException +) diff --git a/sqlalchemy_utils/primitives/number_range.py b/sqlalchemy_utils/primitives/number_range.py new file mode 100644 index 0000000..26e9430 --- /dev/null +++ b/sqlalchemy_utils/primitives/number_range.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +try: + from functools import total_ordering +except ImportError: + from total_ordering import total_ordering + +import six + + +class NumberRangeException(Exception): + pass + + +class RangeBoundsException(NumberRangeException): + def __init__(self, min_value, max_value): + self.message = 'Min value %d is bigger than max value %d.' % ( + min_value, + max_value + ) + + +@total_ordering +class NumberRange(object): + def __init__(self, *args): + if len(args) > 2: + raise NumberRangeException( + 'NumberRange takes at most two arguments' + ) + elif len(args) == 2: + lower, upper = args + if lower > upper: + raise RangeBoundsException(lower, upper) + self.lower = lower + self.upper = upper + self.lower_inc = self.upper_inc = True + else: + if isinstance(args[0], six.integer_types): + self.lower = self.upper = args[0] + self.lower_inc = self.upper_inc = True + elif isinstance(args[0], six.string_types): + if ',' not in args[0]: + self.lower, self.upper = self.parse_range(args[0]) + self.lower_inc = self.upper_inc = True + else: + self.from_range_with_bounds(args[0]) + elif hasattr(args[0], 'lower') and hasattr(args[0], 'upper'): + self.lower = args[0].lower + self.upper = args[0].upper + if not args[0].lower_inc: + self.lower += 1 + + if not args[0].upper_inc: + self.upper -= 1 + + def from_range_with_bounds(self, value): + """ + Returns new NumberRange object from normalized number range format. + + Example :: + + range = NumberRange.from_normalized_str('[23, 45]') + range.lower = 23 + range.upper = 45 + + range = NumberRange.from_normalized_str('(23, 45]') + range.lower = 24 + range.upper = 45 + + range = NumberRange.from_normalized_str('(23, 45)') + range.lower = 24 + range.upper = 44 + """ + values = value[1:-1].split(',') + try: + lower, upper = map( + lambda a: int(a.strip()), values + ) + except ValueError as e: + raise NumberRangeException(e.message) + + self.lower_inc = value[0] == '(' + if self.lower_inc: + lower += 1 + + self.upper_inc = value[-1] == ')' + if self.upper_inc: + upper -= 1 + + self.lower = lower + self.upper = upper + + def parse_range(self, value): + if value is not None: + values = value.split('-') + if len(values) == 1: + lower = upper = int(value.strip()) + else: + try: + lower, upper = map( + lambda a: int(a.strip()), values + ) + except ValueError as e: + raise NumberRangeException(str(e)) + return lower, upper + + @property + def normalized(self): + return '[%s, %s]' % (self.lower, self.upper) + + def __eq__(self, other): + try: + return ( + self.lower == other.lower and + self.upper == other.upper + ) + except AttributeError: + return NotImplemented + + def __ne__(self, other): + return not (self == other) + + def __gt__(self, other): + try: + return self.lower > other.lower and self.upper > other.upper + except AttributeError: + return NotImplemented + + def __repr__(self): + return 'NumberRange(%r, %r)' % (self.lower, self.upper) + + def __str__(self): + if self.lower != self.upper: + return '%s - %s' % (self.lower, self.upper) + return str(self.lower) + + def __add__(self, other): + """ + [a, b] + [c, d] = [a + c, b + d] + """ + try: + return NumberRange( + self.lower + other.lower, + self.upper + other.upper + ) + except AttributeError: + return NotImplemented + + def __sub__(self, other): + """ + Defines the substraction operator. + + As defined in wikipedia: + + [a, b] − [c, d] = [a − d, b − c] + """ + try: + return NumberRange( + self.lower - other.upper, + self.upper - other.lower + ) + except AttributeError: + return NotImplemented diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 1f634b7..f0abb15 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -8,8 +8,6 @@ from .email import EmailType from .ip_address import IPAddressType from .locale import LocaleType from .number_range import ( - NumberRange, - NumberRangeException, NumberRangeRawType, NumberRangeType, ) @@ -33,8 +31,6 @@ __all__ = ( EmailType, IPAddressType, LocaleType, - NumberRange, - NumberRangeException, NumberRangeRawType, NumberRangeType, Password, diff --git a/sqlalchemy_utils/types/number_range.py b/sqlalchemy_utils/types/number_range.py index f80837d..801c848 100644 --- a/sqlalchemy_utils/types/number_range.py +++ b/sqlalchemy_utils/types/number_range.py @@ -1,5 +1,6 @@ import six from sqlalchemy import types +from sqlalchemy_utils.primitives import NumberRange from .scalar_coercible import ScalarCoercible @@ -64,155 +65,18 @@ class NumberRangeType(types.TypeDecorator, ScalarCoercible): def process_result_value(self, value, dialect): if value: if not isinstance(value, six.string_types): - value = NumberRange.from_range_object(value) + value = NumberRange(value) else: - return NumberRange.from_normalized_str(value) + return NumberRange(value) return value def _coerce(self, value): if value is not None and not isinstance(value, NumberRange): - if isinstance(value, six.string_types): - value = NumberRange.from_normalized_str(value) - elif isinstance(value, six.integer_types): - value = NumberRange(value, value) + if ( + isinstance(value, six.string_types) or + isinstance(value, six.integer_types) + ): + value = NumberRange(value) else: raise TypeError('Could not coerce value to NumberRange.') return value - - -class NumberRangeException(Exception): - pass - - -class RangeBoundsException(NumberRangeException): - def __init__(self, min_value, max_value): - self.message = 'Min value %d is bigger than max value %d.' % ( - min_value, - max_value - ) - - -class NumberRange(object): - def __init__(self, min_value, max_value): - if min_value > max_value: - raise RangeBoundsException(min_value, max_value) - self.min_value = min_value - self.max_value = max_value - - @classmethod - def from_range_object(cls, value): - min_value = value.lower - max_value = value.upper - if not value.lower_inc: - min_value += 1 - - if not value.upper_inc: - max_value -= 1 - - return cls(min_value, max_value) - - @classmethod - def from_normalized_str(cls, value): - """ - Returns new NumberRange object from normalized number range format. - - Example :: - - range = NumberRange.from_normalized_str('[23, 45]') - range.min_value = 23 - range.max_value = 45 - - range = NumberRange.from_normalized_str('(23, 45]') - range.min_value = 24 - range.max_value = 45 - - range = NumberRange.from_normalized_str('(23, 45)') - range.min_value = 24 - range.max_value = 44 - """ - if value is not None: - values = value[1:-1].split(',') - try: - min_value, max_value = map( - lambda a: int(a.strip()), values - ) - except ValueError as e: - raise NumberRangeException(e.message) - - if value[0] == '(': - min_value += 1 - - if value[-1] == ')': - max_value -= 1 - - return cls(min_value, max_value) - - @classmethod - def from_str(cls, value): - if value is not None: - values = value.split('-') - if len(values) == 1: - min_value = max_value = int(value.strip()) - else: - try: - min_value, max_value = map( - lambda a: int(a.strip()), values - ) - except ValueError as e: - raise NumberRangeException(str(e)) - return cls(min_value, max_value) - - @property - def normalized(self): - return '[%s, %s]' % (self.min_value, self.max_value) - - def __eq__(self, other): - try: - return ( - self.min_value == other.min_value and - self.max_value == other.max_value - ) - except AttributeError: - return NotImplemented - - def __repr__(self): - return 'NumberRange(%r, %r)' % (self.min_value, self.max_value) - - def __str__(self): - if self.min_value != self.max_value: - return '%s - %s' % (self.min_value, self.max_value) - return str(self.min_value) - - def __add__(self, other): - try: - return NumberRange( - self.min_value + other.min_value, - self.max_value + other.max_value - ) - except AttributeError: - return NotImplemented - - def __iadd__(self, other): - try: - self.min_value += other.min_value - self.max_value += other.max_value - return self - except AttributeError: - return NotImplemented - - def __sub__(self, other): - try: - return NumberRange( - self.min_value - other.min_value, - self.max_value - other.max_value - ) - except AttributeError: - return NotImplemented - - def __isub__(self, other): - try: - self.min_value -= other.min_value - self.max_value -= other.max_value - return self - except AttributeError: - return NotImplemented diff --git a/tests/types/test_number_range.py b/tests/types/test_number_range.py index c413fb1..3203b96 100644 --- a/tests/types/test_number_range.py +++ b/tests/types/test_number_range.py @@ -1,10 +1,8 @@ import sqlalchemy as sa -from pytest import raises from tests import TestCase from sqlalchemy_utils import ( NumberRangeType, NumberRange, - NumberRangeException, coercion_listener ) @@ -30,8 +28,8 @@ class TestNumberRangeType(TestCase): self.session.add(building) self.session.commit() building = self.session.query(self.Building).first() - assert building.persons_at_night.min_value == 1 - assert building.persons_at_night.max_value == 3 + assert building.persons_at_night.lower == 1 + assert building.persons_at_night.upper == 3 def test_nullify_number_range(self): building = self.Building( @@ -55,38 +53,5 @@ class TestNumberRangeType(TestCase): def test_integer_coercion(self): building = self.Building(persons_at_night=15) - assert building.persons_at_night.min_value == 15 - assert building.persons_at_night.max_value == 15 - - -class TestNumberRange(object): - def test_equality_operator(self): - assert NumberRange(1, 3) == NumberRange(1, 3) - - def test_str_representation(self): - assert str(NumberRange(1, 3)) == '1 - 3' - assert str(NumberRange(1, 1)) == '1' - - def test_raises_exception_for_badly_constructed_range(self): - with raises(NumberRangeException): - NumberRange(3, 2) - - def test_from_str_supports_single_integers(self): - number_range = NumberRange.from_str('1') - assert number_range.min_value == 1 - assert number_range.max_value == 1 - - def test_from_str_exception_handling(self): - with raises(NumberRangeException): - NumberRange.from_str('1 - ') - - def test_from_normalized_str(self): - assert str(NumberRange.from_normalized_str('[1,2]')) == '1 - 2' - assert str(NumberRange.from_normalized_str('[1,3)')) == '1 - 2' - assert str(NumberRange.from_normalized_str('(1,3)')) == '2' - - def test_add_operator(self): - assert NumberRange(1, 2) + NumberRange(1, 2) == NumberRange(2, 4) - - def test_sub_operator(self): - assert NumberRange(1, 3) - NumberRange(1, 2) == NumberRange(0, 1) + assert building.persons_at_night.lower == 15 + assert building.persons_at_night.upper == 15