diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index c160851..f18b14e 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -9,6 +9,7 @@ from .types import ( EmailType, instrumented_list, InstrumentedList, + IPAddressType, PhoneNumber, PhoneNumberType, NumberRange, @@ -37,6 +38,7 @@ __all__ = ( ColorType, EmailType, InstrumentedList, + IPAddressType, Merger, NumberRange, NumberRangeException, diff --git a/sqlalchemy_utils/types.py b/sqlalchemy_utils/types.py deleted file mode 100644 index f6efc71..0000000 --- a/sqlalchemy_utils/types.py +++ /dev/null @@ -1,361 +0,0 @@ -import six -import phonenumbers -from colour import Color -from functools import wraps -import sqlalchemy as sa -from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList -from sqlalchemy import types -from .operators import CaseInsensitiveComparator - - -class PhoneNumber(phonenumbers.phonenumber.PhoneNumber): - ''' - Extends a PhoneNumber class from `Python phonenumbers library`_. Adds - different phone number formats to attributes, so they can be easily used - in templates. Phone number validation method is also implemented. - - Takes the raw phone number and country code as params and parses them - into a PhoneNumber object. - - .. _Python phonenumbers library: - https://github.com/daviddrysdale/python-phonenumbers - - :param raw_number: - String representation of the phone number. - :param country_code: - Country code of the phone number. - ''' - def __init__(self, raw_number, country_code=None): - self._phone_number = phonenumbers.parse(raw_number, country_code) - super(PhoneNumber, self).__init__( - country_code=self._phone_number.country_code, - national_number=self._phone_number.national_number, - extension=self._phone_number.extension, - italian_leading_zero=self._phone_number.italian_leading_zero, - raw_input=self._phone_number.raw_input, - country_code_source=self._phone_number.country_code_source, - preferred_domestic_carrier_code= - self._phone_number.preferred_domestic_carrier_code - ) - self.national = phonenumbers.format_number( - self._phone_number, - phonenumbers.PhoneNumberFormat.NATIONAL - ) - self.international = phonenumbers.format_number( - self._phone_number, - phonenumbers.PhoneNumberFormat.INTERNATIONAL - ) - self.e164 = phonenumbers.format_number( - self._phone_number, - phonenumbers.PhoneNumberFormat.E164 - ) - - def is_valid_number(self): - return phonenumbers.is_valid_number(self._phone_number) - - def __unicode__(self): - return self.national - - def __str__(self): - return six.text_type(self.national).encode('utf-8') - - -class PhoneNumberType(types.TypeDecorator): - """ - Changes PhoneNumber objects to a string representation on the way in and - changes them back to PhoneNumber objects on the way out. If E164 is used - as storing format, no country code is needed for parsing the database - value to PhoneNumber object. - """ - STORE_FORMAT = 'e164' - impl = types.Unicode(20) - - def __init__(self, country_code='US', max_length=20, *args, **kwargs): - super(PhoneNumberType, self).__init__(*args, **kwargs) - self.country_code = country_code - self.impl = types.Unicode(max_length) - - def process_bind_param(self, value, dialect): - if value: - return getattr(value, self.STORE_FORMAT) - return value - - def process_result_value(self, value, dialect): - if value: - return PhoneNumber(value, self.country_code) - return value - - def coercion_listener(self, target, value, oldvalue, initiator): - if value is not None and not isinstance(value, PhoneNumber): - value = PhoneNumber(value, country_code=self.country_code) - return value - - -class ColorType(types.TypeDecorator): - """ - Changes Color objects to a string representation on the way in and - changes them back to Color objects on the way out. - """ - STORE_FORMAT = u'hex' - impl = types.Unicode(20) - - def __init__(self, max_length=20, *args, **kwargs): - super(ColorType, self).__init__(*args, **kwargs) - self.impl = types.Unicode(max_length) - - def process_bind_param(self, value, dialect): - if value: - return six.text_type(getattr(value, self.STORE_FORMAT)) - return value - - def process_result_value(self, value, dialect): - if value: - return Color(value) - return value - - def coercion_listener(self, target, value, oldvalue, initiator): - if value is not None and not isinstance(value, Color): - value = Color(value) - return value - - -class ScalarListException(Exception): - pass - - -class ScalarListType(types.TypeDecorator): - impl = sa.UnicodeText() - - def __init__(self, coerce_func=six.text_type, separator=u','): - self.separator = six.text_type(separator) - self.coerce_func = coerce_func - - def process_bind_param(self, value, dialect): - # Convert list of values to unicode separator-separated list - # Example: [1, 2, 3, 4] -> u'1, 2, 3, 4' - if value is not None: - if any(self.separator in six.text_type(item) for item in value): - raise ScalarListException( - "List values can't contain string '%s' (its being used as " - "separator. If you wish for scalar list values to contain " - "these strings, use a different separator string." - ) - return self.separator.join( - map(six.text_type, value) - ) - - def process_result_value(self, value, dialect): - if value is not None: - if value == u'': - return [] - # coerce each value - return list(map( - self.coerce_func, value.split(self.separator) - )) - - -class EmailType(sa.types.TypeDecorator): - impl = sa.Unicode(255) - comparator_factory = CaseInsensitiveComparator - - def process_bind_param(self, value, dialect): - if value is not None: - return value.lower() - return value - - -class TSVectorType(types.UserDefinedType): - """ - Text search vector type for postgresql. - """ - def get_col_spec(self): - return 'tsvector' - - -class NumberRangeRawType(types.UserDefinedType): - """ - Raw number range type, only supports PostgreSQL for now. - """ - def get_col_spec(self): - return 'int4range' - - -class NumberRangeType(types.TypeDecorator): - impl = NumberRangeRawType - - def process_bind_param(self, value, dialect): - if value is not None: - return value.normalized - return value - - def process_result_value(self, value, dialect): - if value: - if not isinstance(value, six.string_types): - value = NumberRange.from_range_object(value) - else: - return NumberRange.from_normalized_str(value) - return value - - def coercion_listener(self, target, value, oldvalue, initiator): - if value is not None and not isinstance(value, NumberRange): - if isinstance(value, six.string_types): - value = NumberRange.from_normalized_str(value) - else: - raise TypeError - return value - - -class NumberRangeException(Exception): - pass - - -class RangeBoundsException(NumberRangeException): - def __init__(self, min_value, max_value): - self.message = 'Min value %d is bigger than max value %d.' % ( - min_value, - max_value - ) - - -class NumberRange(object): - def __init__(self, min_value, max_value): - if min_value > max_value: - raise RangeBoundsException(min_value, max_value) - self.min_value = min_value - self.max_value = max_value - - @classmethod - def from_range_object(cls, value): - min_value = value.lower - max_value = value.upper - if not value.lower_inc: - min_value += 1 - - if not value.upper_inc: - max_value -= 1 - - return cls(min_value, max_value) - - @classmethod - def from_normalized_str(cls, value): - """ - Returns new NumberRange object from normalized number range format. - - Example :: - - range = NumberRange.from_normalized_str('[23, 45]') - range.min_value = 23 - range.max_value = 45 - - range = NumberRange.from_normalized_str('(23, 45]') - range.min_value = 24 - range.max_value = 45 - - range = NumberRange.from_normalized_str('(23, 45)') - range.min_value = 24 - range.max_value = 44 - """ - if value is not None: - values = value[1:-1].split(',') - try: - min_value, max_value = map( - lambda a: int(a.strip()), values - ) - except ValueError as e: - raise NumberRangeException(e.message) - - if value[0] == '(': - min_value += 1 - - if value[-1] == ')': - max_value -= 1 - - return cls(min_value, max_value) - - @classmethod - def from_str(cls, value): - if value is not None: - values = value.split('-') - if len(values) == 1: - min_value = max_value = int(value.strip()) - else: - try: - min_value, max_value = map( - lambda a: int(a.strip()), values - ) - except ValueError as e: - raise NumberRangeException(str(e)) - return cls(min_value, max_value) - - @property - def normalized(self): - return '[%s, %s]' % (self.min_value, self.max_value) - - def __eq__(self, other): - try: - return ( - self.min_value == other.min_value and - self.max_value == other.max_value - ) - except AttributeError: - return NotImplemented - - def __repr__(self): - return 'NumberRange(%r, %r)' % (self.min_value, self.max_value) - - def __str__(self): - if self.min_value != self.max_value: - return '%s - %s' % (self.min_value, self.max_value) - return str(self.min_value) - - def __add__(self, other): - try: - return NumberRange( - self.min_value + other.min_value, - self.max_value + other.max_value - ) - except AttributeError: - return NotImplemented - - def __iadd__(self, other): - try: - self.min_value += other.min_value - self.max_value += other.max_value - return self - except AttributeError: - return NotImplemented - - def __sub__(self, other): - try: - return NumberRange( - self.min_value - other.min_value, - self.max_value - other.max_value - ) - except AttributeError: - return NotImplemented - - def __isub__(self, other): - try: - self.min_value -= other.min_value - self.max_value -= other.max_value - return self - except AttributeError: - return NotImplemented - - -class InstrumentedList(_InstrumentedList): - """Enhanced version of SQLAlchemy InstrumentedList. Provides some - additional functionality.""" - - def any(self, attr): - return any(getattr(item, attr) for item in self) - - def all(self, attr): - return all(getattr(item, attr) for item in self) - - -def instrumented_list(f): - @wraps(f) - def wrapper(*args, **kwargs): - return InstrumentedList([item for item in f(*args, **kwargs)]) - return wrapper diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py new file mode 100644 index 0000000..5586b7c --- /dev/null +++ b/sqlalchemy_utils/types/__init__.py @@ -0,0 +1,55 @@ +from functools import wraps +from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList +from sqlalchemy import types +from .color import ColorType +from .email import EmailType +from .ip_address import IPAddressType +from .number_range import ( + NumberRange, + NumberRangeException, + NumberRangeRawType, + NumberRangeType, +) +from .phone_number import PhoneNumber, PhoneNumberType +from .scalar_list import ScalarListException, ScalarListType + + +__all__ = ( + ColorType, + EmailType, + IPAddressType, + NumberRange, + NumberRangeException, + NumberRangeRawType, + NumberRangeType, + PhoneNumber, + PhoneNumberType, + ScalarListException, + ScalarListType, +) + + +class TSVectorType(types.UserDefinedType): + """ + Text search vector type for postgresql. + """ + def get_col_spec(self): + return 'tsvector' + + +class InstrumentedList(_InstrumentedList): + """Enhanced version of SQLAlchemy InstrumentedList. Provides some + additional functionality.""" + + def any(self, attr): + return any(getattr(item, attr) for item in self) + + def all(self, attr): + return all(getattr(item, attr) for item in self) + + +def instrumented_list(f): + @wraps(f) + def wrapper(*args, **kwargs): + return InstrumentedList([item for item in f(*args, **kwargs)]) + return wrapper diff --git a/sqlalchemy_utils/types/color.py b/sqlalchemy_utils/types/color.py new file mode 100644 index 0000000..5e539ea --- /dev/null +++ b/sqlalchemy_utils/types/color.py @@ -0,0 +1,31 @@ +import six +from colour import Color +from sqlalchemy import types + + +class ColorType(types.TypeDecorator): + """ + Changes Color objects to a string representation on the way in and + changes them back to Color objects on the way out. + """ + STORE_FORMAT = u'hex' + impl = types.Unicode(20) + + def __init__(self, max_length=20, *args, **kwargs): + super(ColorType, self).__init__(*args, **kwargs) + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + if value: + return six.text_type(getattr(value, self.STORE_FORMAT)) + return value + + def process_result_value(self, value, dialect): + if value: + return Color(value) + return value + + def coercion_listener(self, target, value, oldvalue, initiator): + if value is not None and not isinstance(value, Color): + value = Color(value) + return value diff --git a/sqlalchemy_utils/types/email.py b/sqlalchemy_utils/types/email.py new file mode 100644 index 0000000..73acda1 --- /dev/null +++ b/sqlalchemy_utils/types/email.py @@ -0,0 +1,12 @@ +import sqlalchemy as sa +from ..operators import CaseInsensitiveComparator + + +class EmailType(sa.types.TypeDecorator): + impl = sa.Unicode(255) + comparator_factory = CaseInsensitiveComparator + + def process_bind_param(self, value, dialect): + if value is not None: + return value.lower() + return value diff --git a/sqlalchemy_utils/types/ip_address.py b/sqlalchemy_utils/types/ip_address.py new file mode 100644 index 0000000..457d2ca --- /dev/null +++ b/sqlalchemy_utils/types/ip_address.py @@ -0,0 +1,34 @@ +import six +import ipaddress +from sqlalchemy import types + + +class IPAddressType(types.TypeDecorator): + """ + Changes Color objects to a string representation on the way in and + changes them back to Color objects on the way out. + """ + impl = types.Unicode(50) + + def __init__(self, max_length=50, *args, **kwargs): + super(IPAddressType, self).__init__(*args, **kwargs) + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + if value: + return six.text_type(value) + return value + + def process_result_value(self, value, dialect): + if value: + return ipaddress.ip_address(value) + return value + + def coercion_listener(self, target, value, oldvalue, initiator): + if ( + value is not None and + not isinstance(value, ipaddress.IPv4Address) and + not isinstance(value, ipaddress.IPv6Address) + ): + value = ipaddress.ip_address(value) + return value diff --git a/sqlalchemy_utils/types/number_range.py b/sqlalchemy_utils/types/number_range.py new file mode 100644 index 0000000..4b0a771 --- /dev/null +++ b/sqlalchemy_utils/types/number_range.py @@ -0,0 +1,173 @@ +import six +from sqlalchemy import types + + +class NumberRangeRawType(types.UserDefinedType): + """ + Raw number range type, only supports PostgreSQL for now. + """ + def get_col_spec(self): + return 'int4range' + + +class NumberRangeType(types.TypeDecorator): + impl = NumberRangeRawType + + def process_bind_param(self, value, dialect): + if value is not None: + return value.normalized + return value + + def process_result_value(self, value, dialect): + if value: + if not isinstance(value, six.string_types): + value = NumberRange.from_range_object(value) + else: + return NumberRange.from_normalized_str(value) + return value + + def coercion_listener(self, target, value, oldvalue, initiator): + if value is not None and not isinstance(value, NumberRange): + if isinstance(value, six.string_types): + value = NumberRange.from_normalized_str(value) + else: + raise TypeError + return value + + +class NumberRangeException(Exception): + pass + + +class RangeBoundsException(NumberRangeException): + def __init__(self, min_value, max_value): + self.message = 'Min value %d is bigger than max value %d.' % ( + min_value, + max_value + ) + + +class NumberRange(object): + def __init__(self, min_value, max_value): + if min_value > max_value: + raise RangeBoundsException(min_value, max_value) + self.min_value = min_value + self.max_value = max_value + + @classmethod + def from_range_object(cls, value): + min_value = value.lower + max_value = value.upper + if not value.lower_inc: + min_value += 1 + + if not value.upper_inc: + max_value -= 1 + + return cls(min_value, max_value) + + @classmethod + def from_normalized_str(cls, value): + """ + Returns new NumberRange object from normalized number range format. + + Example :: + + range = NumberRange.from_normalized_str('[23, 45]') + range.min_value = 23 + range.max_value = 45 + + range = NumberRange.from_normalized_str('(23, 45]') + range.min_value = 24 + range.max_value = 45 + + range = NumberRange.from_normalized_str('(23, 45)') + range.min_value = 24 + range.max_value = 44 + """ + if value is not None: + values = value[1:-1].split(',') + try: + min_value, max_value = map( + lambda a: int(a.strip()), values + ) + except ValueError as e: + raise NumberRangeException(e.message) + + if value[0] == '(': + min_value += 1 + + if value[-1] == ')': + max_value -= 1 + + return cls(min_value, max_value) + + @classmethod + def from_str(cls, value): + if value is not None: + values = value.split('-') + if len(values) == 1: + min_value = max_value = int(value.strip()) + else: + try: + min_value, max_value = map( + lambda a: int(a.strip()), values + ) + except ValueError as e: + raise NumberRangeException(str(e)) + return cls(min_value, max_value) + + @property + def normalized(self): + return '[%s, %s]' % (self.min_value, self.max_value) + + def __eq__(self, other): + try: + return ( + self.min_value == other.min_value and + self.max_value == other.max_value + ) + except AttributeError: + return NotImplemented + + def __repr__(self): + return 'NumberRange(%r, %r)' % (self.min_value, self.max_value) + + def __str__(self): + if self.min_value != self.max_value: + return '%s - %s' % (self.min_value, self.max_value) + return str(self.min_value) + + def __add__(self, other): + try: + return NumberRange( + self.min_value + other.min_value, + self.max_value + other.max_value + ) + except AttributeError: + return NotImplemented + + def __iadd__(self, other): + try: + self.min_value += other.min_value + self.max_value += other.max_value + return self + except AttributeError: + return NotImplemented + + def __sub__(self, other): + try: + return NumberRange( + self.min_value - other.min_value, + self.max_value - other.max_value + ) + except AttributeError: + return NotImplemented + + def __isub__(self, other): + try: + self.min_value -= other.min_value + self.max_value -= other.max_value + return self + except AttributeError: + return NotImplemented diff --git a/sqlalchemy_utils/types/phone_number.py b/sqlalchemy_utils/types/phone_number.py new file mode 100644 index 0000000..425d058 --- /dev/null +++ b/sqlalchemy_utils/types/phone_number.py @@ -0,0 +1,86 @@ +import six +import phonenumbers +from sqlalchemy import types + + +class PhoneNumber(phonenumbers.phonenumber.PhoneNumber): + ''' + Extends a PhoneNumber class from `Python phonenumbers library`_. Adds + different phone number formats to attributes, so they can be easily used + in templates. Phone number validation method is also implemented. + + Takes the raw phone number and country code as params and parses them + into a PhoneNumber object. + + .. _Python phonenumbers library: + https://github.com/daviddrysdale/python-phonenumbers + + :param raw_number: + String representation of the phone number. + :param country_code: + Country code of the phone number. + ''' + def __init__(self, raw_number, country_code=None): + self._phone_number = phonenumbers.parse(raw_number, country_code) + super(PhoneNumber, self).__init__( + country_code=self._phone_number.country_code, + national_number=self._phone_number.national_number, + extension=self._phone_number.extension, + italian_leading_zero=self._phone_number.italian_leading_zero, + raw_input=self._phone_number.raw_input, + country_code_source=self._phone_number.country_code_source, + preferred_domestic_carrier_code= + self._phone_number.preferred_domestic_carrier_code + ) + self.national = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.NATIONAL + ) + self.international = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.INTERNATIONAL + ) + self.e164 = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.E164 + ) + + def is_valid_number(self): + return phonenumbers.is_valid_number(self._phone_number) + + def __unicode__(self): + return self.national + + def __str__(self): + return six.text_type(self.national).encode('utf-8') + + +class PhoneNumberType(types.TypeDecorator): + """ + Changes PhoneNumber objects to a string representation on the way in and + changes them back to PhoneNumber objects on the way out. If E164 is used + as storing format, no country code is needed for parsing the database + value to PhoneNumber object. + """ + STORE_FORMAT = 'e164' + impl = types.Unicode(20) + + def __init__(self, country_code='US', max_length=20, *args, **kwargs): + super(PhoneNumberType, self).__init__(*args, **kwargs) + self.country_code = country_code + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + if value: + return getattr(value, self.STORE_FORMAT) + return value + + def process_result_value(self, value, dialect): + if value: + return PhoneNumber(value, self.country_code) + return value + + def coercion_listener(self, target, value, oldvalue, initiator): + if value is not None and not isinstance(value, PhoneNumber): + value = PhoneNumber(value, country_code=self.country_code) + return value diff --git a/sqlalchemy_utils/types/scalar_list.py b/sqlalchemy_utils/types/scalar_list.py new file mode 100644 index 0000000..b4f4a6a --- /dev/null +++ b/sqlalchemy_utils/types/scalar_list.py @@ -0,0 +1,38 @@ +import six +import sqlalchemy as sa +from sqlalchemy import types + + +class ScalarListException(Exception): + pass + + +class ScalarListType(types.TypeDecorator): + impl = sa.UnicodeText() + + def __init__(self, coerce_func=six.text_type, separator=u','): + self.separator = six.text_type(separator) + self.coerce_func = coerce_func + + def process_bind_param(self, value, dialect): + # Convert list of values to unicode separator-separated list + # Example: [1, 2, 3, 4] -> u'1, 2, 3, 4' + if value is not None: + if any(self.separator in six.text_type(item) for item in value): + raise ScalarListException( + "List values can't contain string '%s' (its being used as " + "separator. If you wish for scalar list values to contain " + "these strings, use a different separator string.)" + ) + return self.separator.join( + map(six.text_type, value) + ) + + def process_result_value(self, value, dialect): + if value is not None: + if value == u'': + return [] + # coerce each value + return list(map( + self.coerce_func, value.split(self.separator) + )) diff --git a/sqlalchemy_utils/types/slug_type.py b/sqlalchemy_utils/types/slug_type.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py index 1d62436..3a117ec 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,4 @@ +import warnings import sqlalchemy as sa from sqlalchemy import create_engine @@ -18,12 +19,17 @@ def count_sql_calls(conn, cursor, statement, parameters, context, executemany): conn.query_count = 0 +warnings.simplefilter('error', sa.exc.SAWarning) + + class TestCase(object): + dns = 'sqlite:///:memory:' + def setup_method(self, method): - self.engine = create_engine('sqlite:///:memory:') + self.engine = create_engine(self.dns) self.connection = self.engine.connect() self.Base = declarative_base() - self.Base2 = declarative_base() + self.create_models() self.Base.metadata.create_all(self.connection) diff --git a/tests/test_ip_address.py b/tests/test_ip_address.py new file mode 100644 index 0000000..4b03496 --- /dev/null +++ b/tests/test_ip_address.py @@ -0,0 +1,29 @@ +import ipaddress +import six +import sqlalchemy as sa +from sqlalchemy_utils import IPAddressType +from tests import TestCase + + +class TestIPAddressType(TestCase): + def create_models(self): + class Visitor(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + ip_address = sa.Column(IPAddressType) + + def __repr__(self): + return 'Visitor(%r)' % self.id + + self.Visitor = Visitor + + def test_parameter_processing(self): + visitor = self.Visitor( + ip_address=ipaddress.ip_address(u'111.111.111.111') + ) + + self.session.add(visitor) + self.session.commit() + + visitor = self.session.query(self.Visitor).first() + assert six.text_type(visitor.ip_address) == u'111.111.111.111' diff --git a/tests/test_utility_functions.py b/tests/test_utility_functions.py index 11e5d9f..17fb538 100644 --- a/tests/test_utility_functions.py +++ b/tests/test_utility_functions.py @@ -17,6 +17,8 @@ class TestDeferExcept(TestCase): class TestFindNonIndexedForeignKeys(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + def create_models(self): class User(self.Base): __tablename__ = 'user'