From a720b4d7577f6596966310101815bf69cc8ad2c8 Mon Sep 17 00:00:00 2001
From: Samuel Merritt <sam@swiftstack.com>
Date: Wed, 18 Dec 2013 10:43:29 -0800
Subject: [PATCH] Some functional tests for object versioning

Also fix some exception-handling stuff in functional tests; at some
point, ResponseError() started needing two more parameters, but the
functional tests (not swift_test_client, just tests.py) still had a
couple spots that were not passing in the new params. Now they're
optional again, and if you omit them, then the stringification of the
ResponseError is just a little less useful than it could be.

Change-Id: I38968c4b590fc04b97b85c5f974c8648291a6689
---
 test/functional/swift_test_client.py | 23 +++++---
 test/functional/tests.py             | 79 ++++++++++++++++++++++++++++
 2 files changed, 96 insertions(+), 6 deletions(-)

diff --git a/test/functional/swift_test_client.py b/test/functional/swift_test_client.py
index 67d4472302..c5a2cb81df 100644
--- a/test/functional/swift_test_client.py
+++ b/test/functional/swift_test_client.py
@@ -40,7 +40,7 @@ class RequestError(Exception):
 
 
 class ResponseError(Exception):
-    def __init__(self, response, method, path):
+    def __init__(self, response, method=None, path=None):
         self.status = response.status
         self.reason = response.reason
         self.method = method
@@ -310,10 +310,11 @@ class Base:
     def __str__(self):
         return self.name
 
-    def header_fields(self, fields):
+    def header_fields(self, required_fields, optional_fields=()):
         headers = dict(self.conn.response.getheaders())
         ret = {}
-        for field in fields:
+
+        for field in required_fields:
             if field[1] not in headers:
                 raise ValueError("%s was not found in response header" %
                                  (field[1]))
@@ -322,6 +323,15 @@ class Base:
                 ret[field[0]] = int(headers[field[1]])
             except ValueError:
                 ret[field[0]] = headers[field[1]]
+
+        for field in optional_fields:
+            if field[1] not in headers:
+                continue
+            try:
+                ret[field[0]] = int(headers[field[1]])
+            except ValueError:
+                ret[field[0]] = headers[field[1]]
+
         return ret
 
 
@@ -480,10 +490,11 @@ class Container(Base):
                                parms=parms, cfg=cfg)
 
         if self.conn.response.status == 204:
-            fields = [['bytes_used', 'x-container-bytes-used'],
-                      ['object_count', 'x-container-object-count']]
+            required_fields = [['bytes_used', 'x-container-bytes-used'],
+                               ['object_count', 'x-container-object-count']]
+            optional_fields = [['versions', 'x-versions-location']]
 
-            return self.header_fields(fields)
+            return self.header_fields(required_fields, optional_fields)
 
         raise ResponseError(self.conn.response, 'HEAD',
                             self.conn.make_path(self.path))
diff --git a/test/functional/tests.py b/test/functional/tests.py
index d1c626e12d..d2a4134199 100644
--- a/test/functional/tests.py
+++ b/test/functional/tests.py
@@ -1955,5 +1955,84 @@ class TestSloUTF8(Base2, TestSlo):
     set_up = False
 
 
+class TestObjectVersioningEnv(object):
+    versioning_enabled = None  # tri-state: None initially, then True/False
+
+    @classmethod
+    def setUp(cls):
+        cls.conn = Connection(config)
+        cls.conn.authenticate()
+
+        cls.account = Account(cls.conn, config.get('account',
+                                                   config['username']))
+
+        # avoid getting a prefix that stops halfway through an encoded
+        # character
+        prefix = Utils.create_name().decode("utf-8")[:10].encode("utf-8")
+
+        cls.versions_container = cls.account.container(prefix + "-versions")
+        if not cls.versions_container.create():
+            raise ResponseError(cls.conn.response)
+
+        cls.container = cls.account.container(prefix + "-objs")
+        if not cls.container.create(
+                hdrs={'X-Versions-Location': cls.versions_container.name}):
+            raise ResponseError(cls.conn.response)
+
+        container_info = cls.container.info()
+        # if versioning is off, then X-Versions-Location won't persist
+        cls.versioning_enabled = 'versions' in container_info
+
+
+class TestObjectVersioning(Base):
+    env = TestObjectVersioningEnv
+    set_up = False
+
+    def setUp(self):
+        super(TestObjectVersioning, self).setUp()
+        if self.env.versioning_enabled is False:
+            raise SkipTest("Object versioning not enabled")
+        elif self.env.versioning_enabled is not True:
+            # just some sanity checking
+            raise Exception(
+                "Expected versioning_enabled to be True/False, got %r" %
+                (self.env.versioning_enabled,))
+
+    def test_overwriting(self):
+        container = self.env.container
+        versions_container = self.env.versions_container
+        obj_name = Utils.create_name()
+
+        versioned_obj = container.file(obj_name)
+        versioned_obj.write("aaaaa")
+
+        self.assertEqual(0, versions_container.info()['object_count'])
+
+        versioned_obj.write("bbbbb")
+
+        # the old version got saved off
+        self.assertEqual(1, versions_container.info()['object_count'])
+        versioned_obj_name = versions_container.files()[0]
+        self.assertEqual(
+            "aaaaa", versions_container.file(versioned_obj_name).read())
+
+        # if we overwrite it again, there are two versions
+        versioned_obj.write("ccccc")
+        self.assertEqual(2, versions_container.info()['object_count'])
+
+        # as we delete things, the old contents return
+        self.assertEqual("ccccc", versioned_obj.read())
+        versioned_obj.delete()
+        self.assertEqual("bbbbb", versioned_obj.read())
+        versioned_obj.delete()
+        self.assertEqual("aaaaa", versioned_obj.read())
+        versioned_obj.delete()
+        self.assertRaises(ResponseError, versioned_obj.read)
+
+
+class TestObjectVersioningUTF8(Base2, TestObjectVersioning):
+    set_up = False
+
+
 if __name__ == '__main__':
     unittest.main()