Move auth_provider validation to attr setter

This commit is contained in:
Tyler Hobbs
2014-05-28 12:10:21 -05:00
parent 33c523b02c
commit e59fe8503d
2 changed files with 48 additions and 23 deletions

View File

@@ -198,7 +198,11 @@ class Cluster(object):
Setting this to :const:`False` disables compression.
"""
auth_provider = None
_auth_provider = None
_auth_provider_callable = None
@property
def auth_provider(self):
"""
When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
@@ -210,6 +214,25 @@ class Cluster(object):
When not using authentication, this should be left as :const:`None`.
"""
return self._auth_provider
@auth_provider.setter # noqa
def auth_provider(self, value):
if not value:
self._auth_provider = value
return
try:
self._auth_provider_callable = value.new_authenticator
except AttributeError:
if self.protocol_version > 1:
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
"interface when protocol_version >= 2")
elif not callable(value):
raise TypeError("auth_provider must be callable when protocol_version == 1")
self._auth_provider_callable = value
self._auth_provider = value
load_balancing_policy = None
"""
@@ -339,15 +362,8 @@ class Cluster(object):
self.contact_points = contact_points
self.port = port
self.compression = compression
if auth_provider is not None:
if not hasattr(auth_provider, 'new_authenticator'):
if protocol_version > 1:
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
"interface when protocol_version >= 2")
self.protocol_version = protocol_version
self.auth_provider = auth_provider
else:
self.auth_provider = auth_provider.new_authenticator
if load_balancing_policy is not None:
if isinstance(load_balancing_policy, type):
@@ -381,7 +397,6 @@ class Cluster(object):
self.ssl_options = ssl_options
self.sockopts = sockopts
self.cql_version = cql_version
self.protocol_version = protocol_version
self.max_schema_agreement_wait = max_schema_agreement_wait
self.control_connection_timeout = control_connection_timeout
@@ -492,8 +507,8 @@ class Cluster(object):
return partial(self.connection_class.factory, host.address, *args, **kwargs)
def _make_connection_kwargs(self, address, kwargs_dict):
if self.auth_provider:
kwargs_dict['authenticator'] = self.auth_provider(address)
if self._auth_provider_callable:
kwargs_dict['authenticator'] = self._auth_provider_callable(address)
kwargs_dict['port'] = self.port
kwargs_dict['compression'] = self.compression

View File

@@ -117,8 +117,18 @@ class ClusterTests(unittest.TestCase):
"""
Ensure that auth_providers are always callable
"""
self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1)
c = Cluster(protocol_version=1)
self.assertRaises(TypeError, setattr, c, 'auth_provider', 1)
self.assertRaises(ValueError, Cluster, auth_provider=1)
def test_v2_auth_provider(self):
"""
Check for v2 auth_provider compliance
"""
bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'}
self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2)
c = Cluster(protocol_version=2)
self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider)
def test_conviction_policy_factory_is_callable(self):
"""