From f0819efe6b05e3eae125dc22cd37a833afd4910d Mon Sep 17 00:00:00 2001
From: Dean Troyer <dtroyer@gmail.com>
Date: Fri, 21 Apr 2017 11:49:03 -0500
Subject: [PATCH] Add endpoint hook to BaseAPI

Duplicate the basic endpoint cleanup (removing trailing '/') into
a method that can be overridden in subclasses to do additional things
like API-specific version handling.

Add more tests for the combinations of endpoint and url and some
corner-case checking.

Change-Id: I4b4f2abdec29e4d29b61338077f9c1320cec5bb8
---
 osc_lib/api/api.py            |  27 +++++++-
 osc_lib/tests/api/test_api.py | 118 +++++++++++++++++++++++++++++++++-
 2 files changed, 142 insertions(+), 3 deletions(-)

diff --git a/osc_lib/api/api.py b/osc_lib/api/api.py
index 027468f..1d32d99 100644
--- a/osc_lib/api/api.py
+++ b/osc_lib/api/api.py
@@ -14,6 +14,7 @@
 """Base API Library"""
 
 import simplejson as json
+import six
 
 from keystoneauth1 import exceptions as ksa_exceptions
 from keystoneauth1 import session as ksa_session
@@ -69,7 +70,28 @@ class BaseAPI(object):
             self.session = session
 
         self.service_type = service_type
-        self.endpoint = endpoint
+        self.endpoint = self._munge_endpoint(endpoint)
+
+    def _munge_endpoint(self, endpoint):
+        """Hook to allow subclasses to massage the passed-in endpoint
+
+        Hook to massage passed-in endpoints from arbitrary sources,
+        including direct user input.  By default just remove trailing
+        '/' as all of our path info strings start with '/' and not all
+        services can handle '//' in their URLs.
+
+        Some subclasses will override this to do additional work, most
+        likely with regard to API versions.
+
+        :param string endpoint: The service endpoint, generally direct
+                                from the service catalog.
+        :return: The modified endpoint
+        """
+
+        if isinstance(endpoint, six.string_types):
+            return endpoint.rstrip('/')
+        else:
+            return endpoint
 
     def _request(self, method, url, session=None, **kwargs):
         """Perform call into session
@@ -99,6 +121,9 @@ class BaseAPI(object):
             if url:
                 url = '/'.join([self.endpoint.rstrip('/'), url.lstrip('/')])
             else:
+                # NOTE(dtroyer): This is left here after _munge_endpoint() is
+                #                added because endpoint is public and there is
+                #                no accounting for what may happen.
                 url = self.endpoint.rstrip('/')
         else:
             # Pass on the lack of URL unmolested to maintain the same error
diff --git a/osc_lib/tests/api/test_api.py b/osc_lib/tests/api/test_api.py
index 585e635..24b879a 100644
--- a/osc_lib/tests/api/test_api.py
+++ b/osc_lib/tests/api/test_api.py
@@ -13,7 +13,7 @@
 
 """Base API Library Tests"""
 
-from keystoneauth1 import exceptions as ks_exceptions
+from keystoneauth1 import exceptions as ksa_exceptions
 from keystoneauth1 import session
 
 from osc_lib.api import api
@@ -27,6 +27,22 @@ class TestBaseAPIDefault(api_fakes.TestSession):
         super(TestBaseAPIDefault, self).setUp()
         self.api = api.BaseAPI()
 
+    def test_baseapi_request_no_url(self):
+        self.requests_mock.register_uri(
+            'GET',
+            self.BASE_URL + '/qaz',
+            json=api_fakes.RESP_ITEM_1,
+            status_code=200,
+        )
+        self.assertRaises(
+            ksa_exceptions.EndpointNotFound,
+            self.api._request,
+            'GET',
+            '',
+        )
+        self.assertIsNotNone(self.api.session)
+        self.assertNotEqual(self.sess, self.api.session)
+
     def test_baseapi_request_url(self):
         self.requests_mock.register_uri(
             'GET',
@@ -47,7 +63,7 @@ class TestBaseAPIDefault(api_fakes.TestSession):
             status_code=200,
         )
         self.assertRaises(
-            ks_exceptions.EndpointNotFound,
+            ksa_exceptions.EndpointNotFound,
             self.api._request,
             'GET',
             '/qaz',
@@ -72,6 +88,104 @@ class TestBaseAPIDefault(api_fakes.TestSession):
         self.assertNotEqual(self.sess, self.api.session)
 
 
+class TestBaseAPIEndpointArg(api_fakes.TestSession):
+
+    def test_baseapi_endpoint_no_endpoint(self):
+        x_api = api.BaseAPI(
+            session=self.sess,
+        )
+        self.assertIsNotNone(x_api.session)
+        self.assertEqual(self.sess, x_api.session)
+        self.assertIsNone(x_api.endpoint)
+
+        self.requests_mock.register_uri(
+            'GET',
+            self.BASE_URL + '/qaz',
+            json=api_fakes.RESP_ITEM_1,
+            status_code=200,
+        )
+
+        # Normal url
+        self.assertRaises(
+            ksa_exceptions.EndpointNotFound,
+            x_api._request,
+            'GET',
+            '/qaz',
+        )
+
+        # No leading '/' url
+        self.assertRaises(
+            ksa_exceptions.EndpointNotFound,
+            x_api._request,
+            'GET',
+            'qaz',
+        )
+
+        # Extra leading '/' url
+        self.assertRaises(
+            ksa_exceptions.connection.UnknownConnectionError,
+            x_api._request,
+            'GET',
+            '//qaz',
+        )
+
+    def test_baseapi_endpoint_no_extra(self):
+        x_api = api.BaseAPI(
+            session=self.sess,
+            endpoint=self.BASE_URL,
+        )
+        self.assertIsNotNone(x_api.session)
+        self.assertEqual(self.sess, x_api.session)
+        self.assertEqual(self.BASE_URL, x_api.endpoint)
+
+        self.requests_mock.register_uri(
+            'GET',
+            self.BASE_URL + '/qaz',
+            json=api_fakes.RESP_ITEM_1,
+            status_code=200,
+        )
+
+        # Normal url
+        ret = x_api._request('GET', '/qaz')
+        self.assertEqual(api_fakes.RESP_ITEM_1, ret.json())
+
+        # No leading '/' url
+        ret = x_api._request('GET', 'qaz')
+        self.assertEqual(api_fakes.RESP_ITEM_1, ret.json())
+
+        # Extra leading '/' url
+        ret = x_api._request('GET', '//qaz')
+        self.assertEqual(api_fakes.RESP_ITEM_1, ret.json())
+
+    def test_baseapi_endpoint_extra(self):
+        x_api = api.BaseAPI(
+            session=self.sess,
+            endpoint=self.BASE_URL + '/',
+        )
+        self.assertIsNotNone(x_api.session)
+        self.assertEqual(self.sess, x_api.session)
+        self.assertEqual(self.BASE_URL, x_api.endpoint)
+
+        self.requests_mock.register_uri(
+            'GET',
+            self.BASE_URL + '/qaz',
+            json=api_fakes.RESP_ITEM_1,
+            status_code=200,
+        )
+
+        # Normal url
+        ret = x_api._request('GET', '/qaz')
+        self.assertEqual(api_fakes.RESP_ITEM_1, ret.json())
+
+        # No leading '/' url
+        ret = x_api._request('GET', 'qaz')
+        self.assertEqual(api_fakes.RESP_ITEM_1, ret.json())
+
+        # Extra leading '/' url
+        ret = x_api._request('GET', '//qaz')
+        self.assertEqual(api_fakes.RESP_ITEM_1, ret.json())
+
+
 class TestBaseAPIArgs(api_fakes.TestSession):
 
     def setUp(self):