Fix WSGI middleware and add unit tests

Add unit tests to cover osprofiler.web module

* Fix bug in add_trace_id_header
* Add enabled parameter for WSGI Middleware
* Small refactoring of WSGI Middleware
This commit is contained in:
Boris Pavlovic 2014-02-24 01:25:17 +04:00
parent 3c8f7da7a4
commit a105156eb0
2 changed files with 115 additions and 12 deletions

View File

@ -20,20 +20,22 @@ import webob.dec
from osprofiler import profiler
def add_trace_id_header(self, headers):
def add_trace_id_header(headers):
p = profiler.get_profiler()
if p:
kwargs = {'base_id': p.get_base_id(), 'parent_id': p.get_id[-1]}
headers['X-Trace-Info'] = base64.b64encode(pickle.dumps(kwargs))
kwargs = {"base_id": p.get_base_id(), "parent_id": p.get_id()}
headers["X-Trace-Info"] = base64.b64encode(pickle.dumps(kwargs))
class WsgiMiddleware(object):
"""WSGI Middleware that enables tracing for an application."""
def __init__(self, application, service_name='server', name='WSGI'):
def __init__(self, application, service_name='server', name='WSGI',
enabled=False):
self.application = application
self.service_name = service_name
self.name = name
self.enabled = enabled
@classmethod
def factory(cls, global_conf, **local_conf):
@ -43,15 +45,18 @@ class WsgiMiddleware(object):
@webob.dec.wsgify
def __call__(self, request):
trace_info = {}
trace_info_enc = request.headers.get('X-Trace-Info')
if not self.enabled:
return request.get_response(self.application)
trace_info_enc = request.headers.get("X-Trace-Info")
if trace_info_enc:
trace_info = pickle.loads(base64.b64decode(trace_info_enc))
p = profiler.init(trace_info.get("base_id"),
trace_info.get("parent_id"),
self.service_name)
p = profiler.init(trace_info.get("base_id"),
trace_info.get("parent_id"),
self.service_name)
with p(self.name, info={'url': request.url}):
response = request.get_response(self.application)
return response
with p(self.name, info={"url": request.url}):
return request.get_response(self.application)
return request.get_response(self.application)

98
tests/test_web.py Normal file
View File

@ -0,0 +1,98 @@
# Copyright 2014 Mirantis Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import base64
import mock
import pickle
from osprofiler import profiler
from osprofiler import web
from tests import test
class WebMiddlewareTestCase(test.TestCase):
@mock.patch("osprofiler.web.base64.b64encode")
@mock.patch("osprofiler.web.pickle.dumps")
@mock.patch("osprofiler.web.profiler.get_profiler")
def test_add_trace_id_header(self, mock_get_profiler,
mock_dumps, mock_b64encode):
mock_dumps.return_value = "dump"
mock_b64encode.return_value = "b64"
p = mock.MagicMock()
p.get_base_id.return_value = 1
p.get_id.return_value = 2
mock_get_profiler.return_value = p
headers = {"a": 10, "b": 20}
web.add_trace_id_header(headers)
self.assertEqual(sorted(headers.keys()),
sorted(["a", "b", "X-Trace-Info"]))
self.assertEqual(headers["X-Trace-Info"], "b64")
mock_b64encode.assert_called_once_with("dump")
mock_dumps.assert_called_once_with({"base_id": 1, "parent_id": 2})
@mock.patch("osprofiler.profiler.get_profiler")
def test_add_trace_id_header_no_profiler(self, mock_get_profiler):
mock_get_profiler.return_value = False
headers = {"a": "a", "b": 1}
old_headers = dict(headers)
web.add_trace_id_header(headers)
self.assertEqual(old_headers, headers)
def test_wsgi_middleware_no_trace(self):
request = mock.MagicMock()
request.get_response.return_value = "yeah!"
request.headers = {"a": "1", "b": "2"}
middleware = web.WsgiMiddleware("app", enabled=True)
self.assertEqual("yeah!", middleware(request))
request.get_response.assert_called_once_with("app")
def test_wsgi_middleware_disabled(self):
request = mock.MagicMock()
request.get_response.return_value = "yeah!"
request.headers = {"a": "1", "b": "2"}
middleware = web.WsgiMiddleware("app", enabled=False)
self.assertEqual("yeah!", middleware(request))
request.get_response.assert_called_once_with("app")
def test_wsgi_middleware(self):
request = mock.MagicMock()
request.get_response.return_value = "yeah!"
request.url = "someurl"
trace_info = {"base_id": "1", "parent_id": "2"}
request.headers = {
"a": "1",
"b": "2",
"X-Trace-Info": base64.b64encode(pickle.dumps(trace_info))
}
p = profiler.init()
p.start = mock.MagicMock()
p.stop = mock.MagicMock()
with mock.patch("osprofiler.web.profiler.init") as mock_profiler_init:
mock_profiler_init.return_value = p
middleware = web.WsgiMiddleware("app", service_name="ss",
name="WSGI", enabled=True)
self.assertEqual("yeah!", middleware(request))
mock_profiler_init.assert_called_once_with("1", "2", "ss")
p.start.assert_called_once_with("WSGI", info={"url": request.url})
p.stop.assert_called_once_with()