python-keystoneclient/keystoneclient/tests/unit/test_session.py

1126 lines
41 KiB
Python

# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import argparse
import itertools
import logging
import uuid
import mock
from oslo_config import cfg
from oslo_config import fixture as config
from oslo_serialization import jsonutils
import requests
import six
from testtools import matchers
from keystoneclient import adapter
from keystoneclient.auth import base
from keystoneclient import exceptions
from keystoneclient.i18n import _
from keystoneclient import session as client_session
from keystoneclient.tests.unit import utils
class SessionTests(utils.TestCase):
TEST_URL = 'http://127.0.0.1:5000/'
def setUp(self):
super(SessionTests, self).setUp()
self.deprecations.expect_deprecations()
def test_get(self):
session = client_session.Session()
self.stub_url('GET', text='response')
resp = session.get(self.TEST_URL)
self.assertEqual('GET', self.requests_mock.last_request.method)
self.assertEqual(resp.text, 'response')
self.assertTrue(resp.ok)
def test_post(self):
session = client_session.Session()
self.stub_url('POST', text='response')
resp = session.post(self.TEST_URL, json={'hello': 'world'})
self.assertEqual('POST', self.requests_mock.last_request.method)
self.assertEqual(resp.text, 'response')
self.assertTrue(resp.ok)
self.assertRequestBodyIs(json={'hello': 'world'})
def test_head(self):
session = client_session.Session()
self.stub_url('HEAD')
resp = session.head(self.TEST_URL)
self.assertEqual('HEAD', self.requests_mock.last_request.method)
self.assertTrue(resp.ok)
self.assertRequestBodyIs('')
def test_put(self):
session = client_session.Session()
self.stub_url('PUT', text='response')
resp = session.put(self.TEST_URL, json={'hello': 'world'})
self.assertEqual('PUT', self.requests_mock.last_request.method)
self.assertEqual(resp.text, 'response')
self.assertTrue(resp.ok)
self.assertRequestBodyIs(json={'hello': 'world'})
def test_delete(self):
session = client_session.Session()
self.stub_url('DELETE', text='response')
resp = session.delete(self.TEST_URL)
self.assertEqual('DELETE', self.requests_mock.last_request.method)
self.assertTrue(resp.ok)
self.assertEqual(resp.text, 'response')
def test_patch(self):
session = client_session.Session()
self.stub_url('PATCH', text='response')
resp = session.patch(self.TEST_URL, json={'hello': 'world'})
self.assertEqual('PATCH', self.requests_mock.last_request.method)
self.assertTrue(resp.ok)
self.assertEqual(resp.text, 'response')
self.assertRequestBodyIs(json={'hello': 'world'})
def test_user_agent(self):
session = client_session.Session(user_agent='test-agent')
self.stub_url('GET', text='response')
resp = session.get(self.TEST_URL)
self.assertTrue(resp.ok)
self.assertRequestHeaderEqual('User-Agent', 'test-agent')
resp = session.get(self.TEST_URL, headers={'User-Agent': 'new-agent'})
self.assertTrue(resp.ok)
self.assertRequestHeaderEqual('User-Agent', 'new-agent')
resp = session.get(self.TEST_URL, headers={'User-Agent': 'new-agent'},
user_agent='overrides-agent')
self.assertTrue(resp.ok)
self.assertRequestHeaderEqual('User-Agent', 'overrides-agent')
def test_http_session_opts(self):
session = client_session.Session(cert='cert.pem', timeout=5,
verify='certs')
FAKE_RESP = utils.test_response(text='resp')
RESP = mock.Mock(return_value=FAKE_RESP)
with mock.patch.object(session.session, 'request', RESP) as mocked:
session.post(self.TEST_URL, data='value')
mock_args, mock_kwargs = mocked.call_args
self.assertEqual(mock_args[0], 'POST')
self.assertEqual(mock_args[1], self.TEST_URL)
self.assertEqual(mock_kwargs['data'], 'value')
self.assertEqual(mock_kwargs['cert'], 'cert.pem')
self.assertEqual(mock_kwargs['verify'], 'certs')
self.assertEqual(mock_kwargs['timeout'], 5)
def test_not_found(self):
session = client_session.Session()
self.stub_url('GET', status_code=404)
self.assertRaises(exceptions.NotFound, session.get, self.TEST_URL)
def test_server_error(self):
session = client_session.Session()
self.stub_url('GET', status_code=500)
self.assertRaises(exceptions.InternalServerError,
session.get, self.TEST_URL)
def test_session_debug_output(self):
"""Test request and response headers in debug logs.
in order to redact secure headers while debug is true.
"""
session = client_session.Session(verify=False)
headers = {'HEADERA': 'HEADERVALB',
'Content-Type': 'application/json'}
security_headers = {'Authorization': uuid.uuid4().hex,
'X-Auth-Token': uuid.uuid4().hex,
'X-Subject-Token': uuid.uuid4().hex,
'X-Service-Token': uuid.uuid4().hex}
body = '{"a": "b"}'
data = '{"c": "d"}'
all_headers = dict(
itertools.chain(headers.items(), security_headers.items()))
self.stub_url('POST', text=body, headers=all_headers)
resp = session.post(self.TEST_URL, headers=all_headers, data=data)
self.assertEqual(resp.status_code, 200)
self.assertIn('curl', self.logger.output)
self.assertIn('POST', self.logger.output)
self.assertIn('--insecure', self.logger.output)
self.assertIn(body, self.logger.output)
self.assertIn("'%s'" % data, self.logger.output)
for k, v in headers.items():
self.assertIn(k, self.logger.output)
self.assertIn(v, self.logger.output)
# Assert that response headers contains actual values and
# only debug logs has been masked
for k, v in security_headers.items():
self.assertIn('%s: {SHA1}' % k, self.logger.output)
self.assertEqual(v, resp.headers[k])
self.assertNotIn(v, self.logger.output)
def test_logs_failed_output(self):
"""Test that output is logged even for failed requests."""
session = client_session.Session()
body = {uuid.uuid4().hex: uuid.uuid4().hex}
self.stub_url('GET', json=body, status_code=400,
headers={'Content-Type': 'application/json'})
resp = session.get(self.TEST_URL, raise_exc=False)
self.assertEqual(resp.status_code, 400)
self.assertIn(list(body.keys())[0], self.logger.output)
self.assertIn(list(body.values())[0], self.logger.output)
def test_logging_body_only_for_specified_content_types(self):
"""Verify response body is only logged in specific content types.
Response bodies are logged only when the response's Content-Type header
is set to application/json. This prevents us to get an unexpected
MemoryError when reading arbitrary responses, such as streams.
"""
OMITTED_BODY = ('Omitted, Content-Type is set to %s. Only '
'application/json responses have their bodies logged.')
session = client_session.Session(verify=False)
# Content-Type is not set
body = jsonutils.dumps({'token': {'id': '...'}})
self.stub_url('POST', text=body)
session.post(self.TEST_URL)
self.assertNotIn(body, self.logger.output)
self.assertIn(OMITTED_BODY % None, self.logger.output)
# Content-Type is set to text/xml
body = '<token><id>...</id></token>'
self.stub_url('POST', text=body, headers={'Content-Type': 'text/xml'})
session.post(self.TEST_URL)
self.assertNotIn(body, self.logger.output)
self.assertIn(OMITTED_BODY % 'text/xml', self.logger.output)
# Content-Type is set to application/json
body = jsonutils.dumps({'token': {'id': '...'}})
self.stub_url('POST', text=body,
headers={'Content-Type': 'application/json'})
session.post(self.TEST_URL)
self.assertIn(body, self.logger.output)
self.assertNotIn(OMITTED_BODY % 'application/json', self.logger.output)
# Content-Type is set to application/json; charset=UTF-8
body = jsonutils.dumps({'token': {'id': '...'}})
self.stub_url(
'POST', text=body,
headers={'Content-Type': 'application/json; charset=UTF-8'})
session.post(self.TEST_URL)
self.assertIn(body, self.logger.output)
self.assertNotIn(OMITTED_BODY % 'application/json; charset=UTF-8',
self.logger.output)
def test_unicode_data_in_debug_output(self):
"""Verify that ascii-encodable data is logged without modification."""
session = client_session.Session(verify=False)
body = 'RESP'
data = u'αβγδ'
self.stub_url('POST', text=body)
session.post(self.TEST_URL, data=data)
self.assertIn("'%s'" % data, self.logger.output)
def test_binary_data_not_in_debug_output(self):
"""Verify that non-ascii-encodable data causes replacement."""
if six.PY2:
data = "my data" + chr(255)
else:
# Python 3 logging handles binary data well.
return
session = client_session.Session(verify=False)
body = 'RESP'
self.stub_url('POST', text=body)
# Forced mixed unicode and byte strings in request
# elements to make sure that all joins are appropriately
# handled (any join of unicode and byte strings should
# raise a UnicodeDecodeError)
session.post(six.text_type(self.TEST_URL), data=data)
self.assertNotIn('my data', self.logger.output)
def test_logging_cacerts(self):
path_to_certs = '/path/to/certs'
session = client_session.Session(verify=path_to_certs)
self.stub_url('GET', text='text')
session.get(self.TEST_URL)
self.assertIn('--cacert', self.logger.output)
self.assertIn(path_to_certs, self.logger.output)
def test_connect_retries(self):
def _timeout_error(request, context):
raise requests.exceptions.Timeout()
self.stub_url('GET', text=_timeout_error)
session = client_session.Session()
retries = 3
with mock.patch('time.sleep') as m:
self.assertRaises(exceptions.RequestTimeout,
session.get,
self.TEST_URL, connect_retries=retries)
self.assertEqual(retries, m.call_count)
# 3 retries finishing with 2.0 means 0.5, 1.0 and 2.0
m.assert_called_with(2.0)
# we count retries so there will be one initial request + 3 retries
self.assertThat(self.requests_mock.request_history,
matchers.HasLength(retries + 1))
def test_uses_tcp_keepalive_by_default(self):
session = client_session.Session()
requests_session = session.session
self.assertIsInstance(requests_session.adapters['http://'],
client_session.TCPKeepAliveAdapter)
self.assertIsInstance(requests_session.adapters['https://'],
client_session.TCPKeepAliveAdapter)
def test_does_not_set_tcp_keepalive_on_custom_sessions(self):
mock_session = mock.Mock()
client_session.Session(session=mock_session)
self.assertFalse(mock_session.mount.called)
def test_ssl_error_message(self):
error = uuid.uuid4().hex
def _ssl_error(request, context):
raise requests.exceptions.SSLError(error)
self.stub_url('GET', text=_ssl_error)
session = client_session.Session()
# The exception should contain the URL and details about the SSL error
msg = _('SSL exception connecting to %(url)s: %(error)s') % {
'url': self.TEST_URL, 'error': error}
six.assertRaisesRegex(self,
exceptions.SSLError,
msg,
session.get,
self.TEST_URL)
def test_mask_password_in_http_log_response(self):
session = client_session.Session()
def fake_debug(msg):
self.assertNotIn('verybadpass', msg)
logger = mock.Mock(isEnabledFor=mock.Mock(return_value=True))
logger.debug = mock.Mock(side_effect=fake_debug)
body = {
"connection_info": {
"driver_volume_type": "iscsi",
"data": {
"auth_password": "verybadpass",
"target_discovered": False,
"encrypted": False,
"qos_specs": None,
"target_iqn": ("iqn.2010-10.org.openstack:volume-"
"744d2085-8e78-40a5-8659-ef3cffb2480e"),
"target_portal": "172.99.69.228:3260",
"volume_id": "744d2085-8e78-40a5-8659-ef3cffb2480e",
"target_lun": 1,
"access_mode": "rw",
"auth_username": "verybadusername",
"auth_method": "CHAP"}}}
body_json = jsonutils.dumps(body)
response = mock.Mock(text=body_json, status_code=200,
headers={'content-type': 'application/json'})
session._http_log_response(response, logger)
self.assertEqual(1, logger.debug.call_count)
class TCPKeepAliveAdapter(utils.TestCase):
@mock.patch.object(client_session, 'socket')
@mock.patch('requests.adapters.HTTPAdapter.init_poolmanager')
def test_init_poolmanager_all_options(self, mock_parent_init_poolmanager,
mock_socket):
# properties expected to be in socket.
mock_socket.TCP_KEEPIDLE = mock.sentinel.TCP_KEEPIDLE
mock_socket.TCP_KEEPCNT = mock.sentinel.TCP_KEEPCNT
mock_socket.TCP_KEEPINTVL = mock.sentinel.TCP_KEEPINTVL
desired_opts = [mock_socket.TCP_KEEPIDLE, mock_socket.TCP_KEEPCNT,
mock_socket.TCP_KEEPINTVL]
adapter = client_session.TCPKeepAliveAdapter()
adapter.init_poolmanager()
call_args, call_kwargs = mock_parent_init_poolmanager.call_args
called_socket_opts = call_kwargs['socket_options']
call_options = [opt for (protocol, opt, value) in called_socket_opts]
for opt in desired_opts:
self.assertIn(opt, call_options)
@mock.patch.object(client_session, 'socket')
@mock.patch('requests.adapters.HTTPAdapter.init_poolmanager')
def test_init_poolmanager(self, mock_parent_init_poolmanager, mock_socket):
spec = ['IPPROTO_TCP', 'TCP_NODELAY', 'SOL_SOCKET', 'SO_KEEPALIVE']
mock_socket.mock_add_spec(spec)
adapter = client_session.TCPKeepAliveAdapter()
adapter.init_poolmanager()
call_args, call_kwargs = mock_parent_init_poolmanager.call_args
called_socket_opts = call_kwargs['socket_options']
call_options = [opt for (protocol, opt, value) in called_socket_opts]
self.assertEqual([mock_socket.TCP_NODELAY, mock_socket.SO_KEEPALIVE],
call_options)
class RedirectTests(utils.TestCase):
REDIRECT_CHAIN = ['http://myhost:3445/',
'http://anotherhost:6555/',
'http://thirdhost/',
'http://finaldestination:55/']
DEFAULT_REDIRECT_BODY = 'Redirect'
DEFAULT_RESP_BODY = 'Found'
def setUp(self):
super(RedirectTests, self).setUp()
self.deprecations.expect_deprecations()
def setup_redirects(self, method='GET', status_code=305,
redirect_kwargs=None, final_kwargs=None):
redirect_kwargs = redirect_kwargs or {}
final_kwargs = final_kwargs or {}
redirect_kwargs.setdefault('text', self.DEFAULT_REDIRECT_BODY)
for s, d in zip(self.REDIRECT_CHAIN, self.REDIRECT_CHAIN[1:]):
self.requests_mock.register_uri(method, s, status_code=status_code,
headers={'Location': d},
**redirect_kwargs)
final_kwargs.setdefault('status_code', 200)
final_kwargs.setdefault('text', self.DEFAULT_RESP_BODY)
self.requests_mock.register_uri(method, self.REDIRECT_CHAIN[-1],
**final_kwargs)
def assertResponse(self, resp):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, self.DEFAULT_RESP_BODY)
def test_basic_get(self):
session = client_session.Session()
self.setup_redirects()
resp = session.get(self.REDIRECT_CHAIN[-2])
self.assertResponse(resp)
def test_basic_post_keeps_correct_method(self):
session = client_session.Session()
self.setup_redirects(method='POST', status_code=301)
resp = session.post(self.REDIRECT_CHAIN[-2])
self.assertResponse(resp)
def test_redirect_forever(self):
session = client_session.Session(redirect=True)
self.setup_redirects()
resp = session.get(self.REDIRECT_CHAIN[0])
self.assertResponse(resp)
self.assertTrue(len(resp.history), len(self.REDIRECT_CHAIN))
def test_no_redirect(self):
session = client_session.Session(redirect=False)
self.setup_redirects()
resp = session.get(self.REDIRECT_CHAIN[0])
self.assertEqual(resp.status_code, 305)
self.assertEqual(resp.url, self.REDIRECT_CHAIN[0])
def test_redirect_limit(self):
self.setup_redirects()
for i in (1, 2):
session = client_session.Session(redirect=i)
resp = session.get(self.REDIRECT_CHAIN[0])
self.assertEqual(resp.status_code, 305)
self.assertEqual(resp.url, self.REDIRECT_CHAIN[i])
self.assertEqual(resp.text, self.DEFAULT_REDIRECT_BODY)
def test_history_matches_requests(self):
self.setup_redirects(status_code=301)
session = client_session.Session(redirect=True)
req_resp = requests.get(self.REDIRECT_CHAIN[0],
allow_redirects=True)
ses_resp = session.get(self.REDIRECT_CHAIN[0])
self.assertEqual(len(req_resp.history), len(ses_resp.history))
for r, s in zip(req_resp.history, ses_resp.history):
self.assertEqual(r.url, s.url)
self.assertEqual(r.status_code, s.status_code)
class ConstructSessionFromArgsTests(utils.TestCase):
KEY = 'keyfile'
CERT = 'certfile'
CACERT = 'cacert-path'
def _s(self, k=None, **kwargs):
k = k or kwargs
with self.deprecations.expect_deprecations_here():
return client_session.Session.construct(k)
def test_verify(self):
self.assertFalse(self._s(insecure=True).verify)
self.assertTrue(self._s(verify=True, insecure=True).verify)
self.assertFalse(self._s(verify=False, insecure=True).verify)
self.assertEqual(self._s(cacert=self.CACERT).verify, self.CACERT)
def test_cert(self):
tup = (self.CERT, self.KEY)
self.assertEqual(self._s(cert=tup).cert, tup)
self.assertEqual(self._s(cert=self.CERT, key=self.KEY).cert, tup)
self.assertIsNone(self._s(key=self.KEY).cert)
def test_pass_through(self):
value = 42 # only a number because timeout needs to be
for key in ['timeout', 'session', 'original_ip', 'user_agent']:
args = {key: value}
self.assertEqual(getattr(self._s(args), key), value)
self.assertNotIn(key, args)
class AuthPlugin(base.BaseAuthPlugin):
"""Very simple debug authentication plugin.
Takes Parameters such that it can throw exceptions at the right times.
"""
TEST_TOKEN = utils.TestCase.TEST_TOKEN
TEST_USER_ID = 'aUser'
TEST_PROJECT_ID = 'aProject'
SERVICE_URLS = {
'identity': {'public': 'http://identity-public:1111/v2.0',
'admin': 'http://identity-admin:1111/v2.0'},
'compute': {'public': 'http://compute-public:2222/v1.0',
'admin': 'http://compute-admin:2222/v1.0'},
'image': {'public': 'http://image-public:3333/v2.0',
'admin': 'http://image-admin:3333/v2.0'}
}
def __init__(self, token=TEST_TOKEN, invalidate=True):
self.token = token
self._invalidate = invalidate
def get_token(self, session):
return self.token
def get_endpoint(self, session, service_type=None, interface=None,
**kwargs):
try:
return self.SERVICE_URLS[service_type][interface]
except (KeyError, AttributeError):
return None
def invalidate(self):
return self._invalidate
def get_user_id(self, session):
return self.TEST_USER_ID
def get_project_id(self, session):
return self.TEST_PROJECT_ID
class CalledAuthPlugin(base.BaseAuthPlugin):
ENDPOINT = 'http://fakeendpoint/'
def __init__(self, invalidate=True):
self.get_token_called = False
self.get_endpoint_called = False
self.endpoint_arguments = {}
self.invalidate_called = False
self._invalidate = invalidate
def get_token(self, session):
self.get_token_called = True
return utils.TestCase.TEST_TOKEN
def get_endpoint(self, session, **kwargs):
self.get_endpoint_called = True
self.endpoint_arguments = kwargs
return self.ENDPOINT
def invalidate(self):
self.invalidate_called = True
return self._invalidate
class SessionAuthTests(utils.TestCase):
TEST_URL = 'http://127.0.0.1:5000/'
TEST_JSON = {'hello': 'world'}
def setUp(self):
super(SessionAuthTests, self).setUp()
self.deprecations.expect_deprecations()
def stub_service_url(self, service_type, interface, path,
method='GET', **kwargs):
base_url = AuthPlugin.SERVICE_URLS[service_type][interface]
uri = "%s/%s" % (base_url.rstrip('/'), path.lstrip('/'))
self.requests_mock.register_uri(method, uri, **kwargs)
def test_auth_plugin_default_with_plugin(self):
self.stub_url('GET', base_url=self.TEST_URL, json=self.TEST_JSON)
# if there is an auth_plugin then it should default to authenticated
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
resp = sess.get(self.TEST_URL)
self.assertEqual(resp.json(), self.TEST_JSON)
self.assertRequestHeaderEqual('X-Auth-Token', AuthPlugin.TEST_TOKEN)
def test_auth_plugin_disable(self):
self.stub_url('GET', base_url=self.TEST_URL, json=self.TEST_JSON)
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
resp = sess.get(self.TEST_URL, authenticated=False)
self.assertEqual(resp.json(), self.TEST_JSON)
self.assertRequestHeaderEqual('X-Auth-Token', None)
def test_service_type_urls(self):
service_type = 'compute'
interface = 'public'
path = '/instances'
status = 200
body = 'SUCCESS'
self.stub_service_url(service_type=service_type,
interface=interface,
path=path,
status_code=status,
text=body)
sess = client_session.Session(auth=AuthPlugin())
resp = sess.get(path,
endpoint_filter={'service_type': service_type,
'interface': interface})
self.assertEqual(self.requests_mock.last_request.url,
AuthPlugin.SERVICE_URLS['compute']['public'] + path)
self.assertEqual(resp.text, body)
self.assertEqual(resp.status_code, status)
def test_service_url_raises_if_no_auth_plugin(self):
sess = client_session.Session()
self.assertRaises(exceptions.MissingAuthPlugin,
sess.get, '/path',
endpoint_filter={'service_type': 'compute',
'interface': 'public'})
def test_service_url_raises_if_no_url_returned(self):
sess = client_session.Session(auth=AuthPlugin())
self.assertRaises(exceptions.EndpointNotFound,
sess.get, '/path',
endpoint_filter={'service_type': 'unknown',
'interface': 'public'})
def test_raises_exc_only_when_asked(self):
# A request that returns a HTTP error should by default raise an
# exception by default, if you specify raise_exc=False then it will not
self.requests_mock.get(self.TEST_URL, status_code=401)
sess = client_session.Session()
self.assertRaises(exceptions.Unauthorized, sess.get, self.TEST_URL)
resp = sess.get(self.TEST_URL, raise_exc=False)
self.assertEqual(401, resp.status_code)
def test_passed_auth_plugin(self):
passed = CalledAuthPlugin()
sess = client_session.Session()
self.requests_mock.get(CalledAuthPlugin.ENDPOINT + 'path',
status_code=200)
endpoint_filter = {'service_type': 'identity'}
# no plugin with authenticated won't work
self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
authenticated=True)
# no plugin with an endpoint filter won't work
self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
authenticated=False, endpoint_filter=endpoint_filter)
resp = sess.get('path', auth=passed, endpoint_filter=endpoint_filter)
self.assertEqual(200, resp.status_code)
self.assertTrue(passed.get_endpoint_called)
self.assertTrue(passed.get_token_called)
def test_passed_auth_plugin_overrides(self):
fixed = CalledAuthPlugin()
passed = CalledAuthPlugin()
sess = client_session.Session(fixed)
self.requests_mock.get(CalledAuthPlugin.ENDPOINT + 'path',
status_code=200)
resp = sess.get('path', auth=passed,
endpoint_filter={'service_type': 'identity'})
self.assertEqual(200, resp.status_code)
self.assertTrue(passed.get_endpoint_called)
self.assertTrue(passed.get_token_called)
self.assertFalse(fixed.get_endpoint_called)
self.assertFalse(fixed.get_token_called)
def test_requests_auth_plugin(self):
sess = client_session.Session()
requests_auth = object()
FAKE_RESP = utils.test_response(text='resp')
RESP = mock.Mock(return_value=FAKE_RESP)
with mock.patch.object(sess.session, 'request', RESP) as mocked:
sess.get(self.TEST_URL, requests_auth=requests_auth)
mocked.assert_called_once_with('GET', self.TEST_URL,
headers=mock.ANY,
allow_redirects=mock.ANY,
auth=requests_auth,
verify=mock.ANY)
def test_reauth_called(self):
auth = CalledAuthPlugin(invalidate=True)
sess = client_session.Session(auth=auth)
self.requests_mock.get(self.TEST_URL,
[{'text': 'Failed', 'status_code': 401},
{'text': 'Hello', 'status_code': 200}])
# allow_reauth=True is the default
resp = sess.get(self.TEST_URL, authenticated=True)
self.assertEqual(200, resp.status_code)
self.assertEqual('Hello', resp.text)
self.assertTrue(auth.invalidate_called)
def test_reauth_not_called(self):
auth = CalledAuthPlugin(invalidate=True)
sess = client_session.Session(auth=auth)
self.requests_mock.get(self.TEST_URL,
[{'text': 'Failed', 'status_code': 401},
{'text': 'Hello', 'status_code': 200}])
self.assertRaises(exceptions.Unauthorized, sess.get, self.TEST_URL,
authenticated=True, allow_reauth=False)
self.assertFalse(auth.invalidate_called)
def test_endpoint_override_overrides_filter(self):
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
override_base = 'http://mytest/'
path = 'path'
override_url = override_base + path
resp_text = uuid.uuid4().hex
self.requests_mock.get(override_url, text=resp_text)
resp = sess.get(path,
endpoint_override=override_base,
endpoint_filter={'service_type': 'identity'})
self.assertEqual(resp_text, resp.text)
self.assertEqual(override_url, self.requests_mock.last_request.url)
self.assertTrue(auth.get_token_called)
self.assertFalse(auth.get_endpoint_called)
def test_endpoint_override_ignore_full_url(self):
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
path = 'path'
url = self.TEST_URL + path
resp_text = uuid.uuid4().hex
self.requests_mock.get(url, text=resp_text)
resp = sess.get(url,
endpoint_override='http://someother.url',
endpoint_filter={'service_type': 'identity'})
self.assertEqual(resp_text, resp.text)
self.assertEqual(url, self.requests_mock.last_request.url)
self.assertTrue(auth.get_token_called)
self.assertFalse(auth.get_endpoint_called)
def test_user_and_project_id(self):
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
self.assertEqual(auth.TEST_USER_ID, sess.get_user_id())
self.assertEqual(auth.TEST_PROJECT_ID, sess.get_project_id())
def test_logger_object_passed(self):
logger = logging.getLogger(uuid.uuid4().hex)
logger.setLevel(logging.DEBUG)
logger.propagate = False
io = six.StringIO()
handler = logging.StreamHandler(io)
logger.addHandler(handler)
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
response = {uuid.uuid4().hex: uuid.uuid4().hex}
self.stub_url('GET',
json=response,
headers={'Content-Type': 'application/json'})
resp = sess.get(self.TEST_URL, logger=logger)
self.assertEqual(response, resp.json())
output = io.getvalue()
self.assertIn(self.TEST_URL, output)
self.assertIn(list(response.keys())[0], output)
self.assertIn(list(response.values())[0], output)
self.assertNotIn(list(response.keys())[0], self.logger.output)
self.assertNotIn(list(response.values())[0], self.logger.output)
class AdapterTest(utils.TestCase):
SERVICE_TYPE = uuid.uuid4().hex
SERVICE_NAME = uuid.uuid4().hex
INTERFACE = uuid.uuid4().hex
REGION_NAME = uuid.uuid4().hex
USER_AGENT = uuid.uuid4().hex
VERSION = uuid.uuid4().hex
TEST_URL = CalledAuthPlugin.ENDPOINT
def setUp(self):
super(AdapterTest, self).setUp()
self.deprecations.expect_deprecations()
def _create_loaded_adapter(self):
auth = CalledAuthPlugin()
sess = client_session.Session()
return adapter.Adapter(sess,
auth=auth,
service_type=self.SERVICE_TYPE,
service_name=self.SERVICE_NAME,
interface=self.INTERFACE,
region_name=self.REGION_NAME,
user_agent=self.USER_AGENT,
version=self.VERSION)
def _verify_endpoint_called(self, adpt):
self.assertEqual(self.SERVICE_TYPE,
adpt.auth.endpoint_arguments['service_type'])
self.assertEqual(self.SERVICE_NAME,
adpt.auth.endpoint_arguments['service_name'])
self.assertEqual(self.INTERFACE,
adpt.auth.endpoint_arguments['interface'])
self.assertEqual(self.REGION_NAME,
adpt.auth.endpoint_arguments['region_name'])
self.assertEqual(self.VERSION,
adpt.auth.endpoint_arguments['version'])
def test_setting_variables_on_request(self):
response = uuid.uuid4().hex
self.stub_url('GET', text=response)
adpt = self._create_loaded_adapter()
resp = adpt.get('/')
self.assertEqual(resp.text, response)
self._verify_endpoint_called(adpt)
self.assertTrue(adpt.auth.get_token_called)
self.assertRequestHeaderEqual('User-Agent', self.USER_AGENT)
def test_setting_variables_on_get_endpoint(self):
adpt = self._create_loaded_adapter()
url = adpt.get_endpoint()
self.assertEqual(self.TEST_URL, url)
self._verify_endpoint_called(adpt)
def test_legacy_binding(self):
key = uuid.uuid4().hex
val = uuid.uuid4().hex
response = jsonutils.dumps({key: val})
self.stub_url('GET', text=response)
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.LegacyJsonAdapter(sess,
service_type=self.SERVICE_TYPE,
user_agent=self.USER_AGENT)
resp, body = adpt.get('/')
self.assertEqual(self.SERVICE_TYPE,
auth.endpoint_arguments['service_type'])
self.assertEqual(resp.text, response)
self.assertEqual(val, body[key])
def test_legacy_binding_non_json_resp(self):
response = uuid.uuid4().hex
self.stub_url('GET', text=response,
headers={'Content-Type': 'text/html'})
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.LegacyJsonAdapter(sess,
service_type=self.SERVICE_TYPE,
user_agent=self.USER_AGENT)
resp, body = adpt.get('/')
self.assertEqual(self.SERVICE_TYPE,
auth.endpoint_arguments['service_type'])
self.assertEqual(resp.text, response)
self.assertIsNone(body)
def test_methods(self):
sess = client_session.Session()
adpt = adapter.Adapter(sess)
url = 'http://url'
for method in ['get', 'head', 'post', 'put', 'patch', 'delete']:
with mock.patch.object(adpt, 'request') as m:
getattr(adpt, method)(url)
m.assert_called_once_with(url, method.upper())
def test_setting_endpoint_override(self):
endpoint_override = 'http://overrideurl'
path = '/path'
endpoint_url = endpoint_override + path
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.Adapter(sess, endpoint_override=endpoint_override)
response = uuid.uuid4().hex
self.requests_mock.get(endpoint_url, text=response)
resp = adpt.get(path)
self.assertEqual(response, resp.text)
self.assertEqual(endpoint_url, self.requests_mock.last_request.url)
self.assertEqual(endpoint_override, adpt.get_endpoint())
def test_adapter_invalidate(self):
auth = CalledAuthPlugin()
sess = client_session.Session()
adpt = adapter.Adapter(sess, auth=auth)
adpt.invalidate()
self.assertTrue(auth.invalidate_called)
def test_adapter_get_token(self):
auth = CalledAuthPlugin()
sess = client_session.Session()
adpt = adapter.Adapter(sess, auth=auth)
self.assertEqual(self.TEST_TOKEN, adpt.get_token())
self.assertTrue(auth.get_token_called)
def test_adapter_connect_retries(self):
retries = 2
sess = client_session.Session()
adpt = adapter.Adapter(sess, connect_retries=retries)
def _refused_error(request, context):
raise requests.exceptions.ConnectionError()
self.stub_url('GET', text=_refused_error)
with mock.patch('time.sleep') as m:
self.assertRaises(exceptions.ConnectionRefused,
adpt.get, self.TEST_URL)
self.assertEqual(retries, m.call_count)
# we count retries so there will be one initial request + 2 retries
self.assertThat(self.requests_mock.request_history,
matchers.HasLength(retries + 1))
def test_user_and_project_id(self):
auth = AuthPlugin()
sess = client_session.Session()
adpt = adapter.Adapter(sess, auth=auth)
self.assertEqual(auth.TEST_USER_ID, adpt.get_user_id())
self.assertEqual(auth.TEST_PROJECT_ID, adpt.get_project_id())
def test_logger_object_passed(self):
logger = logging.getLogger(uuid.uuid4().hex)
logger.setLevel(logging.DEBUG)
logger.propagate = False
io = six.StringIO()
handler = logging.StreamHandler(io)
logger.addHandler(handler)
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.Adapter(sess, auth=auth, logger=logger)
response = {uuid.uuid4().hex: uuid.uuid4().hex}
self.stub_url('GET', json=response,
headers={'Content-Type': 'application/json'})
resp = adpt.get(self.TEST_URL, logger=logger)
self.assertEqual(response, resp.json())
output = io.getvalue()
self.assertIn(self.TEST_URL, output)
self.assertIn(list(response.keys())[0], output)
self.assertIn(list(response.values())[0], output)
self.assertNotIn(list(response.keys())[0], self.logger.output)
self.assertNotIn(list(response.values())[0], self.logger.output)
class ConfLoadingTests(utils.TestCase):
GROUP = 'sessiongroup'
def setUp(self):
super(ConfLoadingTests, self).setUp()
self.conf_fixture = self.useFixture(config.Config())
client_session.Session.register_conf_options(self.conf_fixture.conf,
self.GROUP)
def config(self, **kwargs):
kwargs['group'] = self.GROUP
self.conf_fixture.config(**kwargs)
def get_session(self, **kwargs):
with self.deprecations.expect_deprecations_here():
return client_session.Session.load_from_conf_options(
self.conf_fixture.conf,
self.GROUP,
**kwargs)
def test_insecure_timeout(self):
self.config(insecure=True, timeout=5)
s = self.get_session()
self.assertFalse(s.verify)
self.assertEqual(5, s.timeout)
def test_client_certs(self):
cert = '/path/to/certfile'
key = '/path/to/keyfile'
self.config(certfile=cert, keyfile=key)
s = self.get_session()
self.assertTrue(s.verify)
self.assertEqual((cert, key), s.cert)
def test_cacert(self):
cafile = '/path/to/cacert'
self.config(cafile=cafile)
s = self.get_session()
self.assertEqual(cafile, s.verify)
def test_deprecated(self):
def new_deprecated():
return cfg.DeprecatedOpt(uuid.uuid4().hex, group=uuid.uuid4().hex)
opt_names = ['cafile', 'certfile', 'keyfile', 'insecure', 'timeout']
depr = dict([(n, [new_deprecated()]) for n in opt_names])
opts = client_session.Session.get_conf_options(deprecated_opts=depr)
self.assertThat(opt_names, matchers.HasLength(len(opts)))
for opt in opts:
self.assertIn(depr[opt.name][0], opt.deprecated_opts)
class CliLoadingTests(utils.TestCase):
def setUp(self):
super(CliLoadingTests, self).setUp()
self.parser = argparse.ArgumentParser()
client_session.Session.register_cli_options(self.parser)
def get_session(self, val, **kwargs):
args = self.parser.parse_args(val.split())
with self.deprecations.expect_deprecations_here():
return client_session.Session.load_from_cli_options(args, **kwargs)
def test_insecure_timeout(self):
s = self.get_session('--insecure --timeout 5.5')
self.assertFalse(s.verify)
self.assertEqual(5.5, s.timeout)
def test_client_certs(self):
cert = '/path/to/certfile'
key = '/path/to/keyfile'
s = self.get_session('--os-cert %s --os-key %s' % (cert, key))
self.assertTrue(s.verify)
self.assertEqual((cert, key), s.cert)
def test_cacert(self):
cacert = '/path/to/cacert'
s = self.get_session('--os-cacert %s' % cacert)
self.assertEqual(cacert, s.verify)