Addad ScalarList type, fixed NumberRangeType

This commit is contained in:
Konsta Vesterinen
2013-04-02 10:13:50 +03:00
parent b6f46d539b
commit 74dbf06146
6 changed files with 116 additions and 4 deletions

View File

@@ -4,6 +4,13 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.8.0 (2013-03-27)
^^^^^^^^^^^^^^^^^^
- Added ScalarList type
- Fixed NumberRange bind param and result value processing
0.7.7 (2013-03-27)
^^^^^^^^^^^^^^^^^^

View File

@@ -24,7 +24,7 @@ class PyTest(Command):
setup(
name='SQLAlchemy-Utils',
version='0.7.7',
version='0.8.0',
url='https://github.com/kvesteri/sqlalchemy-utils',
license='BSD',
author='Konsta Vesterinen',

View File

@@ -8,7 +8,9 @@ from .types import (
NumberRange,
NumberRangeException,
NumberRangeRawType,
NumberRangeType
NumberRangeType,
ScalarList,
ScalarListException,
)
@@ -26,4 +28,6 @@ __all__ = (
NumberRangeType,
PhoneNumber,
PhoneNumberType,
ScalarList,
ScalarListException,
)

View File

@@ -1,5 +1,6 @@
import phonenumbers
from functools import wraps
import sqlalchemy as sa
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
from sqlalchemy import types
@@ -82,6 +83,42 @@ class PhoneNumberType(types.TypeDecorator):
return value
class ScalarListException(Exception):
pass
class ScalarList(types.TypeDecorator):
impl = sa.UnicodeText()
def __init__(self, coerce_func=unicode, separator=u','):
self.separator = unicode(separator)
self.coerce_func = coerce_func
def process_bind_param(self, value, dialect):
# Convert list of values to unicode separator-separated list
# Example: [1, 2, 3, 4] -> u'1, 2, 3, 4'
if value:
print value
if any(self.separator in unicode(item) for item in value):
raise ScalarListException(
"List values can't contain string '%s' (its being used as "
"separator. If you wish for scalar list values to contain "
"these strings, use a different separator string."
)
return self.separator.join(
map(unicode, value)
)
return value
def process_result_value(self, value, dialect):
if value:
# coerce each value
return map(
self.coerce_func, value.split(self.separator)
)
return value
class NumberRangeRawType(types.UserDefinedType):
"""
Raw number range type, only supports PostgreSQL for now.
@@ -96,9 +133,12 @@ class NumberRangeType(types.TypeDecorator):
def process_bind_param(self, value, dialect):
if value:
return value.normalized
return value
def process_result_value(self, value, dialect):
return NumberRange.from_normalized_str(value)
if value:
return NumberRange.from_normalized_str(value)
return value
class NumberRangeException(Exception):

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa
from pytest import raises, mark
from pytest import raises
from sqlalchemy_utils import NumberRangeType, NumberRange, NumberRangeException
from tests import DatabaseTestCase

61
tests/test_scalar_list.py Normal file
View File

@@ -0,0 +1,61 @@
import sqlalchemy as sa
from sqlalchemy_utils import ScalarList
from pytest import raises
from tests import DatabaseTestCase
class TestScalarIntegerList(DatabaseTestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
some_list = sa.Column(ScalarList(int))
def __repr__(self):
return 'User(%r)' % self.id
self.User = User
def test_save_integer_list(self):
user = self.User(
some_list=[1, 2, 3, 4]
)
self.session.add(user)
self.session.commit()
user = self.session.query(self.User).first()
assert user.some_list == [1, 2, 3, 4]
class TestScalarUnicodeList(DatabaseTestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
some_list = sa.Column(ScalarList(unicode))
def __repr__(self):
return 'User(%r)' % self.id
self.User = User
def test_throws_exception_if_using_separator_in_list_values(self):
user = self.User(
some_list=[u',']
)
self.session.add(user)
with raises(sa.exc.StatementError):
self.session.commit()
def test_save_unicode_list(self):
user = self.User(
some_list=[u'1', u'2', u'3', u'4']
)
self.session.add(user)
self.session.commit()
user = self.session.query(self.User).first()
assert user.some_list == [u'1', u'2', u'3', u'4']