Move auth_provider validation to attr setter
This commit is contained in:
@@ -198,7 +198,11 @@ class Cluster(object):
|
||||
Setting this to :const:`False` disables compression.
|
||||
"""
|
||||
|
||||
auth_provider = None
|
||||
_auth_provider = None
|
||||
_auth_provider_callable = None
|
||||
|
||||
@property
|
||||
def auth_provider(self):
|
||||
"""
|
||||
When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
|
||||
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
|
||||
@@ -210,6 +214,25 @@ class Cluster(object):
|
||||
|
||||
When not using authentication, this should be left as :const:`None`.
|
||||
"""
|
||||
return self._auth_provider
|
||||
|
||||
@auth_provider.setter # noqa
|
||||
def auth_provider(self, value):
|
||||
if not value:
|
||||
self._auth_provider = value
|
||||
return
|
||||
|
||||
try:
|
||||
self._auth_provider_callable = value.new_authenticator
|
||||
except AttributeError:
|
||||
if self.protocol_version > 1:
|
||||
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
|
||||
"interface when protocol_version >= 2")
|
||||
elif not callable(value):
|
||||
raise TypeError("auth_provider must be callable when protocol_version == 1")
|
||||
self._auth_provider_callable = value
|
||||
|
||||
self._auth_provider = value
|
||||
|
||||
load_balancing_policy = None
|
||||
"""
|
||||
@@ -339,15 +362,8 @@ class Cluster(object):
|
||||
self.contact_points = contact_points
|
||||
self.port = port
|
||||
self.compression = compression
|
||||
|
||||
if auth_provider is not None:
|
||||
if not hasattr(auth_provider, 'new_authenticator'):
|
||||
if protocol_version > 1:
|
||||
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
|
||||
"interface when protocol_version >= 2")
|
||||
self.protocol_version = protocol_version
|
||||
self.auth_provider = auth_provider
|
||||
else:
|
||||
self.auth_provider = auth_provider.new_authenticator
|
||||
|
||||
if load_balancing_policy is not None:
|
||||
if isinstance(load_balancing_policy, type):
|
||||
@@ -381,7 +397,6 @@ class Cluster(object):
|
||||
self.ssl_options = ssl_options
|
||||
self.sockopts = sockopts
|
||||
self.cql_version = cql_version
|
||||
self.protocol_version = protocol_version
|
||||
self.max_schema_agreement_wait = max_schema_agreement_wait
|
||||
self.control_connection_timeout = control_connection_timeout
|
||||
|
||||
@@ -492,8 +507,8 @@ class Cluster(object):
|
||||
return partial(self.connection_class.factory, host.address, *args, **kwargs)
|
||||
|
||||
def _make_connection_kwargs(self, address, kwargs_dict):
|
||||
if self.auth_provider:
|
||||
kwargs_dict['authenticator'] = self.auth_provider(address)
|
||||
if self._auth_provider_callable:
|
||||
kwargs_dict['authenticator'] = self._auth_provider_callable(address)
|
||||
|
||||
kwargs_dict['port'] = self.port
|
||||
kwargs_dict['compression'] = self.compression
|
||||
|
||||
@@ -117,8 +117,18 @@ class ClusterTests(unittest.TestCase):
|
||||
"""
|
||||
Ensure that auth_providers are always callable
|
||||
"""
|
||||
self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1)
|
||||
c = Cluster(protocol_version=1)
|
||||
self.assertRaises(TypeError, setattr, c, 'auth_provider', 1)
|
||||
|
||||
self.assertRaises(ValueError, Cluster, auth_provider=1)
|
||||
def test_v2_auth_provider(self):
|
||||
"""
|
||||
Check for v2 auth_provider compliance
|
||||
"""
|
||||
bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'}
|
||||
self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2)
|
||||
c = Cluster(protocol_version=2)
|
||||
self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider)
|
||||
|
||||
def test_conviction_policy_factory_is_callable(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user