Merge "Fixing bug with x-trans-id. Will now be set on all incoming requests to proxy and trans-ids will not be reused."
This commit is contained in:
		@@ -22,7 +22,8 @@ from swift.common.utils import get_logger
 | 
			
		||||
 | 
			
		||||
class CatchErrorMiddleware(object):
 | 
			
		||||
    """
 | 
			
		||||
    Middleware that provides high-level error handling.
 | 
			
		||||
    Middleware that provides high-level error handling and ensures that a
 | 
			
		||||
    transaction id will be set for every request.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, app, conf):
 | 
			
		||||
@@ -30,10 +31,12 @@ class CatchErrorMiddleware(object):
 | 
			
		||||
        self.logger = get_logger(conf, log_route='catch-errors')
 | 
			
		||||
 | 
			
		||||
    def __call__(self, env, start_response):
 | 
			
		||||
        trans_id = env.get('HTTP_X_TRANS_ID')
 | 
			
		||||
        if not trans_id:
 | 
			
		||||
        """
 | 
			
		||||
        If used, this should be the first middleware in pipeline.
 | 
			
		||||
        """
 | 
			
		||||
        trans_id = 'tx' + uuid.uuid4().hex
 | 
			
		||||
            env['HTTP_X_TRANS_ID'] = trans_id
 | 
			
		||||
        env['swift.trans_id'] = trans_id
 | 
			
		||||
        self.logger.txn_id = trans_id
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
            def my_start_response(status, response_headers, exc_info=None):
 | 
			
		||||
 
 | 
			
		||||
@@ -214,7 +214,7 @@ class StaticWeb(object):
 | 
			
		||||
        """
 | 
			
		||||
        new_env = {'REQUEST_METHOD': 'GET',
 | 
			
		||||
            'HTTP_USER_AGENT': '%s StaticWeb' % env.get('HTTP_USER_AGENT')}
 | 
			
		||||
        for name in ('eventlet.posthooks', 'HTTP_X_CF_TRANS_ID', 'REMOTE_USER',
 | 
			
		||||
        for name in ('eventlet.posthooks', 'swift.trans_id', 'REMOTE_USER',
 | 
			
		||||
                     'SCRIPT_NAME', 'SERVER_NAME', 'SERVER_PORT',
 | 
			
		||||
                     'SERVER_PROTOCOL', 'swift.cache'):
 | 
			
		||||
            if name in env:
 | 
			
		||||
@@ -532,7 +532,7 @@ class StaticWeb(object):
 | 
			
		||||
            '-',
 | 
			
		||||
            '-',
 | 
			
		||||
            env.get('HTTP_ETAG', '-'),
 | 
			
		||||
            env.get('HTTP_X_CF_TRANS_ID', '-'),
 | 
			
		||||
            env.get('swift.trans_id', '-'),
 | 
			
		||||
            logged_headers or '-',
 | 
			
		||||
            trans_time)))
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -465,7 +465,7 @@ class TempAuth(object):
 | 
			
		||||
            getattr(req, 'bytes_transferred', 0) or '-',
 | 
			
		||||
            getattr(response, 'bytes_transferred', 0) or '-',
 | 
			
		||||
            req.headers.get('etag', '-'),
 | 
			
		||||
            req.headers.get('x-trans-id', '-'), logged_headers or '-',
 | 
			
		||||
            req.environ.get('swift.trans_id', '-'), logged_headers or '-',
 | 
			
		||||
            trans_time)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -303,8 +303,7 @@ class LogAdapter(logging.LoggerAdapter, object):
 | 
			
		||||
    client ip.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _txn_id = threading.local()
 | 
			
		||||
    _client_ip = threading.local()
 | 
			
		||||
    _cls_thread_local = threading.local()
 | 
			
		||||
 | 
			
		||||
    def __init__(self, logger, server):
 | 
			
		||||
        logging.LoggerAdapter.__init__(self, logger, {})
 | 
			
		||||
@@ -313,21 +312,21 @@ class LogAdapter(logging.LoggerAdapter, object):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def txn_id(self):
 | 
			
		||||
        if hasattr(self._txn_id, 'value'):
 | 
			
		||||
            return self._txn_id.value
 | 
			
		||||
        if hasattr(self._cls_thread_local, 'txn_id'):
 | 
			
		||||
            return self._cls_thread_local.txn_id
 | 
			
		||||
 | 
			
		||||
    @txn_id.setter
 | 
			
		||||
    def txn_id(self, value):
 | 
			
		||||
        self._txn_id.value = value
 | 
			
		||||
        self._cls_thread_local.txn_id = value
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def client_ip(self):
 | 
			
		||||
        if hasattr(self._client_ip, 'value'):
 | 
			
		||||
            return self._client_ip.value
 | 
			
		||||
        if hasattr(self._cls_thread_local, 'client_ip'):
 | 
			
		||||
            return self._cls_thread_local.client_ip
 | 
			
		||||
 | 
			
		||||
    @client_ip.setter
 | 
			
		||||
    def client_ip(self, value):
 | 
			
		||||
        self._client_ip.value = value
 | 
			
		||||
        self._cls_thread_local.client_ip = value
 | 
			
		||||
 | 
			
		||||
    def getEffectiveLevel(self):
 | 
			
		||||
        return self.logger.getEffectiveLevel()
 | 
			
		||||
 
 | 
			
		||||
@@ -207,7 +207,7 @@ def make_pre_authed_request(env, method, path, body=None, headers=None,
 | 
			
		||||
    (Stolen from Swauth: https://github.com/gholt/swauth)
 | 
			
		||||
    """
 | 
			
		||||
    newenv = {'REQUEST_METHOD': method, 'HTTP_USER_AGENT': agent}
 | 
			
		||||
    for name in ('swift.cache', 'HTTP_X_TRANS_ID'):
 | 
			
		||||
    for name in ('swift.cache', 'swift.trans_id'):
 | 
			
		||||
        if name in env:
 | 
			
		||||
            newenv[name] = env[name]
 | 
			
		||||
    newenv['swift.authorize'] = lambda req: None
 | 
			
		||||
 
 | 
			
		||||
@@ -1632,8 +1632,13 @@ class BaseApplication(object):
 | 
			
		||||
                return HTTPPreconditionFailed(request=req, body='Bad URL')
 | 
			
		||||
 | 
			
		||||
            controller = controller(self, **path_parts)
 | 
			
		||||
            controller.trans_id = req.headers.get('x-trans-id', '-')
 | 
			
		||||
            self.logger.txn_id = req.headers.get('x-trans-id', None)
 | 
			
		||||
            if 'swift.trans_id' not in req.environ:
 | 
			
		||||
                # if this wasn't set by an earlier middleware, set it now
 | 
			
		||||
                trans_id = 'tx' + uuid.uuid4().hex
 | 
			
		||||
                req.environ['swift.trans_id'] = trans_id
 | 
			
		||||
                self.logger.txn_id = trans_id
 | 
			
		||||
            req.headers['x-trans-id'] = req.environ['swift.trans_id']
 | 
			
		||||
            controller.trans_id = req.environ['swift.trans_id']
 | 
			
		||||
            self.logger.client_ip = get_remote_client(req)
 | 
			
		||||
            try:
 | 
			
		||||
                handler = getattr(controller, req.method)
 | 
			
		||||
@@ -1708,10 +1713,12 @@ class Application(BaseApplication):
 | 
			
		||||
                getattr(req, 'bytes_transferred', 0) or '-',
 | 
			
		||||
                getattr(response, 'bytes_transferred', 0) or '-',
 | 
			
		||||
                req.headers.get('etag', '-'),
 | 
			
		||||
                req.headers.get('x-trans-id', '-'),
 | 
			
		||||
                req.environ.get('swift.trans_id', '-'),
 | 
			
		||||
                logged_headers or '-',
 | 
			
		||||
                trans_time,
 | 
			
		||||
            )))
 | 
			
		||||
        # done with this transaction
 | 
			
		||||
        self.access_logger.txn_id = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def app_factory(global_conf, **local_conf):
 | 
			
		||||
 
 | 
			
		||||
@@ -15,29 +15,36 @@
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from webob import Request
 | 
			
		||||
from webob import Request, Response
 | 
			
		||||
 | 
			
		||||
from swift.common.middleware import catch_errors
 | 
			
		||||
from swift.common.utils import get_logger
 | 
			
		||||
 | 
			
		||||
class FakeApp(object):
 | 
			
		||||
    def __init__(self, error=False):
 | 
			
		||||
        self.error = error
 | 
			
		||||
 | 
			
		||||
    def __call__(self, env, start_response):
 | 
			
		||||
        if 'swift.trans_id' not in env:
 | 
			
		||||
            raise Exception('Trans id should always be in env')
 | 
			
		||||
        if self.error:
 | 
			
		||||
            raise Exception('augh!')
 | 
			
		||||
        return "FAKE APP"
 | 
			
		||||
            raise Exception('An error occurred')
 | 
			
		||||
        return ["FAKE APP"]
 | 
			
		||||
 | 
			
		||||
def start_response(*args):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
class TestCatchErrors(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.logger = get_logger({})
 | 
			
		||||
        self.logger.txn_id = None
 | 
			
		||||
 | 
			
		||||
    def test_catcherrors_passthrough(self):
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
 | 
			
		||||
        req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'})
 | 
			
		||||
        resp = app(req.environ, start_response)
 | 
			
		||||
        self.assertEquals(resp, 'FAKE APP')
 | 
			
		||||
        self.assertEquals(resp, ['FAKE APP'])
 | 
			
		||||
 | 
			
		||||
    def test_catcherrors(self):
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
 | 
			
		||||
@@ -45,28 +52,23 @@ class TestCatchErrors(unittest.TestCase):
 | 
			
		||||
        resp = app(req.environ, start_response)
 | 
			
		||||
        self.assertEquals(resp, ['An error occurred'])
 | 
			
		||||
 | 
			
		||||
    def test_trans_id_header(self):
 | 
			
		||||
 | 
			
		||||
    def test_trans_id_header_pass(self):
 | 
			
		||||
        self.assertEquals(self.logger.txn_id, None)
 | 
			
		||||
        def start_response(status, headers):
 | 
			
		||||
            self.assert_('x-trans-id' in (x[0] for x in headers))
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
 | 
			
		||||
        req = Request.blank('/v1/a')
 | 
			
		||||
        app(req.environ, start_response)
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
 | 
			
		||||
        req = Request.blank('/v1/a/c')
 | 
			
		||||
        app(req.environ, start_response)
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
 | 
			
		||||
        req = Request.blank('/v1/a/c/o')
 | 
			
		||||
        app(req.environ, start_response)
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
 | 
			
		||||
        req = Request.blank('/v1/a')
 | 
			
		||||
        app(req.environ, start_response)
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
 | 
			
		||||
        req = Request.blank('/v1/a/c')
 | 
			
		||||
        app(req.environ, start_response)
 | 
			
		||||
        self.assertEquals(len(self.logger.txn_id), 34) # 32 hex + 'tx'
 | 
			
		||||
 | 
			
		||||
    def test_trans_id_header_fail(self):
 | 
			
		||||
        self.assertEquals(self.logger.txn_id, None)
 | 
			
		||||
        def start_response(status, headers):
 | 
			
		||||
            self.assert_('x-trans-id' in (x[0] for x in headers))
 | 
			
		||||
        app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
 | 
			
		||||
        req = Request.blank('/v1/a/c/o')
 | 
			
		||||
        app(req.environ, start_response)
 | 
			
		||||
        self.assertEquals(len(self.logger.txn_id), 34)
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
@@ -179,7 +179,7 @@ class TestWSGI(unittest.TestCase):
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def fake_blank(cls, path, environ={}, body='', headers={}):
 | 
			
		||||
                self.assertEquals(environ['swift.authorize']('test'), None)
 | 
			
		||||
                self.assertEquals(environ['HTTP_X_TRANS_ID'], '1234')
 | 
			
		||||
                self.assertFalse('HTTP_X_TRANS_ID' in environ)
 | 
			
		||||
        was_blank = Request.blank
 | 
			
		||||
        Request.blank = FakeReq.fake_blank
 | 
			
		||||
        wsgi.make_pre_authed_request({'HTTP_X_TRANS_ID': '1234'},
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user