Added some NumberRange tests

This commit is contained in:
Konsta Vesterinen
2013-03-26 13:17:29 +02:00
parent 26db1397d5
commit 98de37354b
2 changed files with 63 additions and 1 deletions

View File

@@ -84,7 +84,7 @@ class NumberRangeType(types.TypeDecorator):
impl = NumberRangeRawType impl = NumberRangeRawType
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
return value return str(value)
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
return NumberRange.from_normalized_str(value) return NumberRange.from_normalized_str(value)
@@ -97,6 +97,23 @@ class NumberRange(object):
@classmethod @classmethod
def from_normalized_str(cls, value): def from_normalized_str(cls, value):
"""
Returns new NumberRange object from normalized number range format.
Example ::
range = NumberRange.from_normalized_str('[23, 45]')
range.min_value = 23
range.max_value = 45
range = NumberRange.from_normalized_str('(23, 45]')
range.min_value = 24
range.max_value = 45
range = NumberRange.from_normalized_str('(23, 45)')
range.min_value = 24
range.max_value = 44
"""
if value is not None: if value is not None:
values = value[1:-1].split(',') values = value[1:-1].split(',')
min_value, max_value = map( min_value, max_value = map(
@@ -120,6 +137,15 @@ class NumberRange(object):
) )
return cls(min_value, max_value) return cls(min_value, max_value)
def __eq__(self, other):
try:
return (
self.min_value == other.min_value and
self.max_value == other.max_value
)
except AttributeError:
return NotImplemented
def __repr__(self): def __repr__(self):
return 'NumberRange(%r, %r)' % (self.min_value, self.max_value) return 'NumberRange(%r, %r)' % (self.min_value, self.max_value)

View File

@@ -0,0 +1,36 @@
import sqlalchemy as sa
from sqlalchemy_utils import NumberRangeType, NumberRange
from tests import DatabaseTestCase
class TestNumberRangeType(DatabaseTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
persons_at_night = sa.Column(NumberRangeType)
def __repr__(self):
return 'Building(%r)' % self.id
self.Building = Building
def test_save_number_range(self):
building = self.Building(
persons_at_night=NumberRange(1, 3)
)
self.session.add(building)
self.session.commit()
building = self.session.query(self.Building).first()
assert building.persons_at_night.min_value == 1
assert building.persons_at_night.max_value == 3
class TestNumberRange(object):
def test_equality_operator(self):
assert NumberRange(1, 3) == NumberRange(1, 3)
def test_str_representation(self):
assert str(NumberRange(1, 3)) == '[1, 3]'