diff --git a/keystoneclient/session.py b/keystoneclient/session.py index 5382ee95a..89250d584 100644 --- a/keystoneclient/session.py +++ b/keystoneclient/session.py @@ -33,8 +33,12 @@ class Session(object): user_agent = None + REDIRECT_STATUSES = (301, 302, 303, 305, 307) + DEFAULT_REDIRECT_LIMIT = 30 + def __init__(self, session=None, original_ip=None, verify=True, cert=None, - timeout=None, user_agent=None): + timeout=None, user_agent=None, + redirect=DEFAULT_REDIRECT_LIMIT): """Maintains client communication state and common functionality. As much as possible the parameters to this class reflect and are passed @@ -61,6 +65,10 @@ class Session(object): request. If not provided a default is used. (optional, defaults to 'python-keystoneclient') + :param int/bool redirect: Controls the maximum number of redirections + that can be followed by a request. Either an + integer for a specific count or True/False + for forever/never. (optional, default to 30) """ if not session: session = requests.Session() @@ -70,6 +78,7 @@ class Session(object): self.verify = verify self.cert = cert self.timeout = None + self.redirect = redirect if timeout is not None: self.timeout = float(timeout) @@ -79,7 +88,7 @@ class Session(object): self.user_agent = user_agent def request(self, url, method, json=None, original_ip=None, - user_agent=None, **kwargs): + user_agent=None, redirect=None, **kwargs): """Send an HTTP request with the specified characteristics. Wrapper around `requests.Session.request` to handle tasks such as @@ -93,13 +102,19 @@ class Session(object): :param string original_ip: Mark this request as forwarded for this ip. (optional) :param dict headers: Headers to be included in the request. (optional) - :param kwargs: any other parameter that can be passed to - requests.Session.request (such as `headers`) or `json` - that will be encoded as JSON and used as `data` argument :param json: Some data to be represented as JSON. (optional) :param string user_agent: A user_agent to use for the request. If present will override one present in headers. (optional) + :param int/bool redirect: the maximum number of redirections that + can be followed by a request. Either an + integer for a specific count or True/False + for forever/never. (optional) + :param kwargs: any other parameter that can be passed to + requests.Session.request (such as `headers`). Except: + 'data' will be overwritten by the data in 'json' param. + 'allow_redirects' is ignored as redirects are handled + by the session. :raises exceptions.ClientException: For connection failure, or to indicate an error response code. @@ -149,7 +164,17 @@ class Session(object): if data: _logger.debug('REQ BODY: %s', data) - resp = self._send_request(url, method, **kwargs) + # Force disable requests redirect handling. We will manage this below. + kwargs['allow_redirects'] = False + + if redirect is None: + redirect = self.redirect + + resp = self._send_request(url, method, redirect, **kwargs) + + # NOTE(jamielennox): we create a tuple here to be the same as what is + # returned by the requests library. + resp.history = tuple(resp.history) if resp.status_code >= 400: _logger.debug('Request returned failure status: %s', @@ -158,7 +183,12 @@ class Session(object): return resp - def _send_request(self, url, method, **kwargs): + def _send_request(self, url, method, redirect, **kwargs): + # NOTE(jamielennox): We handle redirection manually because the + # requests lib follows some browser patterns where it will redirect + # POSTs as GETs for certain statuses which is not want we want for an + # API. See: https://en.wikipedia.org/wiki/Post/Redirect/Get + try: resp = self.session.request(method, url, **kwargs) except requests.exceptions.SSLError: @@ -174,16 +204,30 @@ class Session(object): _logger.debug('RESP: [%s] %s\nRESP BODY: %s\n', resp.status_code, resp.headers, resp.text) - # NOTE(jamielennox): The requests lib will handle the majority of - # redirections. Where it fails is when POSTs are redirected which - # is apparently something handled differently by each browser which - # requests forces us to do the most compliant way (which we don't want) - # see: https://en.wikipedia.org/wiki/Post/Redirect/Get - # Nova and other direct users don't do this. Is it still relevant? - if resp.status_code in (301, 302, 305): - # Redirected. Reissue the request to the new location. - return self._send_request(resp.headers['location'], - method, **kwargs) + if resp.status_code in self.REDIRECT_STATUSES: + # be careful here in python True == 1 and False == 0 + if isinstance(redirect, bool): + redirect_allowed = redirect + else: + redirect -= 1 + redirect_allowed = redirect >= 0 + + if not redirect_allowed: + return resp + + try: + location = resp.headers['location'] + except KeyError: + _logger.warn("Failed to redirect request to %s as new " + "location was not provided.", resp.url) + else: + new_resp = self._send_request(location, method, redirect, + **kwargs) + + if not isinstance(new_resp.history, list): + new_resp.history = list(new_resp.history) + new_resp.history.insert(0, resp) + resp = new_resp return resp diff --git a/keystoneclient/tests/test_session.py b/keystoneclient/tests/test_session.py index 74fab7864..0edaac48c 100644 --- a/keystoneclient/tests/test_session.py +++ b/keystoneclient/tests/test_session.py @@ -15,6 +15,7 @@ import httpretty import mock +import requests from keystoneclient import exceptions from keystoneclient import session as client_session @@ -138,3 +139,86 @@ class SessionTests(utils.TestCase): self.stub_url(httpretty.GET, status=500) self.assertRaises(exceptions.InternalServerError, session.get, self.TEST_URL) + + +class RedirectTests(utils.TestCase): + + REDIRECT_CHAIN = ['http://myhost:3445/', + 'http://anotherhost:6555/', + 'http://thirdhost/', + 'http://finaldestination:55/'] + + DEFAULT_REDIRECT_BODY = 'Redirect' + DEFAULT_RESP_BODY = 'Found' + + def setup_redirects(self, method=httpretty.GET, status=305, + redirect_kwargs={}, final_kwargs={}): + redirect_kwargs.setdefault('body', self.DEFAULT_REDIRECT_BODY) + + for s, d in zip(self.REDIRECT_CHAIN, self.REDIRECT_CHAIN[1:]): + httpretty.register_uri(method, s, status=status, location=d, + **redirect_kwargs) + + final_kwargs.setdefault('status', 200) + final_kwargs.setdefault('body', self.DEFAULT_RESP_BODY) + httpretty.register_uri(method, self.REDIRECT_CHAIN[-1], **final_kwargs) + + def assertResponse(self, resp): + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, self.DEFAULT_RESP_BODY) + + @httpretty.activate + def test_basic_get(self): + session = client_session.Session() + self.setup_redirects() + resp = session.get(self.REDIRECT_CHAIN[-2]) + self.assertResponse(resp) + + @httpretty.activate + def test_basic_post_keeps_correct_method(self): + session = client_session.Session() + self.setup_redirects(method=httpretty.POST, status=301) + resp = session.post(self.REDIRECT_CHAIN[-2]) + self.assertResponse(resp) + + @httpretty.activate + def test_redirect_forever(self): + session = client_session.Session(redirect=True) + self.setup_redirects() + resp = session.get(self.REDIRECT_CHAIN[0]) + self.assertResponse(resp) + self.assertTrue(len(resp.history), len(self.REDIRECT_CHAIN)) + + @httpretty.activate + def test_no_redirect(self): + session = client_session.Session(redirect=False) + self.setup_redirects() + resp = session.get(self.REDIRECT_CHAIN[0]) + self.assertEqual(resp.status_code, 305) + self.assertEqual(resp.url, self.REDIRECT_CHAIN[0]) + + @httpretty.activate + def test_redirect_limit(self): + self.setup_redirects() + for i in (1, 2): + session = client_session.Session(redirect=i) + resp = session.get(self.REDIRECT_CHAIN[0]) + self.assertEqual(resp.status_code, 305) + self.assertEqual(resp.url, self.REDIRECT_CHAIN[i]) + self.assertEqual(resp.text, self.DEFAULT_REDIRECT_BODY) + + @httpretty.activate + def test_history_matches_requests(self): + self.setup_redirects(status=301) + session = client_session.Session(redirect=True) + req_resp = requests.get(self.REDIRECT_CHAIN[0], + allow_redirects=True) + + ses_resp = session.get(self.REDIRECT_CHAIN[0]) + + self.assertEqual(type(req_resp.history), type(ses_resp.history)) + self.assertEqual(len(req_resp.history), len(ses_resp.history)) + + for r, s in zip(req_resp.history, ses_resp.history): + self.assertEqual(r.url, s.url) + self.assertEqual(r.status_code, s.status_code) diff --git a/keystoneclient/tests/test_shell.py b/keystoneclient/tests/test_shell.py index 2a80c9db0..3e76d8cea 100644 --- a/keystoneclient/tests/test_shell.py +++ b/keystoneclient/tests/test_shell.py @@ -445,6 +445,7 @@ class ShellTest(utils.TestCase): ' --os-auth-url=blah.com endpoint-list')) request_mock.assert_called_with(mock.ANY, mock.ANY, timeout=2, + allow_redirects=False, headers=mock.ANY, verify=mock.ANY)