Added IPAddressType, refactored types
This commit is contained in:
		@@ -9,6 +9,7 @@ from .types import (
 | 
			
		||||
    EmailType,
 | 
			
		||||
    instrumented_list,
 | 
			
		||||
    InstrumentedList,
 | 
			
		||||
    IPAddressType,
 | 
			
		||||
    PhoneNumber,
 | 
			
		||||
    PhoneNumberType,
 | 
			
		||||
    NumberRange,
 | 
			
		||||
@@ -37,6 +38,7 @@ __all__ = (
 | 
			
		||||
    ColorType,
 | 
			
		||||
    EmailType,
 | 
			
		||||
    InstrumentedList,
 | 
			
		||||
    IPAddressType,
 | 
			
		||||
    Merger,
 | 
			
		||||
    NumberRange,
 | 
			
		||||
    NumberRangeException,
 | 
			
		||||
 
 | 
			
		||||
@@ -1,361 +0,0 @@
 | 
			
		||||
import six
 | 
			
		||||
import phonenumbers
 | 
			
		||||
from colour import Color
 | 
			
		||||
from functools import wraps
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
from .operators import CaseInsensitiveComparator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
 | 
			
		||||
    '''
 | 
			
		||||
    Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
 | 
			
		||||
    different phone number formats to attributes, so they can be easily used
 | 
			
		||||
    in templates. Phone number validation method is also implemented.
 | 
			
		||||
 | 
			
		||||
    Takes the raw phone number and country code as params and parses them
 | 
			
		||||
    into a PhoneNumber object.
 | 
			
		||||
 | 
			
		||||
    .. _Python phonenumbers library:
 | 
			
		||||
       https://github.com/daviddrysdale/python-phonenumbers
 | 
			
		||||
 | 
			
		||||
    :param raw_number:
 | 
			
		||||
        String representation of the phone number.
 | 
			
		||||
    :param country_code:
 | 
			
		||||
        Country code of the phone number.
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, raw_number, country_code=None):
 | 
			
		||||
        self._phone_number = phonenumbers.parse(raw_number, country_code)
 | 
			
		||||
        super(PhoneNumber, self).__init__(
 | 
			
		||||
            country_code=self._phone_number.country_code,
 | 
			
		||||
            national_number=self._phone_number.national_number,
 | 
			
		||||
            extension=self._phone_number.extension,
 | 
			
		||||
            italian_leading_zero=self._phone_number.italian_leading_zero,
 | 
			
		||||
            raw_input=self._phone_number.raw_input,
 | 
			
		||||
            country_code_source=self._phone_number.country_code_source,
 | 
			
		||||
            preferred_domestic_carrier_code=
 | 
			
		||||
            self._phone_number.preferred_domestic_carrier_code
 | 
			
		||||
        )
 | 
			
		||||
        self.national = phonenumbers.format_number(
 | 
			
		||||
            self._phone_number,
 | 
			
		||||
            phonenumbers.PhoneNumberFormat.NATIONAL
 | 
			
		||||
        )
 | 
			
		||||
        self.international = phonenumbers.format_number(
 | 
			
		||||
            self._phone_number,
 | 
			
		||||
            phonenumbers.PhoneNumberFormat.INTERNATIONAL
 | 
			
		||||
        )
 | 
			
		||||
        self.e164 = phonenumbers.format_number(
 | 
			
		||||
            self._phone_number,
 | 
			
		||||
            phonenumbers.PhoneNumberFormat.E164
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def is_valid_number(self):
 | 
			
		||||
        return phonenumbers.is_valid_number(self._phone_number)
 | 
			
		||||
 | 
			
		||||
    def __unicode__(self):
 | 
			
		||||
        return self.national
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return six.text_type(self.national).encode('utf-8')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PhoneNumberType(types.TypeDecorator):
 | 
			
		||||
    """
 | 
			
		||||
    Changes PhoneNumber objects to a string representation on the way in and
 | 
			
		||||
    changes them back to PhoneNumber objects on the way out. If E164 is used
 | 
			
		||||
    as storing format, no country code is needed for parsing the database
 | 
			
		||||
    value to PhoneNumber object.
 | 
			
		||||
    """
 | 
			
		||||
    STORE_FORMAT = 'e164'
 | 
			
		||||
    impl = types.Unicode(20)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, country_code='US', max_length=20, *args, **kwargs):
 | 
			
		||||
        super(PhoneNumberType, self).__init__(*args, **kwargs)
 | 
			
		||||
        self.country_code = country_code
 | 
			
		||||
        self.impl = types.Unicode(max_length)
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return getattr(value, self.STORE_FORMAT)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return PhoneNumber(value, self.country_code)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def coercion_listener(self, target, value, oldvalue, initiator):
 | 
			
		||||
        if value is not None and not isinstance(value, PhoneNumber):
 | 
			
		||||
            value = PhoneNumber(value, country_code=self.country_code)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColorType(types.TypeDecorator):
 | 
			
		||||
    """
 | 
			
		||||
    Changes Color objects to a string representation on the way in and
 | 
			
		||||
    changes them back to Color objects on the way out.
 | 
			
		||||
    """
 | 
			
		||||
    STORE_FORMAT = u'hex'
 | 
			
		||||
    impl = types.Unicode(20)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, max_length=20, *args, **kwargs):
 | 
			
		||||
        super(ColorType, self).__init__(*args, **kwargs)
 | 
			
		||||
        self.impl = types.Unicode(max_length)
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return six.text_type(getattr(value, self.STORE_FORMAT))
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return Color(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def coercion_listener(self, target, value, oldvalue, initiator):
 | 
			
		||||
        if value is not None and not isinstance(value, Color):
 | 
			
		||||
            value = Color(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScalarListException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScalarListType(types.TypeDecorator):
 | 
			
		||||
    impl = sa.UnicodeText()
 | 
			
		||||
 | 
			
		||||
    def __init__(self, coerce_func=six.text_type, separator=u','):
 | 
			
		||||
        self.separator = six.text_type(separator)
 | 
			
		||||
        self.coerce_func = coerce_func
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        # Convert list of values to unicode separator-separated list
 | 
			
		||||
        # Example: [1, 2, 3, 4] -> u'1, 2, 3, 4'
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            if any(self.separator in six.text_type(item) for item in value):
 | 
			
		||||
                raise ScalarListException(
 | 
			
		||||
                    "List values can't contain string '%s' (its being used as "
 | 
			
		||||
                    "separator. If you wish for scalar list values to contain "
 | 
			
		||||
                    "these strings, use a different separator string."
 | 
			
		||||
                )
 | 
			
		||||
            return self.separator.join(
 | 
			
		||||
                map(six.text_type, value)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            if value == u'':
 | 
			
		||||
                return []
 | 
			
		||||
            # coerce each value
 | 
			
		||||
            return list(map(
 | 
			
		||||
                self.coerce_func, value.split(self.separator)
 | 
			
		||||
            ))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EmailType(sa.types.TypeDecorator):
 | 
			
		||||
    impl = sa.Unicode(255)
 | 
			
		||||
    comparator_factory = CaseInsensitiveComparator
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            return value.lower()
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TSVectorType(types.UserDefinedType):
 | 
			
		||||
    """
 | 
			
		||||
    Text search vector type for postgresql.
 | 
			
		||||
    """
 | 
			
		||||
    def get_col_spec(self):
 | 
			
		||||
        return 'tsvector'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRangeRawType(types.UserDefinedType):
 | 
			
		||||
    """
 | 
			
		||||
    Raw number range type, only supports PostgreSQL for now.
 | 
			
		||||
    """
 | 
			
		||||
    def get_col_spec(self):
 | 
			
		||||
        return 'int4range'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRangeType(types.TypeDecorator):
 | 
			
		||||
    impl = NumberRangeRawType
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            return value.normalized
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            if not isinstance(value, six.string_types):
 | 
			
		||||
                value = NumberRange.from_range_object(value)
 | 
			
		||||
            else:
 | 
			
		||||
                return NumberRange.from_normalized_str(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def coercion_listener(self, target, value, oldvalue, initiator):
 | 
			
		||||
        if value is not None and not isinstance(value, NumberRange):
 | 
			
		||||
            if isinstance(value, six.string_types):
 | 
			
		||||
                value = NumberRange.from_normalized_str(value)
 | 
			
		||||
            else:
 | 
			
		||||
                raise TypeError
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRangeException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RangeBoundsException(NumberRangeException):
 | 
			
		||||
    def __init__(self, min_value, max_value):
 | 
			
		||||
        self.message = 'Min value %d is bigger than max value %d.' % (
 | 
			
		||||
            min_value,
 | 
			
		||||
            max_value
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRange(object):
 | 
			
		||||
    def __init__(self, min_value, max_value):
 | 
			
		||||
        if min_value > max_value:
 | 
			
		||||
            raise RangeBoundsException(min_value, max_value)
 | 
			
		||||
        self.min_value = min_value
 | 
			
		||||
        self.max_value = max_value
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_range_object(cls, value):
 | 
			
		||||
        min_value = value.lower
 | 
			
		||||
        max_value = value.upper
 | 
			
		||||
        if not value.lower_inc:
 | 
			
		||||
            min_value += 1
 | 
			
		||||
 | 
			
		||||
        if not value.upper_inc:
 | 
			
		||||
            max_value -= 1
 | 
			
		||||
 | 
			
		||||
        return cls(min_value, max_value)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_normalized_str(cls, value):
 | 
			
		||||
        """
 | 
			
		||||
        Returns new NumberRange object from normalized number range format.
 | 
			
		||||
 | 
			
		||||
        Example ::
 | 
			
		||||
 | 
			
		||||
            range = NumberRange.from_normalized_str('[23, 45]')
 | 
			
		||||
            range.min_value = 23
 | 
			
		||||
            range.max_value = 45
 | 
			
		||||
 | 
			
		||||
            range = NumberRange.from_normalized_str('(23, 45]')
 | 
			
		||||
            range.min_value = 24
 | 
			
		||||
            range.max_value = 45
 | 
			
		||||
 | 
			
		||||
            range = NumberRange.from_normalized_str('(23, 45)')
 | 
			
		||||
            range.min_value = 24
 | 
			
		||||
            range.max_value = 44
 | 
			
		||||
        """
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            values = value[1:-1].split(',')
 | 
			
		||||
            try:
 | 
			
		||||
                min_value, max_value = map(
 | 
			
		||||
                    lambda a: int(a.strip()), values
 | 
			
		||||
                )
 | 
			
		||||
            except ValueError as e:
 | 
			
		||||
                raise NumberRangeException(e.message)
 | 
			
		||||
 | 
			
		||||
            if value[0] == '(':
 | 
			
		||||
                min_value += 1
 | 
			
		||||
 | 
			
		||||
            if value[-1] == ')':
 | 
			
		||||
                max_value -= 1
 | 
			
		||||
 | 
			
		||||
            return cls(min_value, max_value)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_str(cls, value):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            values = value.split('-')
 | 
			
		||||
            if len(values) == 1:
 | 
			
		||||
                min_value = max_value = int(value.strip())
 | 
			
		||||
            else:
 | 
			
		||||
                try:
 | 
			
		||||
                    min_value, max_value = map(
 | 
			
		||||
                        lambda a: int(a.strip()), values
 | 
			
		||||
                    )
 | 
			
		||||
                except ValueError as e:
 | 
			
		||||
                    raise NumberRangeException(str(e))
 | 
			
		||||
            return cls(min_value, max_value)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def normalized(self):
 | 
			
		||||
        return '[%s, %s]' % (self.min_value, self.max_value)
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            return (
 | 
			
		||||
                self.min_value == other.min_value and
 | 
			
		||||
                self.max_value == other.max_value
 | 
			
		||||
            )
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return 'NumberRange(%r, %r)' % (self.min_value, self.max_value)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        if self.min_value != self.max_value:
 | 
			
		||||
            return '%s - %s' % (self.min_value, self.max_value)
 | 
			
		||||
        return str(self.min_value)
 | 
			
		||||
 | 
			
		||||
    def __add__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            return NumberRange(
 | 
			
		||||
                self.min_value + other.min_value,
 | 
			
		||||
                self.max_value + other.max_value
 | 
			
		||||
            )
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __iadd__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            self.min_value += other.min_value
 | 
			
		||||
            self.max_value += other.max_value
 | 
			
		||||
            return self
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __sub__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            return NumberRange(
 | 
			
		||||
                self.min_value - other.min_value,
 | 
			
		||||
                self.max_value - other.max_value
 | 
			
		||||
            )
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __isub__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            self.min_value -= other.min_value
 | 
			
		||||
            self.max_value -= other.max_value
 | 
			
		||||
            return self
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InstrumentedList(_InstrumentedList):
 | 
			
		||||
    """Enhanced version of SQLAlchemy InstrumentedList. Provides some
 | 
			
		||||
    additional functionality."""
 | 
			
		||||
 | 
			
		||||
    def any(self, attr):
 | 
			
		||||
        return any(getattr(item, attr) for item in self)
 | 
			
		||||
 | 
			
		||||
    def all(self, attr):
 | 
			
		||||
        return all(getattr(item, attr) for item in self)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def instrumented_list(f):
 | 
			
		||||
    @wraps(f)
 | 
			
		||||
    def wrapper(*args, **kwargs):
 | 
			
		||||
        return InstrumentedList([item for item in f(*args, **kwargs)])
 | 
			
		||||
    return wrapper
 | 
			
		||||
							
								
								
									
										55
									
								
								sqlalchemy_utils/types/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								sqlalchemy_utils/types/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
from .color import ColorType
 | 
			
		||||
from .email import EmailType
 | 
			
		||||
from .ip_address import IPAddressType
 | 
			
		||||
from .number_range import (
 | 
			
		||||
    NumberRange,
 | 
			
		||||
    NumberRangeException,
 | 
			
		||||
    NumberRangeRawType,
 | 
			
		||||
    NumberRangeType,
 | 
			
		||||
)
 | 
			
		||||
from .phone_number import PhoneNumber, PhoneNumberType
 | 
			
		||||
from .scalar_list import ScalarListException, ScalarListType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = (
 | 
			
		||||
    ColorType,
 | 
			
		||||
    EmailType,
 | 
			
		||||
    IPAddressType,
 | 
			
		||||
    NumberRange,
 | 
			
		||||
    NumberRangeException,
 | 
			
		||||
    NumberRangeRawType,
 | 
			
		||||
    NumberRangeType,
 | 
			
		||||
    PhoneNumber,
 | 
			
		||||
    PhoneNumberType,
 | 
			
		||||
    ScalarListException,
 | 
			
		||||
    ScalarListType,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TSVectorType(types.UserDefinedType):
 | 
			
		||||
    """
 | 
			
		||||
    Text search vector type for postgresql.
 | 
			
		||||
    """
 | 
			
		||||
    def get_col_spec(self):
 | 
			
		||||
        return 'tsvector'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InstrumentedList(_InstrumentedList):
 | 
			
		||||
    """Enhanced version of SQLAlchemy InstrumentedList. Provides some
 | 
			
		||||
    additional functionality."""
 | 
			
		||||
 | 
			
		||||
    def any(self, attr):
 | 
			
		||||
        return any(getattr(item, attr) for item in self)
 | 
			
		||||
 | 
			
		||||
    def all(self, attr):
 | 
			
		||||
        return all(getattr(item, attr) for item in self)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def instrumented_list(f):
 | 
			
		||||
    @wraps(f)
 | 
			
		||||
    def wrapper(*args, **kwargs):
 | 
			
		||||
        return InstrumentedList([item for item in f(*args, **kwargs)])
 | 
			
		||||
    return wrapper
 | 
			
		||||
							
								
								
									
										31
									
								
								sqlalchemy_utils/types/color.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								sqlalchemy_utils/types/color.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
import six
 | 
			
		||||
from colour import Color
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColorType(types.TypeDecorator):
 | 
			
		||||
    """
 | 
			
		||||
    Changes Color objects to a string representation on the way in and
 | 
			
		||||
    changes them back to Color objects on the way out.
 | 
			
		||||
    """
 | 
			
		||||
    STORE_FORMAT = u'hex'
 | 
			
		||||
    impl = types.Unicode(20)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, max_length=20, *args, **kwargs):
 | 
			
		||||
        super(ColorType, self).__init__(*args, **kwargs)
 | 
			
		||||
        self.impl = types.Unicode(max_length)
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return six.text_type(getattr(value, self.STORE_FORMAT))
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return Color(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def coercion_listener(self, target, value, oldvalue, initiator):
 | 
			
		||||
        if value is not None and not isinstance(value, Color):
 | 
			
		||||
            value = Color(value)
 | 
			
		||||
        return value
 | 
			
		||||
							
								
								
									
										12
									
								
								sqlalchemy_utils/types/email.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								sqlalchemy_utils/types/email.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from ..operators import CaseInsensitiveComparator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EmailType(sa.types.TypeDecorator):
 | 
			
		||||
    impl = sa.Unicode(255)
 | 
			
		||||
    comparator_factory = CaseInsensitiveComparator
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            return value.lower()
 | 
			
		||||
        return value
 | 
			
		||||
							
								
								
									
										34
									
								
								sqlalchemy_utils/types/ip_address.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								sqlalchemy_utils/types/ip_address.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
			
		||||
import six
 | 
			
		||||
import ipaddress
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IPAddressType(types.TypeDecorator):
 | 
			
		||||
    """
 | 
			
		||||
    Changes Color objects to a string representation on the way in and
 | 
			
		||||
    changes them back to Color objects on the way out.
 | 
			
		||||
    """
 | 
			
		||||
    impl = types.Unicode(50)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, max_length=50, *args, **kwargs):
 | 
			
		||||
        super(IPAddressType, self).__init__(*args, **kwargs)
 | 
			
		||||
        self.impl = types.Unicode(max_length)
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return six.text_type(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return ipaddress.ip_address(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def coercion_listener(self, target, value, oldvalue, initiator):
 | 
			
		||||
        if (
 | 
			
		||||
            value is not None and
 | 
			
		||||
            not isinstance(value, ipaddress.IPv4Address) and
 | 
			
		||||
            not isinstance(value, ipaddress.IPv6Address)
 | 
			
		||||
        ):
 | 
			
		||||
            value = ipaddress.ip_address(value)
 | 
			
		||||
        return value
 | 
			
		||||
							
								
								
									
										173
									
								
								sqlalchemy_utils/types/number_range.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								sqlalchemy_utils/types/number_range.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,173 @@
 | 
			
		||||
import six
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRangeRawType(types.UserDefinedType):
 | 
			
		||||
    """
 | 
			
		||||
    Raw number range type, only supports PostgreSQL for now.
 | 
			
		||||
    """
 | 
			
		||||
    def get_col_spec(self):
 | 
			
		||||
        return 'int4range'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRangeType(types.TypeDecorator):
 | 
			
		||||
    impl = NumberRangeRawType
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            return value.normalized
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            if not isinstance(value, six.string_types):
 | 
			
		||||
                value = NumberRange.from_range_object(value)
 | 
			
		||||
            else:
 | 
			
		||||
                return NumberRange.from_normalized_str(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def coercion_listener(self, target, value, oldvalue, initiator):
 | 
			
		||||
        if value is not None and not isinstance(value, NumberRange):
 | 
			
		||||
            if isinstance(value, six.string_types):
 | 
			
		||||
                value = NumberRange.from_normalized_str(value)
 | 
			
		||||
            else:
 | 
			
		||||
                raise TypeError
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRangeException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RangeBoundsException(NumberRangeException):
 | 
			
		||||
    def __init__(self, min_value, max_value):
 | 
			
		||||
        self.message = 'Min value %d is bigger than max value %d.' % (
 | 
			
		||||
            min_value,
 | 
			
		||||
            max_value
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumberRange(object):
 | 
			
		||||
    def __init__(self, min_value, max_value):
 | 
			
		||||
        if min_value > max_value:
 | 
			
		||||
            raise RangeBoundsException(min_value, max_value)
 | 
			
		||||
        self.min_value = min_value
 | 
			
		||||
        self.max_value = max_value
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_range_object(cls, value):
 | 
			
		||||
        min_value = value.lower
 | 
			
		||||
        max_value = value.upper
 | 
			
		||||
        if not value.lower_inc:
 | 
			
		||||
            min_value += 1
 | 
			
		||||
 | 
			
		||||
        if not value.upper_inc:
 | 
			
		||||
            max_value -= 1
 | 
			
		||||
 | 
			
		||||
        return cls(min_value, max_value)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_normalized_str(cls, value):
 | 
			
		||||
        """
 | 
			
		||||
        Returns new NumberRange object from normalized number range format.
 | 
			
		||||
 | 
			
		||||
        Example ::
 | 
			
		||||
 | 
			
		||||
            range = NumberRange.from_normalized_str('[23, 45]')
 | 
			
		||||
            range.min_value = 23
 | 
			
		||||
            range.max_value = 45
 | 
			
		||||
 | 
			
		||||
            range = NumberRange.from_normalized_str('(23, 45]')
 | 
			
		||||
            range.min_value = 24
 | 
			
		||||
            range.max_value = 45
 | 
			
		||||
 | 
			
		||||
            range = NumberRange.from_normalized_str('(23, 45)')
 | 
			
		||||
            range.min_value = 24
 | 
			
		||||
            range.max_value = 44
 | 
			
		||||
        """
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            values = value[1:-1].split(',')
 | 
			
		||||
            try:
 | 
			
		||||
                min_value, max_value = map(
 | 
			
		||||
                    lambda a: int(a.strip()), values
 | 
			
		||||
                )
 | 
			
		||||
            except ValueError as e:
 | 
			
		||||
                raise NumberRangeException(e.message)
 | 
			
		||||
 | 
			
		||||
            if value[0] == '(':
 | 
			
		||||
                min_value += 1
 | 
			
		||||
 | 
			
		||||
            if value[-1] == ')':
 | 
			
		||||
                max_value -= 1
 | 
			
		||||
 | 
			
		||||
            return cls(min_value, max_value)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_str(cls, value):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            values = value.split('-')
 | 
			
		||||
            if len(values) == 1:
 | 
			
		||||
                min_value = max_value = int(value.strip())
 | 
			
		||||
            else:
 | 
			
		||||
                try:
 | 
			
		||||
                    min_value, max_value = map(
 | 
			
		||||
                        lambda a: int(a.strip()), values
 | 
			
		||||
                    )
 | 
			
		||||
                except ValueError as e:
 | 
			
		||||
                    raise NumberRangeException(str(e))
 | 
			
		||||
            return cls(min_value, max_value)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def normalized(self):
 | 
			
		||||
        return '[%s, %s]' % (self.min_value, self.max_value)
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            return (
 | 
			
		||||
                self.min_value == other.min_value and
 | 
			
		||||
                self.max_value == other.max_value
 | 
			
		||||
            )
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return 'NumberRange(%r, %r)' % (self.min_value, self.max_value)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        if self.min_value != self.max_value:
 | 
			
		||||
            return '%s - %s' % (self.min_value, self.max_value)
 | 
			
		||||
        return str(self.min_value)
 | 
			
		||||
 | 
			
		||||
    def __add__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            return NumberRange(
 | 
			
		||||
                self.min_value + other.min_value,
 | 
			
		||||
                self.max_value + other.max_value
 | 
			
		||||
            )
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __iadd__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            self.min_value += other.min_value
 | 
			
		||||
            self.max_value += other.max_value
 | 
			
		||||
            return self
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __sub__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            return NumberRange(
 | 
			
		||||
                self.min_value - other.min_value,
 | 
			
		||||
                self.max_value - other.max_value
 | 
			
		||||
            )
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
    def __isub__(self, other):
 | 
			
		||||
        try:
 | 
			
		||||
            self.min_value -= other.min_value
 | 
			
		||||
            self.max_value -= other.max_value
 | 
			
		||||
            return self
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
							
								
								
									
										86
									
								
								sqlalchemy_utils/types/phone_number.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								sqlalchemy_utils/types/phone_number.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
import six
 | 
			
		||||
import phonenumbers
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
 | 
			
		||||
    '''
 | 
			
		||||
    Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
 | 
			
		||||
    different phone number formats to attributes, so they can be easily used
 | 
			
		||||
    in templates. Phone number validation method is also implemented.
 | 
			
		||||
 | 
			
		||||
    Takes the raw phone number and country code as params and parses them
 | 
			
		||||
    into a PhoneNumber object.
 | 
			
		||||
 | 
			
		||||
    .. _Python phonenumbers library:
 | 
			
		||||
       https://github.com/daviddrysdale/python-phonenumbers
 | 
			
		||||
 | 
			
		||||
    :param raw_number:
 | 
			
		||||
        String representation of the phone number.
 | 
			
		||||
    :param country_code:
 | 
			
		||||
        Country code of the phone number.
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, raw_number, country_code=None):
 | 
			
		||||
        self._phone_number = phonenumbers.parse(raw_number, country_code)
 | 
			
		||||
        super(PhoneNumber, self).__init__(
 | 
			
		||||
            country_code=self._phone_number.country_code,
 | 
			
		||||
            national_number=self._phone_number.national_number,
 | 
			
		||||
            extension=self._phone_number.extension,
 | 
			
		||||
            italian_leading_zero=self._phone_number.italian_leading_zero,
 | 
			
		||||
            raw_input=self._phone_number.raw_input,
 | 
			
		||||
            country_code_source=self._phone_number.country_code_source,
 | 
			
		||||
            preferred_domestic_carrier_code=
 | 
			
		||||
            self._phone_number.preferred_domestic_carrier_code
 | 
			
		||||
        )
 | 
			
		||||
        self.national = phonenumbers.format_number(
 | 
			
		||||
            self._phone_number,
 | 
			
		||||
            phonenumbers.PhoneNumberFormat.NATIONAL
 | 
			
		||||
        )
 | 
			
		||||
        self.international = phonenumbers.format_number(
 | 
			
		||||
            self._phone_number,
 | 
			
		||||
            phonenumbers.PhoneNumberFormat.INTERNATIONAL
 | 
			
		||||
        )
 | 
			
		||||
        self.e164 = phonenumbers.format_number(
 | 
			
		||||
            self._phone_number,
 | 
			
		||||
            phonenumbers.PhoneNumberFormat.E164
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def is_valid_number(self):
 | 
			
		||||
        return phonenumbers.is_valid_number(self._phone_number)
 | 
			
		||||
 | 
			
		||||
    def __unicode__(self):
 | 
			
		||||
        return self.national
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return six.text_type(self.national).encode('utf-8')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PhoneNumberType(types.TypeDecorator):
 | 
			
		||||
    """
 | 
			
		||||
    Changes PhoneNumber objects to a string representation on the way in and
 | 
			
		||||
    changes them back to PhoneNumber objects on the way out. If E164 is used
 | 
			
		||||
    as storing format, no country code is needed for parsing the database
 | 
			
		||||
    value to PhoneNumber object.
 | 
			
		||||
    """
 | 
			
		||||
    STORE_FORMAT = 'e164'
 | 
			
		||||
    impl = types.Unicode(20)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, country_code='US', max_length=20, *args, **kwargs):
 | 
			
		||||
        super(PhoneNumberType, self).__init__(*args, **kwargs)
 | 
			
		||||
        self.country_code = country_code
 | 
			
		||||
        self.impl = types.Unicode(max_length)
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return getattr(value, self.STORE_FORMAT)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value:
 | 
			
		||||
            return PhoneNumber(value, self.country_code)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def coercion_listener(self, target, value, oldvalue, initiator):
 | 
			
		||||
        if value is not None and not isinstance(value, PhoneNumber):
 | 
			
		||||
            value = PhoneNumber(value, country_code=self.country_code)
 | 
			
		||||
        return value
 | 
			
		||||
							
								
								
									
										38
									
								
								sqlalchemy_utils/types/scalar_list.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								sqlalchemy_utils/types/scalar_list.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
import six
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScalarListException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScalarListType(types.TypeDecorator):
 | 
			
		||||
    impl = sa.UnicodeText()
 | 
			
		||||
 | 
			
		||||
    def __init__(self, coerce_func=six.text_type, separator=u','):
 | 
			
		||||
        self.separator = six.text_type(separator)
 | 
			
		||||
        self.coerce_func = coerce_func
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        # Convert list of values to unicode separator-separated list
 | 
			
		||||
        # Example: [1, 2, 3, 4] -> u'1, 2, 3, 4'
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            if any(self.separator in six.text_type(item) for item in value):
 | 
			
		||||
                raise ScalarListException(
 | 
			
		||||
                    "List values can't contain string '%s' (its being used as "
 | 
			
		||||
                    "separator. If you wish for scalar list values to contain "
 | 
			
		||||
                    "these strings, use a different separator string.)"
 | 
			
		||||
                )
 | 
			
		||||
            return self.separator.join(
 | 
			
		||||
                map(six.text_type, value)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            if value == u'':
 | 
			
		||||
                return []
 | 
			
		||||
            # coerce each value
 | 
			
		||||
            return list(map(
 | 
			
		||||
                self.coerce_func, value.split(self.separator)
 | 
			
		||||
            ))
 | 
			
		||||
							
								
								
									
										0
									
								
								sqlalchemy_utils/types/slug_type.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sqlalchemy_utils/types/slug_type.py
									
									
									
									
									
										Normal file
									
								
							@@ -1,3 +1,4 @@
 | 
			
		||||
import warnings
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import create_engine
 | 
			
		||||
@@ -18,12 +19,17 @@ def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
 | 
			
		||||
        conn.query_count = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
warnings.simplefilter('error', sa.exc.SAWarning)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCase(object):
 | 
			
		||||
    dns = 'sqlite:///:memory:'
 | 
			
		||||
 | 
			
		||||
    def setup_method(self, method):
 | 
			
		||||
        self.engine = create_engine('sqlite:///:memory:')
 | 
			
		||||
        self.engine = create_engine(self.dns)
 | 
			
		||||
        self.connection = self.engine.connect()
 | 
			
		||||
        self.Base = declarative_base()
 | 
			
		||||
        self.Base2 = declarative_base()
 | 
			
		||||
 | 
			
		||||
        self.create_models()
 | 
			
		||||
        self.Base.metadata.create_all(self.connection)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										29
									
								
								tests/test_ip_address.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/test_ip_address.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
import ipaddress
 | 
			
		||||
import six
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from sqlalchemy_utils import IPAddressType
 | 
			
		||||
from tests import TestCase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestIPAddressType(TestCase):
 | 
			
		||||
    def create_models(self):
 | 
			
		||||
        class Visitor(self.Base):
 | 
			
		||||
            __tablename__ = 'document'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            ip_address = sa.Column(IPAddressType)
 | 
			
		||||
 | 
			
		||||
            def __repr__(self):
 | 
			
		||||
                return 'Visitor(%r)' % self.id
 | 
			
		||||
 | 
			
		||||
        self.Visitor = Visitor
 | 
			
		||||
 | 
			
		||||
    def test_parameter_processing(self):
 | 
			
		||||
        visitor = self.Visitor(
 | 
			
		||||
            ip_address=ipaddress.ip_address(u'111.111.111.111')
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.session.add(visitor)
 | 
			
		||||
        self.session.commit()
 | 
			
		||||
 | 
			
		||||
        visitor = self.session.query(self.Visitor).first()
 | 
			
		||||
        assert six.text_type(visitor.ip_address) == u'111.111.111.111'
 | 
			
		||||
@@ -17,6 +17,8 @@ class TestDeferExcept(TestCase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestFindNonIndexedForeignKeys(TestCase):
 | 
			
		||||
    dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
 | 
			
		||||
 | 
			
		||||
    def create_models(self):
 | 
			
		||||
        class User(self.Base):
 | 
			
		||||
            __tablename__ = 'user'
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user