From e59fe8503d8fe442064c384ac8e48c3bde8304c9 Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Wed, 28 May 2014 12:10:21 -0500 Subject: [PATCH] Move auth_provider validation to attr setter --- cassandra/cluster.py | 59 ++++++++++++++-------- tests/integration/standard/test_cluster.py | 12 ++++- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index dd8ab463..293a3c94 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -198,18 +198,41 @@ class Cluster(object): Setting this to :const:`False` disables compression. """ - auth_provider = None - """ - When :attr:`~.Cluster.protocol_version` is 2 or higher, this should - be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`, - such ass :class:`~.PlainTextAuthProvider`. + _auth_provider = None + _auth_provider_callable = None - When :attr:`~.Cluster.protocol_version` is 1, this should be - a function that accepts one argument, the IP address of a node, - and returns a dict of credentials for that node. + @property + def auth_provider(self): + """ + When :attr:`~.Cluster.protocol_version` is 2 or higher, this should + be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`, + such ass :class:`~.PlainTextAuthProvider`. - When not using authentication, this should be left as :const:`None`. - """ + When :attr:`~.Cluster.protocol_version` is 1, this should be + a function that accepts one argument, the IP address of a node, + and returns a dict of credentials for that node. + + When not using authentication, this should be left as :const:`None`. + """ + return self._auth_provider + + @auth_provider.setter # noqa + def auth_provider(self, value): + if not value: + self._auth_provider = value + return + + try: + self._auth_provider_callable = value.new_authenticator + except AttributeError: + if self.protocol_version > 1: + raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " + "interface when protocol_version >= 2") + elif not callable(value): + raise TypeError("auth_provider must be callable when protocol_version == 1") + self._auth_provider_callable = value + + self._auth_provider = value load_balancing_policy = None """ @@ -339,15 +362,8 @@ class Cluster(object): self.contact_points = contact_points self.port = port self.compression = compression - - if auth_provider is not None: - if not hasattr(auth_provider, 'new_authenticator'): - if protocol_version > 1: - raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " - "interface when protocol_version >= 2") - self.auth_provider = auth_provider - else: - self.auth_provider = auth_provider.new_authenticator + self.protocol_version = protocol_version + self.auth_provider = auth_provider if load_balancing_policy is not None: if isinstance(load_balancing_policy, type): @@ -381,7 +397,6 @@ class Cluster(object): self.ssl_options = ssl_options self.sockopts = sockopts self.cql_version = cql_version - self.protocol_version = protocol_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout @@ -492,8 +507,8 @@ class Cluster(object): return partial(self.connection_class.factory, host.address, *args, **kwargs) def _make_connection_kwargs(self, address, kwargs_dict): - if self.auth_provider: - kwargs_dict['authenticator'] = self.auth_provider(address) + if self._auth_provider_callable: + kwargs_dict['authenticator'] = self._auth_provider_callable(address) kwargs_dict['port'] = self.port kwargs_dict['compression'] = self.compression diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index e3037a14..28ef0009 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -117,8 +117,18 @@ class ClusterTests(unittest.TestCase): """ Ensure that auth_providers are always callable """ + self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1) + c = Cluster(protocol_version=1) + self.assertRaises(TypeError, setattr, c, 'auth_provider', 1) - self.assertRaises(ValueError, Cluster, auth_provider=1) + def test_v2_auth_provider(self): + """ + Check for v2 auth_provider compliance + """ + bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'} + self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2) + c = Cluster(protocol_version=2) + self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider) def test_conviction_policy_factory_is_callable(self): """