Files
python-keystoneclient/keystoneclient/tests/unit/test_session.py
Tang Chen 09f47ab0ba Use assertEqual() instead of assertDictEqual()
OpenStack now has dropped support for python 2.6 and lower.
So we don't need to define an assertDictEqual() for python
lower than 2.7.

Further, assertEqual() in testtools is able to handle dicts
comparation, just the same as assertDictEqual() in unittest2.
So we do't need to call assertDictEqual() for dicts any more.

Please also refer to: https://review.openstack.org/#/c/347097/

Change-Id: Ieaf211617c38aa0f9a38625b1009c36bd6a16fba
2016-07-26 17:15:22 +08:00

1079 lines
39 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import argparse
import itertools
import logging
import uuid
import mock
from oslo_config import cfg
from oslo_config import fixture as config
from oslo_serialization import jsonutils
import requests
import six
from testtools import matchers
from keystoneclient import adapter
from keystoneclient.auth import base
from keystoneclient import exceptions
from keystoneclient.i18n import _
from keystoneclient import session as client_session
from keystoneclient.tests.unit import utils
class SessionTests(utils.TestCase):
TEST_URL = 'http://127.0.0.1:5000/'
def setUp(self):
super(SessionTests, self).setUp()
self.deprecations.expect_deprecations()
def test_get(self):
session = client_session.Session()
self.stub_url('GET', text='response')
resp = session.get(self.TEST_URL)
self.assertEqual('GET', self.requests_mock.last_request.method)
self.assertEqual(resp.text, 'response')
self.assertTrue(resp.ok)
def test_post(self):
session = client_session.Session()
self.stub_url('POST', text='response')
resp = session.post(self.TEST_URL, json={'hello': 'world'})
self.assertEqual('POST', self.requests_mock.last_request.method)
self.assertEqual(resp.text, 'response')
self.assertTrue(resp.ok)
self.assertRequestBodyIs(json={'hello': 'world'})
def test_head(self):
session = client_session.Session()
self.stub_url('HEAD')
resp = session.head(self.TEST_URL)
self.assertEqual('HEAD', self.requests_mock.last_request.method)
self.assertTrue(resp.ok)
self.assertRequestBodyIs('')
def test_put(self):
session = client_session.Session()
self.stub_url('PUT', text='response')
resp = session.put(self.TEST_URL, json={'hello': 'world'})
self.assertEqual('PUT', self.requests_mock.last_request.method)
self.assertEqual(resp.text, 'response')
self.assertTrue(resp.ok)
self.assertRequestBodyIs(json={'hello': 'world'})
def test_delete(self):
session = client_session.Session()
self.stub_url('DELETE', text='response')
resp = session.delete(self.TEST_URL)
self.assertEqual('DELETE', self.requests_mock.last_request.method)
self.assertTrue(resp.ok)
self.assertEqual(resp.text, 'response')
def test_patch(self):
session = client_session.Session()
self.stub_url('PATCH', text='response')
resp = session.patch(self.TEST_URL, json={'hello': 'world'})
self.assertEqual('PATCH', self.requests_mock.last_request.method)
self.assertTrue(resp.ok)
self.assertEqual(resp.text, 'response')
self.assertRequestBodyIs(json={'hello': 'world'})
def test_user_agent(self):
session = client_session.Session(user_agent='test-agent')
self.stub_url('GET', text='response')
resp = session.get(self.TEST_URL)
self.assertTrue(resp.ok)
self.assertRequestHeaderEqual('User-Agent', 'test-agent')
resp = session.get(self.TEST_URL, headers={'User-Agent': 'new-agent'})
self.assertTrue(resp.ok)
self.assertRequestHeaderEqual('User-Agent', 'new-agent')
resp = session.get(self.TEST_URL, headers={'User-Agent': 'new-agent'},
user_agent='overrides-agent')
self.assertTrue(resp.ok)
self.assertRequestHeaderEqual('User-Agent', 'overrides-agent')
def test_http_session_opts(self):
session = client_session.Session(cert='cert.pem', timeout=5,
verify='certs')
FAKE_RESP = utils.test_response(text='resp')
RESP = mock.Mock(return_value=FAKE_RESP)
with mock.patch.object(session.session, 'request', RESP) as mocked:
session.post(self.TEST_URL, data='value')
mock_args, mock_kwargs = mocked.call_args
self.assertEqual(mock_args[0], 'POST')
self.assertEqual(mock_args[1], self.TEST_URL)
self.assertEqual(mock_kwargs['data'], 'value')
self.assertEqual(mock_kwargs['cert'], 'cert.pem')
self.assertEqual(mock_kwargs['verify'], 'certs')
self.assertEqual(mock_kwargs['timeout'], 5)
def test_not_found(self):
session = client_session.Session()
self.stub_url('GET', status_code=404)
self.assertRaises(exceptions.NotFound, session.get, self.TEST_URL)
def test_server_error(self):
session = client_session.Session()
self.stub_url('GET', status_code=500)
self.assertRaises(exceptions.InternalServerError,
session.get, self.TEST_URL)
def test_session_debug_output(self):
"""Test request and response headers in debug logs.
in order to redact secure headers while debug is true.
"""
session = client_session.Session(verify=False)
headers = {'HEADERA': 'HEADERVALB'}
security_headers = {'Authorization': uuid.uuid4().hex,
'X-Auth-Token': uuid.uuid4().hex,
'X-Subject-Token': uuid.uuid4().hex, }
body = 'BODYRESPONSE'
data = 'BODYDATA'
all_headers = dict(
itertools.chain(headers.items(), security_headers.items()))
self.stub_url('POST', text=body, headers=all_headers)
resp = session.post(self.TEST_URL, headers=all_headers, data=data)
self.assertEqual(resp.status_code, 200)
self.assertIn('curl', self.logger.output)
self.assertIn('POST', self.logger.output)
self.assertIn('--insecure', self.logger.output)
self.assertIn(body, self.logger.output)
self.assertIn("'%s'" % data, self.logger.output)
for k, v in six.iteritems(headers):
self.assertIn(k, self.logger.output)
self.assertIn(v, self.logger.output)
# Assert that response headers contains actual values and
# only debug logs has been masked
for k, v in six.iteritems(security_headers):
self.assertIn('%s: {SHA1}' % k, self.logger.output)
self.assertEqual(v, resp.headers[k])
self.assertNotIn(v, self.logger.output)
def test_logs_failed_output(self):
"""Test that output is logged even for failed requests."""
session = client_session.Session()
body = uuid.uuid4().hex
self.stub_url('GET', text=body, status_code=400)
resp = session.get(self.TEST_URL, raise_exc=False)
self.assertEqual(resp.status_code, 400)
self.assertIn(body, self.logger.output)
def test_unicode_data_in_debug_output(self):
"""Verify that ascii-encodable data is logged without modification."""
session = client_session.Session(verify=False)
body = 'RESP'
data = u'unicode_data'
self.stub_url('POST', text=body)
session.post(self.TEST_URL, data=data)
self.assertIn("'%s'" % data, self.logger.output)
def test_binary_data_not_in_debug_output(self):
"""Verify that non-ascii-encodable data causes replacement."""
if six.PY2:
data = "my data" + chr(255)
else:
# Python 3 logging handles binary data well.
return
session = client_session.Session(verify=False)
body = 'RESP'
self.stub_url('POST', text=body)
# Forced mixed unicode and byte strings in request
# elements to make sure that all joins are appropriately
# handled (any join of unicode and byte strings should
# raise a UnicodeDecodeError)
session.post(unicode(self.TEST_URL), data=data)
self.assertIn("Replaced characters that could not be decoded"
" in log output", self.logger.output)
# Our data payload should have changed to
# include the replacement char
self.assertIn(u"-d 'my data\ufffd'", self.logger.output)
def test_logging_cacerts(self):
path_to_certs = '/path/to/certs'
session = client_session.Session(verify=path_to_certs)
self.stub_url('GET', text='text')
session.get(self.TEST_URL)
self.assertIn('--cacert', self.logger.output)
self.assertIn(path_to_certs, self.logger.output)
def test_connect_retries(self):
def _timeout_error(request, context):
raise requests.exceptions.Timeout()
self.stub_url('GET', text=_timeout_error)
session = client_session.Session()
retries = 3
with mock.patch('time.sleep') as m:
self.assertRaises(exceptions.RequestTimeout,
session.get,
self.TEST_URL, connect_retries=retries)
self.assertEqual(retries, m.call_count)
# 3 retries finishing with 2.0 means 0.5, 1.0 and 2.0
m.assert_called_with(2.0)
# we count retries so there will be one initial request + 3 retries
self.assertThat(self.requests_mock.request_history,
matchers.HasLength(retries + 1))
def test_uses_tcp_keepalive_by_default(self):
session = client_session.Session()
requests_session = session.session
self.assertIsInstance(requests_session.adapters['http://'],
client_session.TCPKeepAliveAdapter)
self.assertIsInstance(requests_session.adapters['https://'],
client_session.TCPKeepAliveAdapter)
def test_does_not_set_tcp_keepalive_on_custom_sessions(self):
mock_session = mock.Mock()
client_session.Session(session=mock_session)
self.assertFalse(mock_session.mount.called)
def test_ssl_error_message(self):
error = uuid.uuid4().hex
def _ssl_error(request, context):
raise requests.exceptions.SSLError(error)
self.stub_url('GET', text=_ssl_error)
session = client_session.Session()
# The exception should contain the URL and details about the SSL error
msg = _('SSL exception connecting to %(url)s: %(error)s') % {
'url': self.TEST_URL, 'error': error}
six.assertRaisesRegex(self,
exceptions.SSLError,
msg,
session.get,
self.TEST_URL)
def test_mask_password_in_http_log_response(self):
session = client_session.Session()
def fake_debug(msg):
self.assertNotIn('verybadpass', msg)
logger = mock.Mock(isEnabledFor=mock.Mock(return_value=True))
logger.debug = mock.Mock(side_effect=fake_debug)
body = {
"connection_info": {
"driver_volume_type": "iscsi",
"data": {
"auth_password": "verybadpass",
"target_discovered": False,
"encrypted": False,
"qos_specs": None,
"target_iqn": ("iqn.2010-10.org.openstack:volume-"
"744d2085-8e78-40a5-8659-ef3cffb2480e"),
"target_portal": "172.99.69.228:3260",
"volume_id": "744d2085-8e78-40a5-8659-ef3cffb2480e",
"target_lun": 1,
"access_mode": "rw",
"auth_username": "verybadusername",
"auth_method": "CHAP"}}}
body_json = jsonutils.dumps(body)
response = mock.Mock(text=body_json, status_code=200, headers={})
session._http_log_response(response, logger)
self.assertEqual(1, logger.debug.call_count)
class TCPKeepAliveAdapter(utils.TestCase):
@mock.patch.object(client_session, 'socket')
@mock.patch('requests.adapters.HTTPAdapter.init_poolmanager')
def test_init_poolmanager_all_options(self, mock_parent_init_poolmanager,
mock_socket):
# properties expected to be in socket.
mock_socket.TCP_KEEPIDLE = mock.sentinel.TCP_KEEPIDLE
mock_socket.TCP_KEEPCNT = mock.sentinel.TCP_KEEPCNT
mock_socket.TCP_KEEPINTVL = mock.sentinel.TCP_KEEPINTVL
desired_opts = [mock_socket.TCP_KEEPIDLE, mock_socket.TCP_KEEPCNT,
mock_socket.TCP_KEEPINTVL]
adapter = client_session.TCPKeepAliveAdapter()
adapter.init_poolmanager()
call_args, call_kwargs = mock_parent_init_poolmanager.call_args
called_socket_opts = call_kwargs['socket_options']
call_options = [opt for (protocol, opt, value) in called_socket_opts]
for opt in desired_opts:
self.assertIn(opt, call_options)
@mock.patch.object(client_session, 'socket')
@mock.patch('requests.adapters.HTTPAdapter.init_poolmanager')
def test_init_poolmanager(self, mock_parent_init_poolmanager, mock_socket):
spec = ['IPPROTO_TCP', 'TCP_NODELAY', 'SOL_SOCKET', 'SO_KEEPALIVE']
mock_socket.mock_add_spec(spec)
adapter = client_session.TCPKeepAliveAdapter()
adapter.init_poolmanager()
call_args, call_kwargs = mock_parent_init_poolmanager.call_args
called_socket_opts = call_kwargs['socket_options']
call_options = [opt for (protocol, opt, value) in called_socket_opts]
self.assertEqual([mock_socket.TCP_NODELAY, mock_socket.SO_KEEPALIVE],
call_options)
class RedirectTests(utils.TestCase):
REDIRECT_CHAIN = ['http://myhost:3445/',
'http://anotherhost:6555/',
'http://thirdhost/',
'http://finaldestination:55/']
DEFAULT_REDIRECT_BODY = 'Redirect'
DEFAULT_RESP_BODY = 'Found'
def setUp(self):
super(RedirectTests, self).setUp()
self.deprecations.expect_deprecations()
def setup_redirects(self, method='GET', status_code=305,
redirect_kwargs=None, final_kwargs=None):
redirect_kwargs = redirect_kwargs or {}
final_kwargs = final_kwargs or {}
redirect_kwargs.setdefault('text', self.DEFAULT_REDIRECT_BODY)
for s, d in zip(self.REDIRECT_CHAIN, self.REDIRECT_CHAIN[1:]):
self.requests_mock.register_uri(method, s, status_code=status_code,
headers={'Location': d},
**redirect_kwargs)
final_kwargs.setdefault('status_code', 200)
final_kwargs.setdefault('text', self.DEFAULT_RESP_BODY)
self.requests_mock.register_uri(method, self.REDIRECT_CHAIN[-1],
**final_kwargs)
def assertResponse(self, resp):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, self.DEFAULT_RESP_BODY)
def test_basic_get(self):
session = client_session.Session()
self.setup_redirects()
resp = session.get(self.REDIRECT_CHAIN[-2])
self.assertResponse(resp)
def test_basic_post_keeps_correct_method(self):
session = client_session.Session()
self.setup_redirects(method='POST', status_code=301)
resp = session.post(self.REDIRECT_CHAIN[-2])
self.assertResponse(resp)
def test_redirect_forever(self):
session = client_session.Session(redirect=True)
self.setup_redirects()
resp = session.get(self.REDIRECT_CHAIN[0])
self.assertResponse(resp)
self.assertTrue(len(resp.history), len(self.REDIRECT_CHAIN))
def test_no_redirect(self):
session = client_session.Session(redirect=False)
self.setup_redirects()
resp = session.get(self.REDIRECT_CHAIN[0])
self.assertEqual(resp.status_code, 305)
self.assertEqual(resp.url, self.REDIRECT_CHAIN[0])
def test_redirect_limit(self):
self.setup_redirects()
for i in (1, 2):
session = client_session.Session(redirect=i)
resp = session.get(self.REDIRECT_CHAIN[0])
self.assertEqual(resp.status_code, 305)
self.assertEqual(resp.url, self.REDIRECT_CHAIN[i])
self.assertEqual(resp.text, self.DEFAULT_REDIRECT_BODY)
def test_history_matches_requests(self):
self.setup_redirects(status_code=301)
session = client_session.Session(redirect=True)
req_resp = requests.get(self.REDIRECT_CHAIN[0],
allow_redirects=True)
ses_resp = session.get(self.REDIRECT_CHAIN[0])
self.assertEqual(len(req_resp.history), len(ses_resp.history))
for r, s in zip(req_resp.history, ses_resp.history):
self.assertEqual(r.url, s.url)
self.assertEqual(r.status_code, s.status_code)
class ConstructSessionFromArgsTests(utils.TestCase):
KEY = 'keyfile'
CERT = 'certfile'
CACERT = 'cacert-path'
def _s(self, k=None, **kwargs):
k = k or kwargs
with self.deprecations.expect_deprecations_here():
return client_session.Session.construct(k)
def test_verify(self):
self.assertFalse(self._s(insecure=True).verify)
self.assertTrue(self._s(verify=True, insecure=True).verify)
self.assertFalse(self._s(verify=False, insecure=True).verify)
self.assertEqual(self._s(cacert=self.CACERT).verify, self.CACERT)
def test_cert(self):
tup = (self.CERT, self.KEY)
self.assertEqual(self._s(cert=tup).cert, tup)
self.assertEqual(self._s(cert=self.CERT, key=self.KEY).cert, tup)
self.assertIsNone(self._s(key=self.KEY).cert)
def test_pass_through(self):
value = 42 # only a number because timeout needs to be
for key in ['timeout', 'session', 'original_ip', 'user_agent']:
args = {key: value}
self.assertEqual(getattr(self._s(args), key), value)
self.assertNotIn(key, args)
class AuthPlugin(base.BaseAuthPlugin):
"""Very simple debug authentication plugin.
Takes Parameters such that it can throw exceptions at the right times.
"""
TEST_TOKEN = utils.TestCase.TEST_TOKEN
TEST_USER_ID = 'aUser'
TEST_PROJECT_ID = 'aProject'
SERVICE_URLS = {
'identity': {'public': 'http://identity-public:1111/v2.0',
'admin': 'http://identity-admin:1111/v2.0'},
'compute': {'public': 'http://compute-public:2222/v1.0',
'admin': 'http://compute-admin:2222/v1.0'},
'image': {'public': 'http://image-public:3333/v2.0',
'admin': 'http://image-admin:3333/v2.0'}
}
def __init__(self, token=TEST_TOKEN, invalidate=True):
self.token = token
self._invalidate = invalidate
def get_token(self, session):
return self.token
def get_endpoint(self, session, service_type=None, interface=None,
**kwargs):
try:
return self.SERVICE_URLS[service_type][interface]
except (KeyError, AttributeError):
return None
def invalidate(self):
return self._invalidate
def get_user_id(self, session):
return self.TEST_USER_ID
def get_project_id(self, session):
return self.TEST_PROJECT_ID
class CalledAuthPlugin(base.BaseAuthPlugin):
ENDPOINT = 'http://fakeendpoint/'
def __init__(self, invalidate=True):
self.get_token_called = False
self.get_endpoint_called = False
self.endpoint_arguments = {}
self.invalidate_called = False
self._invalidate = invalidate
def get_token(self, session):
self.get_token_called = True
return utils.TestCase.TEST_TOKEN
def get_endpoint(self, session, **kwargs):
self.get_endpoint_called = True
self.endpoint_arguments = kwargs
return self.ENDPOINT
def invalidate(self):
self.invalidate_called = True
return self._invalidate
class SessionAuthTests(utils.TestCase):
TEST_URL = 'http://127.0.0.1:5000/'
TEST_JSON = {'hello': 'world'}
def setUp(self):
super(SessionAuthTests, self).setUp()
self.deprecations.expect_deprecations()
def stub_service_url(self, service_type, interface, path,
method='GET', **kwargs):
base_url = AuthPlugin.SERVICE_URLS[service_type][interface]
uri = "%s/%s" % (base_url.rstrip('/'), path.lstrip('/'))
self.requests_mock.register_uri(method, uri, **kwargs)
def test_auth_plugin_default_with_plugin(self):
self.stub_url('GET', base_url=self.TEST_URL, json=self.TEST_JSON)
# if there is an auth_plugin then it should default to authenticated
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
resp = sess.get(self.TEST_URL)
self.assertEqual(resp.json(), self.TEST_JSON)
self.assertRequestHeaderEqual('X-Auth-Token', AuthPlugin.TEST_TOKEN)
def test_auth_plugin_disable(self):
self.stub_url('GET', base_url=self.TEST_URL, json=self.TEST_JSON)
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
resp = sess.get(self.TEST_URL, authenticated=False)
self.assertEqual(resp.json(), self.TEST_JSON)
self.assertRequestHeaderEqual('X-Auth-Token', None)
def test_service_type_urls(self):
service_type = 'compute'
interface = 'public'
path = '/instances'
status = 200
body = 'SUCCESS'
self.stub_service_url(service_type=service_type,
interface=interface,
path=path,
status_code=status,
text=body)
sess = client_session.Session(auth=AuthPlugin())
resp = sess.get(path,
endpoint_filter={'service_type': service_type,
'interface': interface})
self.assertEqual(self.requests_mock.last_request.url,
AuthPlugin.SERVICE_URLS['compute']['public'] + path)
self.assertEqual(resp.text, body)
self.assertEqual(resp.status_code, status)
def test_service_url_raises_if_no_auth_plugin(self):
sess = client_session.Session()
self.assertRaises(exceptions.MissingAuthPlugin,
sess.get, '/path',
endpoint_filter={'service_type': 'compute',
'interface': 'public'})
def test_service_url_raises_if_no_url_returned(self):
sess = client_session.Session(auth=AuthPlugin())
self.assertRaises(exceptions.EndpointNotFound,
sess.get, '/path',
endpoint_filter={'service_type': 'unknown',
'interface': 'public'})
def test_raises_exc_only_when_asked(self):
# A request that returns a HTTP error should by default raise an
# exception by default, if you specify raise_exc=False then it will not
self.requests_mock.get(self.TEST_URL, status_code=401)
sess = client_session.Session()
self.assertRaises(exceptions.Unauthorized, sess.get, self.TEST_URL)
resp = sess.get(self.TEST_URL, raise_exc=False)
self.assertEqual(401, resp.status_code)
def test_passed_auth_plugin(self):
passed = CalledAuthPlugin()
sess = client_session.Session()
self.requests_mock.get(CalledAuthPlugin.ENDPOINT + 'path',
status_code=200)
endpoint_filter = {'service_type': 'identity'}
# no plugin with authenticated won't work
self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
authenticated=True)
# no plugin with an endpoint filter won't work
self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
authenticated=False, endpoint_filter=endpoint_filter)
resp = sess.get('path', auth=passed, endpoint_filter=endpoint_filter)
self.assertEqual(200, resp.status_code)
self.assertTrue(passed.get_endpoint_called)
self.assertTrue(passed.get_token_called)
def test_passed_auth_plugin_overrides(self):
fixed = CalledAuthPlugin()
passed = CalledAuthPlugin()
sess = client_session.Session(fixed)
self.requests_mock.get(CalledAuthPlugin.ENDPOINT + 'path',
status_code=200)
resp = sess.get('path', auth=passed,
endpoint_filter={'service_type': 'identity'})
self.assertEqual(200, resp.status_code)
self.assertTrue(passed.get_endpoint_called)
self.assertTrue(passed.get_token_called)
self.assertFalse(fixed.get_endpoint_called)
self.assertFalse(fixed.get_token_called)
def test_requests_auth_plugin(self):
sess = client_session.Session()
requests_auth = object()
FAKE_RESP = utils.test_response(text='resp')
RESP = mock.Mock(return_value=FAKE_RESP)
with mock.patch.object(sess.session, 'request', RESP) as mocked:
sess.get(self.TEST_URL, requests_auth=requests_auth)
mocked.assert_called_once_with('GET', self.TEST_URL,
headers=mock.ANY,
allow_redirects=mock.ANY,
auth=requests_auth,
verify=mock.ANY)
def test_reauth_called(self):
auth = CalledAuthPlugin(invalidate=True)
sess = client_session.Session(auth=auth)
self.requests_mock.get(self.TEST_URL,
[{'text': 'Failed', 'status_code': 401},
{'text': 'Hello', 'status_code': 200}])
# allow_reauth=True is the default
resp = sess.get(self.TEST_URL, authenticated=True)
self.assertEqual(200, resp.status_code)
self.assertEqual('Hello', resp.text)
self.assertTrue(auth.invalidate_called)
def test_reauth_not_called(self):
auth = CalledAuthPlugin(invalidate=True)
sess = client_session.Session(auth=auth)
self.requests_mock.get(self.TEST_URL,
[{'text': 'Failed', 'status_code': 401},
{'text': 'Hello', 'status_code': 200}])
self.assertRaises(exceptions.Unauthorized, sess.get, self.TEST_URL,
authenticated=True, allow_reauth=False)
self.assertFalse(auth.invalidate_called)
def test_endpoint_override_overrides_filter(self):
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
override_base = 'http://mytest/'
path = 'path'
override_url = override_base + path
resp_text = uuid.uuid4().hex
self.requests_mock.get(override_url, text=resp_text)
resp = sess.get(path,
endpoint_override=override_base,
endpoint_filter={'service_type': 'identity'})
self.assertEqual(resp_text, resp.text)
self.assertEqual(override_url, self.requests_mock.last_request.url)
self.assertTrue(auth.get_token_called)
self.assertFalse(auth.get_endpoint_called)
def test_endpoint_override_ignore_full_url(self):
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
path = 'path'
url = self.TEST_URL + path
resp_text = uuid.uuid4().hex
self.requests_mock.get(url, text=resp_text)
resp = sess.get(url,
endpoint_override='http://someother.url',
endpoint_filter={'service_type': 'identity'})
self.assertEqual(resp_text, resp.text)
self.assertEqual(url, self.requests_mock.last_request.url)
self.assertTrue(auth.get_token_called)
self.assertFalse(auth.get_endpoint_called)
def test_user_and_project_id(self):
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
self.assertEqual(auth.TEST_USER_ID, sess.get_user_id())
self.assertEqual(auth.TEST_PROJECT_ID, sess.get_project_id())
def test_logger_object_passed(self):
logger = logging.getLogger(uuid.uuid4().hex)
logger.setLevel(logging.DEBUG)
logger.propagate = False
io = six.StringIO()
handler = logging.StreamHandler(io)
logger.addHandler(handler)
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
response = uuid.uuid4().hex
self.stub_url('GET',
text=response,
headers={'Content-Type': 'text/html'})
resp = sess.get(self.TEST_URL, logger=logger)
self.assertEqual(response, resp.text)
output = io.getvalue()
self.assertIn(self.TEST_URL, output)
self.assertIn(response, output)
self.assertNotIn(self.TEST_URL, self.logger.output)
self.assertNotIn(response, self.logger.output)
class AdapterTest(utils.TestCase):
SERVICE_TYPE = uuid.uuid4().hex
SERVICE_NAME = uuid.uuid4().hex
INTERFACE = uuid.uuid4().hex
REGION_NAME = uuid.uuid4().hex
USER_AGENT = uuid.uuid4().hex
VERSION = uuid.uuid4().hex
TEST_URL = CalledAuthPlugin.ENDPOINT
def setUp(self):
super(AdapterTest, self).setUp()
self.deprecations.expect_deprecations()
def _create_loaded_adapter(self):
auth = CalledAuthPlugin()
sess = client_session.Session()
return adapter.Adapter(sess,
auth=auth,
service_type=self.SERVICE_TYPE,
service_name=self.SERVICE_NAME,
interface=self.INTERFACE,
region_name=self.REGION_NAME,
user_agent=self.USER_AGENT,
version=self.VERSION)
def _verify_endpoint_called(self, adpt):
self.assertEqual(self.SERVICE_TYPE,
adpt.auth.endpoint_arguments['service_type'])
self.assertEqual(self.SERVICE_NAME,
adpt.auth.endpoint_arguments['service_name'])
self.assertEqual(self.INTERFACE,
adpt.auth.endpoint_arguments['interface'])
self.assertEqual(self.REGION_NAME,
adpt.auth.endpoint_arguments['region_name'])
self.assertEqual(self.VERSION,
adpt.auth.endpoint_arguments['version'])
def test_setting_variables_on_request(self):
response = uuid.uuid4().hex
self.stub_url('GET', text=response)
adpt = self._create_loaded_adapter()
resp = adpt.get('/')
self.assertEqual(resp.text, response)
self._verify_endpoint_called(adpt)
self.assertTrue(adpt.auth.get_token_called)
self.assertRequestHeaderEqual('User-Agent', self.USER_AGENT)
def test_setting_variables_on_get_endpoint(self):
adpt = self._create_loaded_adapter()
url = adpt.get_endpoint()
self.assertEqual(self.TEST_URL, url)
self._verify_endpoint_called(adpt)
def test_legacy_binding(self):
key = uuid.uuid4().hex
val = uuid.uuid4().hex
response = jsonutils.dumps({key: val})
self.stub_url('GET', text=response)
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.LegacyJsonAdapter(sess,
service_type=self.SERVICE_TYPE,
user_agent=self.USER_AGENT)
resp, body = adpt.get('/')
self.assertEqual(self.SERVICE_TYPE,
auth.endpoint_arguments['service_type'])
self.assertEqual(resp.text, response)
self.assertEqual(val, body[key])
def test_legacy_binding_non_json_resp(self):
response = uuid.uuid4().hex
self.stub_url('GET', text=response,
headers={'Content-Type': 'text/html'})
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.LegacyJsonAdapter(sess,
service_type=self.SERVICE_TYPE,
user_agent=self.USER_AGENT)
resp, body = adpt.get('/')
self.assertEqual(self.SERVICE_TYPE,
auth.endpoint_arguments['service_type'])
self.assertEqual(resp.text, response)
self.assertIsNone(body)
def test_methods(self):
sess = client_session.Session()
adpt = adapter.Adapter(sess)
url = 'http://url'
for method in ['get', 'head', 'post', 'put', 'patch', 'delete']:
with mock.patch.object(adpt, 'request') as m:
getattr(adpt, method)(url)
m.assert_called_once_with(url, method.upper())
def test_setting_endpoint_override(self):
endpoint_override = 'http://overrideurl'
path = '/path'
endpoint_url = endpoint_override + path
auth = CalledAuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.Adapter(sess, endpoint_override=endpoint_override)
response = uuid.uuid4().hex
self.requests_mock.get(endpoint_url, text=response)
resp = adpt.get(path)
self.assertEqual(response, resp.text)
self.assertEqual(endpoint_url, self.requests_mock.last_request.url)
self.assertEqual(endpoint_override, adpt.get_endpoint())
def test_adapter_invalidate(self):
auth = CalledAuthPlugin()
sess = client_session.Session()
adpt = adapter.Adapter(sess, auth=auth)
adpt.invalidate()
self.assertTrue(auth.invalidate_called)
def test_adapter_get_token(self):
auth = CalledAuthPlugin()
sess = client_session.Session()
adpt = adapter.Adapter(sess, auth=auth)
self.assertEqual(self.TEST_TOKEN, adpt.get_token())
self.assertTrue(auth.get_token_called)
def test_adapter_connect_retries(self):
retries = 2
sess = client_session.Session()
adpt = adapter.Adapter(sess, connect_retries=retries)
def _refused_error(request, context):
raise requests.exceptions.ConnectionError()
self.stub_url('GET', text=_refused_error)
with mock.patch('time.sleep') as m:
self.assertRaises(exceptions.ConnectionRefused,
adpt.get, self.TEST_URL)
self.assertEqual(retries, m.call_count)
# we count retries so there will be one initial request + 2 retries
self.assertThat(self.requests_mock.request_history,
matchers.HasLength(retries + 1))
def test_user_and_project_id(self):
auth = AuthPlugin()
sess = client_session.Session()
adpt = adapter.Adapter(sess, auth=auth)
self.assertEqual(auth.TEST_USER_ID, adpt.get_user_id())
self.assertEqual(auth.TEST_PROJECT_ID, adpt.get_project_id())
def test_logger_object_passed(self):
logger = logging.getLogger(uuid.uuid4().hex)
logger.setLevel(logging.DEBUG)
logger.propagate = False
io = six.StringIO()
handler = logging.StreamHandler(io)
logger.addHandler(handler)
auth = AuthPlugin()
sess = client_session.Session(auth=auth)
adpt = adapter.Adapter(sess, auth=auth, logger=logger)
response = uuid.uuid4().hex
self.stub_url('GET', text=response,
headers={'Content-Type': 'text/html'})
resp = adpt.get(self.TEST_URL, logger=logger)
self.assertEqual(response, resp.text)
output = io.getvalue()
self.assertIn(self.TEST_URL, output)
self.assertIn(response, output)
self.assertNotIn(self.TEST_URL, self.logger.output)
self.assertNotIn(response, self.logger.output)
class ConfLoadingTests(utils.TestCase):
GROUP = 'sessiongroup'
def setUp(self):
super(ConfLoadingTests, self).setUp()
self.conf_fixture = self.useFixture(config.Config())
client_session.Session.register_conf_options(self.conf_fixture.conf,
self.GROUP)
def config(self, **kwargs):
kwargs['group'] = self.GROUP
self.conf_fixture.config(**kwargs)
def get_session(self, **kwargs):
with self.deprecations.expect_deprecations_here():
return client_session.Session.load_from_conf_options(
self.conf_fixture.conf,
self.GROUP,
**kwargs)
def test_insecure_timeout(self):
self.config(insecure=True, timeout=5)
s = self.get_session()
self.assertFalse(s.verify)
self.assertEqual(5, s.timeout)
def test_client_certs(self):
cert = '/path/to/certfile'
key = '/path/to/keyfile'
self.config(certfile=cert, keyfile=key)
s = self.get_session()
self.assertTrue(s.verify)
self.assertEqual((cert, key), s.cert)
def test_cacert(self):
cafile = '/path/to/cacert'
self.config(cafile=cafile)
s = self.get_session()
self.assertEqual(cafile, s.verify)
def test_deprecated(self):
def new_deprecated():
return cfg.DeprecatedOpt(uuid.uuid4().hex, group=uuid.uuid4().hex)
opt_names = ['cafile', 'certfile', 'keyfile', 'insecure', 'timeout']
depr = dict([(n, [new_deprecated()]) for n in opt_names])
opts = client_session.Session.get_conf_options(deprecated_opts=depr)
self.assertThat(opt_names, matchers.HasLength(len(opts)))
for opt in opts:
self.assertIn(depr[opt.name][0], opt.deprecated_opts)
class CliLoadingTests(utils.TestCase):
def setUp(self):
super(CliLoadingTests, self).setUp()
self.parser = argparse.ArgumentParser()
client_session.Session.register_cli_options(self.parser)
def get_session(self, val, **kwargs):
args = self.parser.parse_args(val.split())
with self.deprecations.expect_deprecations_here():
return client_session.Session.load_from_cli_options(args, **kwargs)
def test_insecure_timeout(self):
s = self.get_session('--insecure --timeout 5.5')
self.assertFalse(s.verify)
self.assertEqual(5.5, s.timeout)
def test_client_certs(self):
cert = '/path/to/certfile'
key = '/path/to/keyfile'
s = self.get_session('--os-cert %s --os-key %s' % (cert, key))
self.assertTrue(s.verify)
self.assertEqual((cert, key), s.cert)
def test_cacert(self):
cacert = '/path/to/cacert'
s = self.get_session('--os-cacert %s' % cacert)
self.assertEqual(cacert, s.verify)