Refactor passing params to requests in K8s client

Seems like there are options to set default values for several
parameters in requests session. This commit attempts to leverage that
by globally setting certificates, SSL verification, token header and
timeout.

Change-Id: Ieecc14cef94f1678a935f23affa6ca37e3de4a91
This commit is contained in:
Michał Dulko 2020-09-30 12:56:57 +02:00
parent 44890a84d5
commit 663194a152
2 changed files with 44 additions and 79 deletions

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import contextlib import contextlib
import functools
import itertools import itertools
import os import os
import ssl import ssl
@ -78,12 +79,16 @@ class K8sClient(object):
else: else:
self.verify_server = ca_crt_file self.verify_server = ca_crt_file
self._rq_params = { # Let's setup defaults for our Session.
'cert': self.cert, self.session.cert = self.cert
'verify': self.verify_server, self.session.verify = self.verify_server
'timeout': (CONF.kubernetes.watch_connection_timeout, if self.token:
CONF.kubernetes.watch_read_timeout), self.session.headers['Authorization'] = f'Bearer {self.token}'
} # NOTE(dulek): Seems like this is the only way to set is globally.
self.session.request = functools.partial(
self.session.request, timeout=(
CONF.kubernetes.watch_connection_timeout,
CONF.kubernetes.watch_read_timeout))
def _raise_from_response(self, response): def _raise_from_response(self, response):
if response.status_code == requests.codes.not_found: if response.status_code == requests.codes.not_found:
@ -108,12 +113,7 @@ class K8sClient(object):
def get(self, path, json=True, headers=None): def get(self, path, json=True, headers=None):
LOG.debug("Get %(path)s", {'path': path}) LOG.debug("Get %(path)s", {'path': path})
url = self._base_url + path url = self._base_url + path
header = {} response = self.session.get(url, headers=headers)
if self.token:
header.update({'Authorization': 'Bearer %s' % self.token})
if headers:
header.update(headers)
response = self.session.get(url, headers=header, **self._rq_params)
self._raise_from_response(response) self._raise_from_response(response)
result = response.json() if json else response.text result = response.json() if json else response.text
return result return result
@ -122,18 +122,14 @@ class K8sClient(object):
url = self._base_url + path url = self._base_url + path
header = {'Content-Type': content_type, header = {'Content-Type': content_type,
'Accept': 'application/json'} 'Accept': 'application/json'}
if self.token:
header.update({'Authorization': 'Bearer %s' % self.token})
return url, header return url, header
def patch(self, field, path, data): def patch(self, field, path, data):
LOG.debug("Patch %(path)s: %(data)s", { LOG.debug("Patch %(path)s: %(data)s", {'path': path, 'data': data})
'path': path, 'data': data})
content_type = 'application/merge-patch+json' content_type = 'application/merge-patch+json'
url, header = self._get_url_and_header(path, content_type) url, header = self._get_url_and_header(path, content_type)
response = self.session.patch(url, json={field: data}, response = self.session.patch(url, json={field: data}, headers=header)
headers=header, **self._rq_params)
self._raise_from_response(response) self._raise_from_response(response)
return response.json().get('status') return response.json().get('status')
@ -159,7 +155,7 @@ class K8sClient(object):
'path': path, 'data': data}) 'path': path, 'data': data})
response = self.session.patch(url, data=jsonutils.dumps(data), response = self.session.patch(url, data=jsonutils.dumps(data),
headers=header, **self._rq_params) headers=header)
self._raise_from_response(response) self._raise_from_response(response)
return response.json().get('status') return response.json().get('status')
@ -174,7 +170,7 @@ class K8sClient(object):
'value': value}] 'value': value}]
response = self.session.patch(url, data=jsonutils.dumps(data), response = self.session.patch(url, data=jsonutils.dumps(data),
headers=header, **self._rq_params) headers=header)
self._raise_from_response(response) self._raise_from_response(response)
return response.json().get('status') return response.json().get('status')
@ -187,7 +183,7 @@ class K8sClient(object):
'path': '/metadata/annotations/{}'.format(annotation_name)}] 'path': '/metadata/annotations/{}'.format(annotation_name)}]
response = self.session.patch(url, data=jsonutils.dumps(data), response = self.session.patch(url, data=jsonutils.dumps(data),
headers=header, **self._rq_params) headers=header)
self._raise_from_response(response) self._raise_from_response(response)
return response.json().get('status') return response.json().get('status')
@ -206,7 +202,7 @@ class K8sClient(object):
data = [{'op': 'remove', data = [{'op': 'remove',
'path': f'/metadata/annotations/{annotation_name}'}] 'path': f'/metadata/annotations/{annotation_name}'}]
response = self.session.patch(url, data=jsonutils.dumps(data), response = self.session.patch(url, data=jsonutils.dumps(data),
headers=header, **self._rq_params) headers=header)
if response.ok: if response.ok:
return response.json().get('status') return response.json().get('status')
raise exc.K8sClientException(response.text) raise exc.K8sClientException(response.text)
@ -215,11 +211,8 @@ class K8sClient(object):
LOG.debug("Post %(path)s: %(body)s", {'path': path, 'body': body}) LOG.debug("Post %(path)s: %(body)s", {'path': path, 'body': body})
url = self._base_url + path url = self._base_url + path
header = {'Content-Type': 'application/json'} header = {'Content-Type': 'application/json'}
if self.token:
header.update({'Authorization': 'Bearer %s' % self.token})
response = self.session.post(url, json=body, headers=header, response = self.session.post(url, json=body, headers=header)
**self._rq_params)
self._raise_from_response(response) self._raise_from_response(response)
return response.json() return response.json()
@ -227,10 +220,8 @@ class K8sClient(object):
LOG.debug("Delete %(path)s", {'path': path}) LOG.debug("Delete %(path)s", {'path': path})
url = self._base_url + path url = self._base_url + path
header = {'Content-Type': 'application/json'} header = {'Content-Type': 'application/json'}
if self.token:
header.update({'Authorization': 'Bearer %s' % self.token})
response = self.session.delete(url, headers=header, **self._rq_params) response = self.session.delete(url, headers=header)
self._raise_from_response(response) self._raise_from_response(response)
return response.json() return response.json()
@ -256,8 +247,7 @@ class K8sClient(object):
}, },
} }
response = self.session.patch(url, json=data, headers=headers, response = self.session.patch(url, json=data, headers=headers)
**self._rq_params)
if response.ok: if response.ok:
return True return True
@ -297,8 +287,7 @@ class K8sClient(object):
}, },
} }
response = self.session.patch(url, json=data, headers=headers, response = self.session.patch(url, json=data, headers=headers)
**self._rq_params)
if response.ok: if response.ok:
return True return True
@ -349,8 +338,7 @@ class K8sClient(object):
if resource_version: if resource_version:
metadata['resourceVersion'] = resource_version metadata['resourceVersion'] = resource_version
data = jsonutils.dumps({"metadata": metadata}, sort_keys=True) data = jsonutils.dumps({"metadata": metadata}, sort_keys=True)
response = self.session.patch(url, data=data, response = self.session.patch(url, data=data, headers=header)
headers=header, **self._rq_params)
if response.ok: if response.ok:
return response.json()['metadata'].get('annotations', {}) return response.json()['metadata'].get('annotations', {})
if response.status_code == requests.codes.conflict: if response.status_code == requests.codes.conflict:
@ -381,9 +369,6 @@ class K8sClient(object):
def watch(self, path): def watch(self, path):
url = self._base_url + path url = self._base_url + path
resource_version = None resource_version = None
header = {}
if self.token:
header.update({'Authorization': 'Bearer %s' % self.token})
attempt = 0 attempt = 0
while True: while True:
@ -393,8 +378,7 @@ class K8sClient(object):
params['resourceVersion'] = resource_version params['resourceVersion'] = resource_version
with contextlib.closing( with contextlib.closing(
self.session.get( self.session.get(
url, params=params, stream=True, headers=header, url, params=params, stream=True)) as response:
**self._rq_params)) as response:
if not response.ok: if not response.ok:
raise exc.K8sClientException(response.text) raise exc.K8sClientException(response.text)
attempt = 0 attempt = 0

View File

@ -74,9 +74,9 @@ class TestK8sClient(test_base.TestCase):
m_exist.return_value = True m_exist.return_value = True
self.assertRaises(RuntimeError, k8s_client.K8sClient, self.base_url) self.assertRaises(RuntimeError, k8s_client.K8sClient, self.base_url)
@mock.patch('requests.sessions.Session.get') @mock.patch('requests.sessions.Session.send')
@mock.patch('kuryr_kubernetes.config.CONF') @mock.patch('kuryr_kubernetes.config.CONF')
def test_bearer_token(self, m_cfg, m_get): def test_bearer_token(self, m_cfg, m_send):
token_content = ( token_content = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3Nl" "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3Nl"
"cnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc" "cnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc"
@ -102,11 +102,9 @@ class TestK8sClient(test_base.TestCase):
path = '/test' path = '/test'
client = k8s_client.K8sClient(self.base_url) client = k8s_client.K8sClient(self.base_url)
client.get(path) client.get(path)
headers = {
'Authorization': 'Bearer {}'.format(token_content)} self.assertEqual(f'Bearer {token_content}',
m_get.assert_called_once_with( m_send.call_args[0][0].headers['Authorization'])
self.base_url + path, cert=(None, None), headers=headers,
verify=False, timeout=(30, 60))
finally: finally:
os.unlink(m_cfg.kubernetes.token_file) os.unlink(m_cfg.kubernetes.token_file)
@ -121,9 +119,7 @@ class TestK8sClient(test_base.TestCase):
m_get.return_value = m_resp m_get.return_value = m_resp
self.assertEqual(ret, self.client.get(path)) self.assertEqual(ret, self.client.get(path))
m_get.assert_called_once_with( m_get.assert_called_once_with(self.base_url + path, headers=None)
self.base_url + path,
cert=(None, None), headers={}, verify=False, timeout=(30, 60))
@mock.patch('requests.sessions.Session.get') @mock.patch('requests.sessions.Session.get')
def test_get_exception(self, m_get): def test_get_exception(self, m_get):
@ -154,9 +150,7 @@ class TestK8sClient(test_base.TestCase):
self.assertEqual(annotations, self.client.annotate( self.assertEqual(annotations, self.client.annotate(
path, annotations, resource_version=resource_version)) path, annotations, resource_version=resource_version))
m_patch.assert_called_once_with(self.base_url + path, m_patch.assert_called_once_with(self.base_url + path,
data=data, headers=mock.ANY, data=data, headers=mock.ANY)
cert=(None, None), verify=False,
timeout=(30, 60))
@mock.patch('itertools.count') @mock.patch('itertools.count')
@mock.patch('requests.sessions.Session.patch') @mock.patch('requests.sessions.Session.patch')
@ -203,8 +197,7 @@ class TestK8sClient(test_base.TestCase):
m_patch.assert_has_calls([ m_patch.assert_has_calls([
mock.call(self.base_url + path, mock.call(self.base_url + path,
data=conflicting_data, data=conflicting_data,
headers=mock.ANY, headers=mock.ANY)])
cert=(None, None), verify=False, timeout=(30, 60))])
@mock.patch('itertools.count') @mock.patch('itertools.count')
@mock.patch('requests.sessions.Session.patch') @mock.patch('requests.sessions.Session.patch')
@ -243,12 +236,10 @@ class TestK8sClient(test_base.TestCase):
m_patch.assert_has_calls([ m_patch.assert_has_calls([
mock.call(self.base_url + path, mock.call(self.base_url + path,
data=annotating_data, data=annotating_data,
headers=mock.ANY, headers=mock.ANY),
cert=(None, None), verify=False, timeout=(30, 60)),
mock.call(self.base_url + path, mock.call(self.base_url + path,
data=resolution_data, data=resolution_data,
headers=mock.ANY, headers=mock.ANY)])
cert=(None, None), verify=False, timeout=(30, 60))])
@mock.patch('itertools.count') @mock.patch('itertools.count')
@mock.patch('requests.sessions.Session.patch') @mock.patch('requests.sessions.Session.patch')
@ -287,12 +278,10 @@ class TestK8sClient(test_base.TestCase):
m_patch.assert_has_calls([ m_patch.assert_has_calls([
mock.call(self.base_url + path, mock.call(self.base_url + path,
data=conflicting_data, data=conflicting_data,
headers=mock.ANY, headers=mock.ANY),
cert=(None, None), verify=False, timeout=(30, 60)),
mock.call(self.base_url + path, mock.call(self.base_url + path,
data=good_data, data=good_data,
headers=mock.ANY, headers=mock.ANY)])
cert=(None, None), verify=False, timeout=(30, 60))])
@mock.patch('itertools.count') @mock.patch('itertools.count')
@mock.patch('requests.sessions.Session.patch') @mock.patch('requests.sessions.Session.patch')
@ -318,9 +307,7 @@ class TestK8sClient(test_base.TestCase):
resource_version=resource_version) resource_version=resource_version)
m_patch.assert_called_once_with(self.base_url + path, m_patch.assert_called_once_with(self.base_url + path,
data=annotate_data, data=annotate_data,
headers=mock.ANY, headers=mock.ANY)
cert=(None, None), verify=False,
timeout=(30, 60))
@mock.patch('requests.sessions.Session.get') @mock.patch('requests.sessions.Session.get')
def test_watch(self, m_get): def test_watch(self, m_get):
@ -341,9 +328,8 @@ class TestK8sClient(test_base.TestCase):
self.assertEqual(cycles, m_get.call_count) self.assertEqual(cycles, m_get.call_count)
self.assertEqual(cycles, m_resp.close.call_count) self.assertEqual(cycles, m_resp.close.call_count)
m_get.assert_called_with(self.base_url + path, headers={}, stream=True, m_get.assert_called_with(self.base_url + path, stream=True,
params={'watch': 'true'}, cert=(None, None), params={'watch': 'true'})
verify=False, timeout=(30, 60))
@mock.patch('requests.sessions.Session.get') @mock.patch('requests.sessions.Session.get')
def test_watch_restart(self, m_get): def test_watch_restart(self, m_get):
@ -364,13 +350,10 @@ class TestK8sClient(test_base.TestCase):
self.assertEqual(3, m_get.call_count) self.assertEqual(3, m_get.call_count)
self.assertEqual(3, m_resp.close.call_count) self.assertEqual(3, m_resp.close.call_count)
m_get.assert_any_call( m_get.assert_any_call(
self.base_url + path, headers={}, stream=True, self.base_url + path, stream=True, params={"watch": "true"})
params={"watch": "true"}, cert=(None, None), verify=False,
timeout=(30, 60))
m_get.assert_any_call( m_get.assert_any_call(
self.base_url + path, headers={}, stream=True, self.base_url + path, stream=True, params={"watch": "true",
params={"watch": "true", "resourceVersion": 2}, cert=(None, None), "resourceVersion": 2})
verify=False, timeout=(30, 60))
@mock.patch('requests.sessions.Session.get') @mock.patch('requests.sessions.Session.get')
def test_watch_exception(self, m_get): def test_watch_exception(self, m_get):
@ -396,8 +379,7 @@ class TestK8sClient(test_base.TestCase):
self.assertEqual(ret, self.client.post(path, body)) self.assertEqual(ret, self.client.post(path, body))
m_post.assert_called_once_with(self.base_url + path, json=body, m_post.assert_called_once_with(self.base_url + path, json=body,
headers=mock.ANY, cert=(None, None), headers=mock.ANY)
verify=False, timeout=(30, 60))
@mock.patch('requests.sessions.Session.post') @mock.patch('requests.sessions.Session.post')
def test_post_exception(self, m_post): def test_post_exception(self, m_post):
@ -423,8 +405,7 @@ class TestK8sClient(test_base.TestCase):
self.assertEqual(ret, self.client.delete(path)) self.assertEqual(ret, self.client.delete(path))
m_delete.assert_called_once_with(self.base_url + path, m_delete.assert_called_once_with(self.base_url + path,
headers=mock.ANY, cert=(None, None), headers=mock.ANY)
verify=False, timeout=(30, 60))
@mock.patch('requests.sessions.Session.delete') @mock.patch('requests.sessions.Session.delete')
def test_delete_exception(self, m_delete): def test_delete_exception(self, m_delete):