diff --git a/setup.py b/setup.py index 443a5f9..76905af 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ extras_require = { ], 'password': ['passlib >= 1.6, < 2.0'], 'color': ['colour>=0.0.4'], + 'ipaddress': ['ipaddr'] if not PY3 else [] } diff --git a/sqlalchemy_utils/types/ip_address.py b/sqlalchemy_utils/types/ip_address.py index 8065762..237ada2 100644 --- a/sqlalchemy_utils/types/ip_address.py +++ b/sqlalchemy_utils/types/ip_address.py @@ -1,45 +1,43 @@ import six -ipaddress = None try: - import ipaddress -except: - pass + from ipaddress import ip_address + +except ImportError: + try: + from ipaddr import IPAddress as ip_address + + except ImportError: + ip_address = None + + from sqlalchemy import types from sqlalchemy_utils import ImproperlyConfigured class IPAddressType(types.TypeDecorator): """ - Changes Color objects to a string representation on the way in and - changes them back to Color objects on the way out. + Changes IPAddress objects to a string representation on the way in and + changes them back to IPAddress objects on the way out. """ + impl = types.Unicode(50) def __init__(self, max_length=50, *args, **kwargs): - if not ipaddress: + if not ip_address: raise ImproperlyConfigured( - "'ipaddress' package is required to use 'IPAddressType'" + "'ipaddr' package is required to use 'IPAddressType' " + "in python 2" ) super(IPAddressType, self).__init__(*args, **kwargs) self.impl = types.Unicode(max_length) def process_bind_param(self, value, dialect): - if value: - return six.text_type(value) - return value + return six.text_type(value) if value else None def process_result_value(self, value, dialect): - if value: - return ipaddress.ip_address(value) - return value + return ip_address(value) if value else None def coercion_listener(self, target, value, oldvalue, initiator): - if ( - value is not None and - not isinstance(value, ipaddress.IPv4Address) and - not isinstance(value, ipaddress.IPv6Address) - ): - value = ipaddress.ip_address(value) - return value + return ip_address(value) if value else None diff --git a/tests/test_ip_address.py b/tests/test_ip_address.py index da3c13c..a3725ba 100644 --- a/tests/test_ip_address.py +++ b/tests/test_ip_address.py @@ -5,7 +5,7 @@ from sqlalchemy_utils.types import ip_address from tests import TestCase -@mark.skipif('ip_address.ipaddress is None') +@mark.skipif('ip_address.ip_address is None') class TestIPAddressType(TestCase): def create_models(self): class Visitor(self.Base):