From 1fc40d6c296d36da5d5c48c7c2aba537ed988f1c Mon Sep 17 00:00:00 2001
From: David Goetz <david.goetz@gmail.com>
Date: Fri, 29 Oct 2010 13:30:34 -0700
Subject: [PATCH] catching invalid urls and adding tests

---
 swift/common/middleware/ratelimit.py          |  6 +++++-
 swift/common/utils.py                         |  2 ++
 test/functional/tests.py                      |  9 +++++++++
 test/unit/common/middleware/test_ratelimit.py | 20 +++++++++++++++++++
 4 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/swift/common/middleware/ratelimit.py b/swift/common/middleware/ratelimit.py
index b13a4a6ab4..707e544ae2 100644
--- a/swift/common/middleware/ratelimit.py
+++ b/swift/common/middleware/ratelimit.py
@@ -14,6 +14,7 @@
 import time
 import eventlet
 from webob import Request, Response
+from webob.exc import HTTPNotFound
 
 from swift.common.utils import split_path, cache_from_env, get_logger
 from swift.proxy.server import get_container_memcache_key
@@ -204,7 +205,10 @@ class RateLimitMiddleware(object):
         req = Request(env)
         if self.memcache_client is None:
             self.memcache_client = cache_from_env(env)
-        version, account, container, obj = split_path(req.path, 1, 4, True)
+        try:
+            version, account, container, obj = split_path(req.path, 1, 4, True)
+        except ValueError:
+            return HTTPNotFound()(env, start_response)
         ratelimit_resp = self.handle_ratelimit(req, account, container, obj)
         if ratelimit_resp is None:
             return self.app(env, start_response)
diff --git a/swift/common/utils.py b/swift/common/utils.py
index 5a8bf0e1be..bb635725c8 100644
--- a/swift/common/utils.py
+++ b/swift/common/utils.py
@@ -208,6 +208,7 @@ def split_path(path, minsegs=1, maxsegs=None, rest_with_last=False):
                            trailing data, raises ValueError.
     :returns: list of segments with a length of maxsegs (non-existant
               segments will return as None)
+    :raises: ValueError if given an invalid path
     """
     if not maxsegs:
         maxsegs = minsegs
@@ -622,6 +623,7 @@ def write_pickle(obj, dest, tmp):
         os.fsync(fd)
         renamer(tmppath, dest)
 
+
 def audit_location_generator(devices, datadir, mount_check=True, logger=None):
     '''
     Given a devices path and a data directory, yield (path, device,
diff --git a/test/functional/tests.py b/test/functional/tests.py
index 6a28d9bb3e..f1ea6232b0 100644
--- a/test/functional/tests.py
+++ b/test/functional/tests.py
@@ -170,6 +170,15 @@ class TestAccount(Base):
         self.assert_status(412)
         self.assert_body('Bad URL')
 
+    def testInvalidPath(self):
+        was_url = self.env.account.conn.storage_url
+        self.env.account.conn.storage_url = "/%s" % was_url
+        self.env.account.conn.make_request('GET')
+        try:
+            self.assert_status(404)
+        finally:
+            self.env.account.conn.storage_url = was_url
+
     def testPUT(self):
         self.env.account.conn.make_request('PUT')
         self.assert_status([403, 405])
diff --git a/test/unit/common/middleware/test_ratelimit.py b/test/unit/common/middleware/test_ratelimit.py
index 7bf4a9e445..2f709c4a41 100644
--- a/test/unit/common/middleware/test_ratelimit.py
+++ b/test/unit/common/middleware/test_ratelimit.py
@@ -366,6 +366,26 @@ class TestRateLimit(unittest.TestCase):
         time_took = time.time() - begin
         self.assert_(round(time_took, 1) == .4)
 
+    def test_call_invalid_path(self):
+        env = {'REQUEST_METHOD': 'GET',
+               'SCRIPT_NAME': '',
+               'PATH_INFO': '//v1/AUTH_1234567890',
+               'SERVER_NAME': '127.0.0.1',
+               'SERVER_PORT': '80',
+               'swift.cache': FakeMemcache(),
+               'SERVER_PROTOCOL': 'HTTP/1.0'}
+
+        app = lambda *args, **kwargs: None
+        rate_mid = ratelimit.RateLimitMiddleware(app, {},
+                                                 logger=FakeLogger())
+
+        class a_callable(object):
+
+            def __call__(self, *args, **kwargs):
+                pass
+        resp = rate_mid.__call__(env, a_callable())
+        self.assert_('404 Not Found' in resp[0])
+
 
 if __name__ == '__main__':
     unittest.main()