Move auth_provider validation to attr setter
This commit is contained in:
@@ -198,7 +198,11 @@ class Cluster(object):
|
|||||||
Setting this to :const:`False` disables compression.
|
Setting this to :const:`False` disables compression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
auth_provider = None
|
_auth_provider = None
|
||||||
|
_auth_provider_callable = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_provider(self):
|
||||||
"""
|
"""
|
||||||
When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
|
When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
|
||||||
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
|
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
|
||||||
@@ -210,6 +214,25 @@ class Cluster(object):
|
|||||||
|
|
||||||
When not using authentication, this should be left as :const:`None`.
|
When not using authentication, this should be left as :const:`None`.
|
||||||
"""
|
"""
|
||||||
|
return self._auth_provider
|
||||||
|
|
||||||
|
@auth_provider.setter # noqa
|
||||||
|
def auth_provider(self, value):
|
||||||
|
if not value:
|
||||||
|
self._auth_provider = value
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._auth_provider_callable = value.new_authenticator
|
||||||
|
except AttributeError:
|
||||||
|
if self.protocol_version > 1:
|
||||||
|
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
|
||||||
|
"interface when protocol_version >= 2")
|
||||||
|
elif not callable(value):
|
||||||
|
raise TypeError("auth_provider must be callable when protocol_version == 1")
|
||||||
|
self._auth_provider_callable = value
|
||||||
|
|
||||||
|
self._auth_provider = value
|
||||||
|
|
||||||
load_balancing_policy = None
|
load_balancing_policy = None
|
||||||
"""
|
"""
|
||||||
@@ -339,15 +362,8 @@ class Cluster(object):
|
|||||||
self.contact_points = contact_points
|
self.contact_points = contact_points
|
||||||
self.port = port
|
self.port = port
|
||||||
self.compression = compression
|
self.compression = compression
|
||||||
|
self.protocol_version = protocol_version
|
||||||
if auth_provider is not None:
|
|
||||||
if not hasattr(auth_provider, 'new_authenticator'):
|
|
||||||
if protocol_version > 1:
|
|
||||||
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
|
|
||||||
"interface when protocol_version >= 2")
|
|
||||||
self.auth_provider = auth_provider
|
self.auth_provider = auth_provider
|
||||||
else:
|
|
||||||
self.auth_provider = auth_provider.new_authenticator
|
|
||||||
|
|
||||||
if load_balancing_policy is not None:
|
if load_balancing_policy is not None:
|
||||||
if isinstance(load_balancing_policy, type):
|
if isinstance(load_balancing_policy, type):
|
||||||
@@ -381,7 +397,6 @@ class Cluster(object):
|
|||||||
self.ssl_options = ssl_options
|
self.ssl_options = ssl_options
|
||||||
self.sockopts = sockopts
|
self.sockopts = sockopts
|
||||||
self.cql_version = cql_version
|
self.cql_version = cql_version
|
||||||
self.protocol_version = protocol_version
|
|
||||||
self.max_schema_agreement_wait = max_schema_agreement_wait
|
self.max_schema_agreement_wait = max_schema_agreement_wait
|
||||||
self.control_connection_timeout = control_connection_timeout
|
self.control_connection_timeout = control_connection_timeout
|
||||||
|
|
||||||
@@ -492,8 +507,8 @@ class Cluster(object):
|
|||||||
return partial(self.connection_class.factory, host.address, *args, **kwargs)
|
return partial(self.connection_class.factory, host.address, *args, **kwargs)
|
||||||
|
|
||||||
def _make_connection_kwargs(self, address, kwargs_dict):
|
def _make_connection_kwargs(self, address, kwargs_dict):
|
||||||
if self.auth_provider:
|
if self._auth_provider_callable:
|
||||||
kwargs_dict['authenticator'] = self.auth_provider(address)
|
kwargs_dict['authenticator'] = self._auth_provider_callable(address)
|
||||||
|
|
||||||
kwargs_dict['port'] = self.port
|
kwargs_dict['port'] = self.port
|
||||||
kwargs_dict['compression'] = self.compression
|
kwargs_dict['compression'] = self.compression
|
||||||
|
|||||||
@@ -117,8 +117,18 @@ class ClusterTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Ensure that auth_providers are always callable
|
Ensure that auth_providers are always callable
|
||||||
"""
|
"""
|
||||||
|
self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1)
|
||||||
|
c = Cluster(protocol_version=1)
|
||||||
|
self.assertRaises(TypeError, setattr, c, 'auth_provider', 1)
|
||||||
|
|
||||||
self.assertRaises(ValueError, Cluster, auth_provider=1)
|
def test_v2_auth_provider(self):
|
||||||
|
"""
|
||||||
|
Check for v2 auth_provider compliance
|
||||||
|
"""
|
||||||
|
bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'}
|
||||||
|
self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2)
|
||||||
|
c = Cluster(protocol_version=2)
|
||||||
|
self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider)
|
||||||
|
|
||||||
def test_conviction_policy_factory_is_callable(self):
|
def test_conviction_policy_factory_is_callable(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user