From 663194a1528d3f9994a1ef6c60d1b19d608184bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dulko?= Date: Wed, 30 Sep 2020 12:56:57 +0200 Subject: [PATCH] Refactor passing params to requests in K8s client Seems like there are options to set default values for several parameters in requests session. This commit attempts to leverage that by globally setting certificates, SSL verification, token header and timeout. Change-Id: Ieecc14cef94f1678a935f23affa6ca37e3de4a91 --- kuryr_kubernetes/k8s_client.py | 64 +++++++------------ .../tests/unit/test_k8s_client.py | 59 ++++++----------- 2 files changed, 44 insertions(+), 79 deletions(-) diff --git a/kuryr_kubernetes/k8s_client.py b/kuryr_kubernetes/k8s_client.py index a17d5b8a8..198cc6d17 100644 --- a/kuryr_kubernetes/k8s_client.py +++ b/kuryr_kubernetes/k8s_client.py @@ -13,6 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. import contextlib +import functools import itertools import os import ssl @@ -78,12 +79,16 @@ class K8sClient(object): else: self.verify_server = ca_crt_file - self._rq_params = { - 'cert': self.cert, - 'verify': self.verify_server, - 'timeout': (CONF.kubernetes.watch_connection_timeout, - CONF.kubernetes.watch_read_timeout), - } + # Let's setup defaults for our Session. + self.session.cert = self.cert + self.session.verify = self.verify_server + if self.token: + self.session.headers['Authorization'] = f'Bearer {self.token}' + # NOTE(dulek): Seems like this is the only way to set is globally. + self.session.request = functools.partial( + self.session.request, timeout=( + CONF.kubernetes.watch_connection_timeout, + CONF.kubernetes.watch_read_timeout)) def _raise_from_response(self, response): if response.status_code == requests.codes.not_found: @@ -108,12 +113,7 @@ class K8sClient(object): def get(self, path, json=True, headers=None): LOG.debug("Get %(path)s", {'path': path}) url = self._base_url + path - header = {} - if self.token: - header.update({'Authorization': 'Bearer %s' % self.token}) - if headers: - header.update(headers) - response = self.session.get(url, headers=header, **self._rq_params) + response = self.session.get(url, headers=headers) self._raise_from_response(response) result = response.json() if json else response.text return result @@ -122,18 +122,14 @@ class K8sClient(object): url = self._base_url + path header = {'Content-Type': content_type, 'Accept': 'application/json'} - if self.token: - header.update({'Authorization': 'Bearer %s' % self.token}) return url, header def patch(self, field, path, data): - LOG.debug("Patch %(path)s: %(data)s", { - 'path': path, 'data': data}) + LOG.debug("Patch %(path)s: %(data)s", {'path': path, 'data': data}) content_type = 'application/merge-patch+json' url, header = self._get_url_and_header(path, content_type) - response = self.session.patch(url, json={field: data}, - headers=header, **self._rq_params) + response = self.session.patch(url, json={field: data}, headers=header) self._raise_from_response(response) return response.json().get('status') @@ -159,7 +155,7 @@ class K8sClient(object): 'path': path, 'data': data}) response = self.session.patch(url, data=jsonutils.dumps(data), - headers=header, **self._rq_params) + headers=header) self._raise_from_response(response) return response.json().get('status') @@ -174,7 +170,7 @@ class K8sClient(object): 'value': value}] response = self.session.patch(url, data=jsonutils.dumps(data), - headers=header, **self._rq_params) + headers=header) self._raise_from_response(response) return response.json().get('status') @@ -187,7 +183,7 @@ class K8sClient(object): 'path': '/metadata/annotations/{}'.format(annotation_name)}] response = self.session.patch(url, data=jsonutils.dumps(data), - headers=header, **self._rq_params) + headers=header) self._raise_from_response(response) return response.json().get('status') @@ -206,7 +202,7 @@ class K8sClient(object): data = [{'op': 'remove', 'path': f'/metadata/annotations/{annotation_name}'}] response = self.session.patch(url, data=jsonutils.dumps(data), - headers=header, **self._rq_params) + headers=header) if response.ok: return response.json().get('status') raise exc.K8sClientException(response.text) @@ -215,11 +211,8 @@ class K8sClient(object): LOG.debug("Post %(path)s: %(body)s", {'path': path, 'body': body}) url = self._base_url + path header = {'Content-Type': 'application/json'} - if self.token: - header.update({'Authorization': 'Bearer %s' % self.token}) - response = self.session.post(url, json=body, headers=header, - **self._rq_params) + response = self.session.post(url, json=body, headers=header) self._raise_from_response(response) return response.json() @@ -227,10 +220,8 @@ class K8sClient(object): LOG.debug("Delete %(path)s", {'path': path}) url = self._base_url + path header = {'Content-Type': 'application/json'} - if self.token: - header.update({'Authorization': 'Bearer %s' % self.token}) - response = self.session.delete(url, headers=header, **self._rq_params) + response = self.session.delete(url, headers=header) self._raise_from_response(response) return response.json() @@ -256,8 +247,7 @@ class K8sClient(object): }, } - response = self.session.patch(url, json=data, headers=headers, - **self._rq_params) + response = self.session.patch(url, json=data, headers=headers) if response.ok: return True @@ -297,8 +287,7 @@ class K8sClient(object): }, } - response = self.session.patch(url, json=data, headers=headers, - **self._rq_params) + response = self.session.patch(url, json=data, headers=headers) if response.ok: return True @@ -349,8 +338,7 @@ class K8sClient(object): if resource_version: metadata['resourceVersion'] = resource_version data = jsonutils.dumps({"metadata": metadata}, sort_keys=True) - response = self.session.patch(url, data=data, - headers=header, **self._rq_params) + response = self.session.patch(url, data=data, headers=header) if response.ok: return response.json()['metadata'].get('annotations', {}) if response.status_code == requests.codes.conflict: @@ -381,9 +369,6 @@ class K8sClient(object): def watch(self, path): url = self._base_url + path resource_version = None - header = {} - if self.token: - header.update({'Authorization': 'Bearer %s' % self.token}) attempt = 0 while True: @@ -393,8 +378,7 @@ class K8sClient(object): params['resourceVersion'] = resource_version with contextlib.closing( self.session.get( - url, params=params, stream=True, headers=header, - **self._rq_params)) as response: + url, params=params, stream=True)) as response: if not response.ok: raise exc.K8sClientException(response.text) attempt = 0 diff --git a/kuryr_kubernetes/tests/unit/test_k8s_client.py b/kuryr_kubernetes/tests/unit/test_k8s_client.py index 273f79243..8f2d5ab7f 100644 --- a/kuryr_kubernetes/tests/unit/test_k8s_client.py +++ b/kuryr_kubernetes/tests/unit/test_k8s_client.py @@ -74,9 +74,9 @@ class TestK8sClient(test_base.TestCase): m_exist.return_value = True self.assertRaises(RuntimeError, k8s_client.K8sClient, self.base_url) - @mock.patch('requests.sessions.Session.get') + @mock.patch('requests.sessions.Session.send') @mock.patch('kuryr_kubernetes.config.CONF') - def test_bearer_token(self, m_cfg, m_get): + def test_bearer_token(self, m_cfg, m_send): token_content = ( "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3Nl" "cnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc" @@ -102,11 +102,9 @@ class TestK8sClient(test_base.TestCase): path = '/test' client = k8s_client.K8sClient(self.base_url) client.get(path) - headers = { - 'Authorization': 'Bearer {}'.format(token_content)} - m_get.assert_called_once_with( - self.base_url + path, cert=(None, None), headers=headers, - verify=False, timeout=(30, 60)) + + self.assertEqual(f'Bearer {token_content}', + m_send.call_args[0][0].headers['Authorization']) finally: os.unlink(m_cfg.kubernetes.token_file) @@ -121,9 +119,7 @@ class TestK8sClient(test_base.TestCase): m_get.return_value = m_resp self.assertEqual(ret, self.client.get(path)) - m_get.assert_called_once_with( - self.base_url + path, - cert=(None, None), headers={}, verify=False, timeout=(30, 60)) + m_get.assert_called_once_with(self.base_url + path, headers=None) @mock.patch('requests.sessions.Session.get') def test_get_exception(self, m_get): @@ -154,9 +150,7 @@ class TestK8sClient(test_base.TestCase): self.assertEqual(annotations, self.client.annotate( path, annotations, resource_version=resource_version)) m_patch.assert_called_once_with(self.base_url + path, - data=data, headers=mock.ANY, - cert=(None, None), verify=False, - timeout=(30, 60)) + data=data, headers=mock.ANY) @mock.patch('itertools.count') @mock.patch('requests.sessions.Session.patch') @@ -203,8 +197,7 @@ class TestK8sClient(test_base.TestCase): m_patch.assert_has_calls([ mock.call(self.base_url + path, data=conflicting_data, - headers=mock.ANY, - cert=(None, None), verify=False, timeout=(30, 60))]) + headers=mock.ANY)]) @mock.patch('itertools.count') @mock.patch('requests.sessions.Session.patch') @@ -243,12 +236,10 @@ class TestK8sClient(test_base.TestCase): m_patch.assert_has_calls([ mock.call(self.base_url + path, data=annotating_data, - headers=mock.ANY, - cert=(None, None), verify=False, timeout=(30, 60)), + headers=mock.ANY), mock.call(self.base_url + path, data=resolution_data, - headers=mock.ANY, - cert=(None, None), verify=False, timeout=(30, 60))]) + headers=mock.ANY)]) @mock.patch('itertools.count') @mock.patch('requests.sessions.Session.patch') @@ -287,12 +278,10 @@ class TestK8sClient(test_base.TestCase): m_patch.assert_has_calls([ mock.call(self.base_url + path, data=conflicting_data, - headers=mock.ANY, - cert=(None, None), verify=False, timeout=(30, 60)), + headers=mock.ANY), mock.call(self.base_url + path, data=good_data, - headers=mock.ANY, - cert=(None, None), verify=False, timeout=(30, 60))]) + headers=mock.ANY)]) @mock.patch('itertools.count') @mock.patch('requests.sessions.Session.patch') @@ -318,9 +307,7 @@ class TestK8sClient(test_base.TestCase): resource_version=resource_version) m_patch.assert_called_once_with(self.base_url + path, data=annotate_data, - headers=mock.ANY, - cert=(None, None), verify=False, - timeout=(30, 60)) + headers=mock.ANY) @mock.patch('requests.sessions.Session.get') def test_watch(self, m_get): @@ -341,9 +328,8 @@ class TestK8sClient(test_base.TestCase): self.assertEqual(cycles, m_get.call_count) self.assertEqual(cycles, m_resp.close.call_count) - m_get.assert_called_with(self.base_url + path, headers={}, stream=True, - params={'watch': 'true'}, cert=(None, None), - verify=False, timeout=(30, 60)) + m_get.assert_called_with(self.base_url + path, stream=True, + params={'watch': 'true'}) @mock.patch('requests.sessions.Session.get') def test_watch_restart(self, m_get): @@ -364,13 +350,10 @@ class TestK8sClient(test_base.TestCase): self.assertEqual(3, m_get.call_count) self.assertEqual(3, m_resp.close.call_count) m_get.assert_any_call( - self.base_url + path, headers={}, stream=True, - params={"watch": "true"}, cert=(None, None), verify=False, - timeout=(30, 60)) + self.base_url + path, stream=True, params={"watch": "true"}) m_get.assert_any_call( - self.base_url + path, headers={}, stream=True, - params={"watch": "true", "resourceVersion": 2}, cert=(None, None), - verify=False, timeout=(30, 60)) + self.base_url + path, stream=True, params={"watch": "true", + "resourceVersion": 2}) @mock.patch('requests.sessions.Session.get') def test_watch_exception(self, m_get): @@ -396,8 +379,7 @@ class TestK8sClient(test_base.TestCase): self.assertEqual(ret, self.client.post(path, body)) m_post.assert_called_once_with(self.base_url + path, json=body, - headers=mock.ANY, cert=(None, None), - verify=False, timeout=(30, 60)) + headers=mock.ANY) @mock.patch('requests.sessions.Session.post') def test_post_exception(self, m_post): @@ -423,8 +405,7 @@ class TestK8sClient(test_base.TestCase): self.assertEqual(ret, self.client.delete(path)) m_delete.assert_called_once_with(self.base_url + path, - headers=mock.ANY, cert=(None, None), - verify=False, timeout=(30, 60)) + headers=mock.ANY) @mock.patch('requests.sessions.Session.delete') def test_delete_exception(self, m_delete):