Touch up IPAddressType; make work in python 2.

This commit is contained in:
Ryan Leckey
2013-07-26 00:24:10 -07:00
parent 40d1d80144
commit df5d6d1880
3 changed files with 21 additions and 22 deletions

View File

@@ -42,6 +42,7 @@ extras_require = {
],
'password': ['passlib >= 1.6, < 2.0'],
'color': ['colour>=0.0.4'],
'ipaddress': ['ipaddr'] if not PY3 else []
}

View File

@@ -1,45 +1,43 @@
import six
ipaddress = None
try:
import ipaddress
except:
pass
from ipaddress import ip_address
except ImportError:
try:
from ipaddr import IPAddress as ip_address
except ImportError:
ip_address = None
from sqlalchemy import types
from sqlalchemy_utils import ImproperlyConfigured
class IPAddressType(types.TypeDecorator):
"""
Changes Color objects to a string representation on the way in and
changes them back to Color objects on the way out.
Changes IPAddress objects to a string representation on the way in and
changes them back to IPAddress objects on the way out.
"""
impl = types.Unicode(50)
def __init__(self, max_length=50, *args, **kwargs):
if not ipaddress:
if not ip_address:
raise ImproperlyConfigured(
"'ipaddress' package is required to use 'IPAddressType'"
"'ipaddr' package is required to use 'IPAddressType' "
"in python 2"
)
super(IPAddressType, self).__init__(*args, **kwargs)
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
if value:
return six.text_type(value)
return value
return six.text_type(value) if value else None
def process_result_value(self, value, dialect):
if value:
return ipaddress.ip_address(value)
return value
return ip_address(value) if value else None
def coercion_listener(self, target, value, oldvalue, initiator):
if (
value is not None and
not isinstance(value, ipaddress.IPv4Address) and
not isinstance(value, ipaddress.IPv6Address)
):
value = ipaddress.ip_address(value)
return value
return ip_address(value) if value else None

View File

@@ -5,7 +5,7 @@ from sqlalchemy_utils.types import ip_address
from tests import TestCase
@mark.skipif('ip_address.ipaddress is None')
@mark.skipif('ip_address.ip_address is None')
class TestIPAddressType(TestCase):
def create_models(self):
class Visitor(self.Base):