diff --git a/osprofiler/web.py b/osprofiler/web.py index 88b496c..94b0c0c 100644 --- a/osprofiler/web.py +++ b/osprofiler/web.py @@ -26,16 +26,16 @@ def add_trace_id_header(headers): p = profiler.get_profiler() if p: idents = {"base_id": p.get_base_id(), "parent_id": p.get_id()} - raw_content = json.dumps(idents) - headers["X-Trace-Info"] = utils.binary_encode(raw_content) + raw_content = utils.binary_encode(json.dumps(idents)) + headers["X-Trace-Info"] = raw_content if p.hmac_key: headers["X-Trace-HMAC"] = generate_hmac(raw_content, p.hmac_key) def generate_hmac(content, hmac_key): """Generate a hmac using a known key given the provided content.""" - h = hmac.new(hmac_key, digestmod=hashlib.sha1) - h.update(content) + h = hmac.new(utils.binary_encode(hmac_key), digestmod=hashlib.sha1) + h.update(utils.binary_encode(content)) return h.hexdigest() @@ -45,9 +45,7 @@ def validate_hmac(content, expected_hmac, hmac_key): or was being faked). """ if hmac_key: - h = hmac.new(hmac_key, digestmod=hashlib.sha1) - h.update(content) - if h.hexdigest() != expected_hmac: + if generate_hmac(content, hmac_key) != expected_hmac: raise IOError("Invalid hmac detected") diff --git a/tests/test_web.py b/tests/test_web.py index cc2a98a..daa43ae 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -16,13 +16,26 @@ import json import mock +from webob import request as webob_request +from webob import response as webob_response + +from osprofiler import profiler from osprofiler import utils from osprofiler import web from tests import test +def dummy_app(environ, response): + res = webob_response.Response() + return res(environ, response) + + class WebMiddlewareTestCase(test.TestCase): + def setUp(self): + super(WebMiddlewareTestCase, self).setUp() + profiler._clean() + self.addCleanup(profiler._clean) @mock.patch("osprofiler.web.utils.binary_encode") @mock.patch("osprofiler.web.json.dumps") @@ -55,6 +68,119 @@ class WebMiddlewareTestCase(test.TestCase): web.add_trace_id_header(headers) self.assertEqual(old_headers, headers) + def test_wsgi_hmac_no_headers(self): + req = webob_request.Request.blank("/") + m = web.WsgiMiddleware(dummy_app, enabled=True, + hmac_key="secret_password") + m(req) + p = profiler.get_profiler() + self.assertIsNone(p) + + def test_wsgi_hmac_headers_init_profiler(self): + hmac_key = 'secret_password' + profiler.init(base_id="b", parent_id="a", hmac_key=hmac_key) + headers = { + 'Content-Type': 'text/javascript', + } + web.add_trace_id_header(headers) + profiler._clean() + self.assertIsNone(profiler.get_profiler()) + + req = webob_request.Request.blank("/", headers=headers) + m = web.WsgiMiddleware(dummy_app, enabled=True, hmac_key=hmac_key) + m(req) + + p = profiler.get_profiler() + self.assertIsNotNone(p) + self.assertEqual('a', p.get_id()) + self.assertEqual('b', p.get_base_id()) + + def test_wsgi_hmac_headers_init_profiler_spaces(self): + hmac_key = 'secret_password' + profiler.init(base_id="b", parent_id="a", hmac_key=hmac_key) + headers = { + 'Content-Type': 'text/javascript', + } + web.add_trace_id_header(headers) + headers['X-Trace-HMAC'] = "\t " + headers['X-Trace-HMAC'] + " \n" + profiler._clean() + self.assertIsNone(profiler.get_profiler()) + + req = webob_request.Request.blank("/", headers=headers) + m = web.WsgiMiddleware(dummy_app, enabled=True, hmac_key=hmac_key) + m(req) + + p = profiler.get_profiler() + self.assertIsNotNone(p) + self.assertEqual('a', p.get_id()) + self.assertEqual('b', p.get_base_id()) + + def test_wsgi_hmac_headers_no_init_profiler(self): + profiler.init(base_id="b", parent_id="a", hmac_key="hacked_password") + headers = { + 'Content-Type': 'text/javascript', + } + web.add_trace_id_header(headers) + profiler._clean() + self.assertIsNone(profiler.get_profiler()) + + req = webob_request.Request.blank("/", headers=headers) + m = web.WsgiMiddleware(dummy_app, enabled=True, + hmac_key="secret_password") + m(req) + + p = profiler.get_profiler() + self.assertIsNone(p) + + def test_hmac_generation(self): + profiler.init(base_id="b", parent_id="a", hmac_key="secret_password") + headers = { + 'Content-Type': 'text/javascript', + } + web.add_trace_id_header(headers) + self.assertIn('X-Trace-HMAC', headers) + self.assertTrue(len(headers['X-Trace-HMAC']) > 0) + + def test_hmac_no_generation(self): + profiler.init(base_id="b", parent_id="a") + headers = { + 'Content-Type': 'text/javascript', + } + web.add_trace_id_header(headers) + self.assertNotIn('X-Trace-HMAC', headers) + self.assertIn('X-Trace-Info', headers) + self.assertEqual(2, len(headers)) + + def test_hmac_validation(self): + profiler.init(base_id="b", parent_id="a", hmac_key="secret_password") + headers = { + 'Content-Type': 'text/javascript', + } + web.add_trace_id_header(headers) + content = headers.get("X-Trace-Info") + web.validate_hmac(content, headers['X-Trace-HMAC'], "secret_password") + + def test_invalid_hmac(self): + profiler.init(base_id="b", parent_id="a", hmac_key="secret_password") + headers = { + 'Content-Type': 'text/javascript', + } + web.add_trace_id_header(headers) + content = headers.get("X-Trace-Info") + content += b"_changed" + self.assertRaises(IOError, web.validate_hmac, content, + headers['X-Trace-HMAC'], "secret_password") + + def test_hmac_faked(self): + headers = { + 'Content-Type': 'text/javascript', + 'X-Trace-HMAC': 'fake', + 'X-Trace-Info': '{}', + } + content = headers.get("X-Trace-Info") + self.assertRaises(IOError, web.validate_hmac, content, + headers['X-Trace-HMAC'], 'secret_password') + def test_wsgi_middleware_no_trace(self): request = mock.MagicMock() request.get_response.return_value = "yeah!"