From 98de37354ba3a1c5c36b608decc9f77996af4580 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 26 Mar 2013 13:17:29 +0200 Subject: [PATCH] Added some NumberRange tests --- sqlalchemy_utils/types.py | 28 +++++++++++++++++++++++++++- tests/test_number_range.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/test_number_range.py diff --git a/sqlalchemy_utils/types.py b/sqlalchemy_utils/types.py index 0261df2..a475837 100644 --- a/sqlalchemy_utils/types.py +++ b/sqlalchemy_utils/types.py @@ -84,7 +84,7 @@ class NumberRangeType(types.TypeDecorator): impl = NumberRangeRawType def process_bind_param(self, value, dialect): - return value + return str(value) def process_result_value(self, value, dialect): return NumberRange.from_normalized_str(value) @@ -97,6 +97,23 @@ class NumberRange(object): @classmethod def from_normalized_str(cls, value): + """ + Returns new NumberRange object from normalized number range format. + + Example :: + + range = NumberRange.from_normalized_str('[23, 45]') + range.min_value = 23 + range.max_value = 45 + + range = NumberRange.from_normalized_str('(23, 45]') + range.min_value = 24 + range.max_value = 45 + + range = NumberRange.from_normalized_str('(23, 45)') + range.min_value = 24 + range.max_value = 44 + """ if value is not None: values = value[1:-1].split(',') min_value, max_value = map( @@ -120,6 +137,15 @@ class NumberRange(object): ) return cls(min_value, max_value) + def __eq__(self, other): + try: + return ( + self.min_value == other.min_value and + self.max_value == other.max_value + ) + except AttributeError: + return NotImplemented + def __repr__(self): return 'NumberRange(%r, %r)' % (self.min_value, self.max_value) diff --git a/tests/test_number_range.py b/tests/test_number_range.py new file mode 100644 index 0000000..cef0367 --- /dev/null +++ b/tests/test_number_range.py @@ -0,0 +1,36 @@ +import sqlalchemy as sa +from sqlalchemy_utils import NumberRangeType, NumberRange +from tests import DatabaseTestCase + + +class TestNumberRangeType(DatabaseTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + persons_at_night = sa.Column(NumberRangeType) + + def __repr__(self): + return 'Building(%r)' % self.id + + self.Building = Building + + def test_save_number_range(self): + building = self.Building( + persons_at_night=NumberRange(1, 3) + ) + + self.session.add(building) + self.session.commit() + + building = self.session.query(self.Building).first() + assert building.persons_at_night.min_value == 1 + assert building.persons_at_night.max_value == 3 + + +class TestNumberRange(object): + def test_equality_operator(self): + assert NumberRange(1, 3) == NumberRange(1, 3) + + def test_str_representation(self): + assert str(NumberRange(1, 3)) == '[1, 3]'