diff --git a/CHANGES.rst b/CHANGES.rst index ee9761e..b823797 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,13 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.8.0 (2013-03-27) +^^^^^^^^^^^^^^^^^^ + +- Added ScalarList type +- Fixed NumberRange bind param and result value processing + + 0.7.7 (2013-03-27) ^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 37f6182..dfb1f28 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ class PyTest(Command): setup( name='SQLAlchemy-Utils', - version='0.7.7', + version='0.8.0', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index d4e4ceb..72691ed 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -8,7 +8,9 @@ from .types import ( NumberRange, NumberRangeException, NumberRangeRawType, - NumberRangeType + NumberRangeType, + ScalarList, + ScalarListException, ) @@ -26,4 +28,6 @@ __all__ = ( NumberRangeType, PhoneNumber, PhoneNumberType, + ScalarList, + ScalarListException, ) diff --git a/sqlalchemy_utils/types.py b/sqlalchemy_utils/types.py index e7da978..d504c0d 100644 --- a/sqlalchemy_utils/types.py +++ b/sqlalchemy_utils/types.py @@ -1,5 +1,6 @@ import phonenumbers from functools import wraps +import sqlalchemy as sa from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from sqlalchemy import types @@ -82,6 +83,42 @@ class PhoneNumberType(types.TypeDecorator): return value +class ScalarListException(Exception): + pass + + +class ScalarList(types.TypeDecorator): + impl = sa.UnicodeText() + + def __init__(self, coerce_func=unicode, separator=u','): + self.separator = unicode(separator) + self.coerce_func = coerce_func + + def process_bind_param(self, value, dialect): + # Convert list of values to unicode separator-separated list + # Example: [1, 2, 3, 4] -> u'1, 2, 3, 4' + if value: + print value + if any(self.separator in unicode(item) for item in value): + raise ScalarListException( + "List values can't contain string '%s' (its being used as " + "separator. If you wish for scalar list values to contain " + "these strings, use a different separator string." + ) + return self.separator.join( + map(unicode, value) + ) + return value + + def process_result_value(self, value, dialect): + if value: + # coerce each value + return map( + self.coerce_func, value.split(self.separator) + ) + return value + + class NumberRangeRawType(types.UserDefinedType): """ Raw number range type, only supports PostgreSQL for now. @@ -96,9 +133,12 @@ class NumberRangeType(types.TypeDecorator): def process_bind_param(self, value, dialect): if value: return value.normalized + return value def process_result_value(self, value, dialect): - return NumberRange.from_normalized_str(value) + if value: + return NumberRange.from_normalized_str(value) + return value class NumberRangeException(Exception): diff --git a/tests/test_number_range.py b/tests/test_number_range.py index c09874e..b1dea40 100644 --- a/tests/test_number_range.py +++ b/tests/test_number_range.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from pytest import raises, mark +from pytest import raises from sqlalchemy_utils import NumberRangeType, NumberRange, NumberRangeException from tests import DatabaseTestCase diff --git a/tests/test_scalar_list.py b/tests/test_scalar_list.py new file mode 100644 index 0000000..9585888 --- /dev/null +++ b/tests/test_scalar_list.py @@ -0,0 +1,61 @@ +import sqlalchemy as sa +from sqlalchemy_utils import ScalarList +from pytest import raises +from tests import DatabaseTestCase + + +class TestScalarIntegerList(DatabaseTestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + some_list = sa.Column(ScalarList(int)) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_save_integer_list(self): + user = self.User( + some_list=[1, 2, 3, 4] + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.some_list == [1, 2, 3, 4] + + +class TestScalarUnicodeList(DatabaseTestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + some_list = sa.Column(ScalarList(unicode)) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_throws_exception_if_using_separator_in_list_values(self): + user = self.User( + some_list=[u','] + ) + + self.session.add(user) + with raises(sa.exc.StatementError): + self.session.commit() + + def test_save_unicode_list(self): + user = self.User( + some_list=[u'1', u'2', u'3', u'4'] + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.some_list == [u'1', u'2', u'3', u'4']