Move TestClient to test_versions

TestClient is only used in test_versions, so move it there to keep
tests.unit.core more focused.

Change-Id: Ic8f0b6426d3332d940e26cd371c349495971c056
This commit is contained in:
Brant Knudson
2015-08-29 11:10:28 -05:00
parent 48ad4b7663
commit ffedc364e1
2 changed files with 48 additions and 48 deletions

View File

@@ -37,7 +37,6 @@ from paste.deploy import loadwsgi
import six
from sqlalchemy import exc
from testtools import testcase
import webob
# NOTE(ayoung)
# environment.use_eventlet must run before any of the code that will
@@ -224,36 +223,6 @@ class UnexpectedExit(Exception):
pass
class TestClient(object):
def __init__(self, app=None, token=None):
self.app = app
self.token = token
def request(self, method, path, headers=None, body=None):
if headers is None:
headers = {}
if self.token:
headers.setdefault('X-Auth-Token', self.token)
req = webob.Request.blank(path)
req.method = method
for k, v in headers.items():
req.headers[k] = v
if body:
req.body = body
return req.get_response(self.app)
def get(self, path, headers=None):
return self.request('GET', path=path, headers=headers)
def post(self, path, headers=None, body=None):
return self.request('POST', path=path, headers=headers, body=body)
def put(self, path, headers=None, body=None):
return self.request('PUT', path=path, headers=headers, body=body)
def new_ref():
"""Populates a ref with attributes common to some API entities."""
return {

View File

@@ -21,6 +21,7 @@ import mock
from oslo_config import cfg
from oslo_serialization import jsonutils
from testtools import matchers as tt_matchers
import webob
from keystone.common import json_home
from keystone import controllers
@@ -621,6 +622,36 @@ V3_JSON_HOME_RESOURCES_INHERIT_ENABLED.update(
)
class TestClient(object):
def __init__(self, app=None, token=None):
self.app = app
self.token = token
def request(self, method, path, headers=None, body=None):
if headers is None:
headers = {}
if self.token:
headers.setdefault('X-Auth-Token', self.token)
req = webob.Request.blank(path)
req.method = method
for k, v in headers.items():
req.headers[k] = v
if body:
req.body = body
return req.get_response(self.app)
def get(self, path, headers=None):
return self.request('GET', path=path, headers=headers)
def post(self, path, headers=None, body=None):
return self.request('POST', path=path, headers=headers, body=body)
def put(self, path, headers=None, body=None):
return self.request('PUT', path=path, headers=headers, body=body)
class _VersionsEqual(tt_matchers.MatchesListwise):
def __init__(self, expected):
super(_VersionsEqual, self).__init__([
@@ -664,7 +695,7 @@ class VersionTestCase(unit.TestCase):
link['href'] = port
def test_public_versions(self):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
resp = client.get('/')
self.assertEqual(300, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -681,7 +712,7 @@ class VersionTestCase(unit.TestCase):
self.assertThat(data, _VersionsEqual(expected))
def test_admin_versions(self):
client = unit.TestClient(self.admin_app)
client = TestClient(self.admin_app)
resp = client.get('/')
self.assertEqual(300, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -701,7 +732,7 @@ class VersionTestCase(unit.TestCase):
self.config_fixture.config(public_endpoint=None, admin_endpoint=None)
for app in (self.public_app, self.admin_app):
client = unit.TestClient(app)
client = TestClient(app)
resp = client.get('/')
self.assertEqual(300, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -717,7 +748,7 @@ class VersionTestCase(unit.TestCase):
self.assertThat(data, _VersionsEqual(expected))
def test_public_version_v2(self):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
resp = client.get('/v2.0/')
self.assertEqual(200, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -728,7 +759,7 @@ class VersionTestCase(unit.TestCase):
self.assertEqual(expected, data)
def test_admin_version_v2(self):
client = unit.TestClient(self.admin_app)
client = TestClient(self.admin_app)
resp = client.get('/v2.0/')
self.assertEqual(200, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -741,7 +772,7 @@ class VersionTestCase(unit.TestCase):
def test_use_site_url_if_endpoint_unset_v2(self):
self.config_fixture.config(public_endpoint=None, admin_endpoint=None)
for app in (self.public_app, self.admin_app):
client = unit.TestClient(app)
client = TestClient(app)
resp = client.get('/v2.0/')
self.assertEqual(200, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -750,7 +781,7 @@ class VersionTestCase(unit.TestCase):
self.assertEqual(data, expected)
def test_public_version_v3(self):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
resp = client.get('/v3/')
self.assertEqual(200, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -762,7 +793,7 @@ class VersionTestCase(unit.TestCase):
@utils.wip('waiting on bug #1381961')
def test_admin_version_v3(self):
client = unit.TestClient(self.admin_app)
client = TestClient(self.admin_app)
resp = client.get('/v3/')
self.assertEqual(200, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -775,7 +806,7 @@ class VersionTestCase(unit.TestCase):
def test_use_site_url_if_endpoint_unset_v3(self):
self.config_fixture.config(public_endpoint=None, admin_endpoint=None)
for app in (self.public_app, self.admin_app):
client = unit.TestClient(app)
client = TestClient(app)
resp = client.get('/v3/')
self.assertEqual(200, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -785,7 +816,7 @@ class VersionTestCase(unit.TestCase):
@mock.patch.object(controllers, '_VERSIONS', ['v3'])
def test_v2_disabled(self):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
# request to /v2.0 should fail
resp = client.get('/v2.0/')
self.assertEqual(404, resp.status_int)
@@ -818,7 +849,7 @@ class VersionTestCase(unit.TestCase):
@mock.patch.object(controllers, '_VERSIONS', ['v2.0'])
def test_v3_disabled(self):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
# request to /v3 should fail
resp = client.get('/v3/')
self.assertEqual(404, resp.status_int)
@@ -850,7 +881,7 @@ class VersionTestCase(unit.TestCase):
self.assertEqual(v2_only_response, data)
def _test_json_home(self, path, exp_json_home_data):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
resp = client.get(path, headers={'Accept': 'application/json-home'})
self.assertThat(resp.status, tt_matchers.Equals('200 OK'))
@@ -883,7 +914,7 @@ class VersionTestCase(unit.TestCase):
# Accept headers with multiple types and qvalues are handled.
def make_request(accept_types=None):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
headers = None
if accept_types:
headers = {'Accept': accept_types}
@@ -969,7 +1000,7 @@ class VersionSingleAppTestCase(unit.TestCase):
else:
return CONF.eventlet_server.public_port
app = self.loadapp('keystone', app_name)
client = unit.TestClient(app)
client = TestClient(app)
resp = client.get('/')
self.assertEqual(300, resp.status_int)
data = jsonutils.loads(resp.body)
@@ -1015,7 +1046,7 @@ class VersionInheritEnabledTestCase(unit.TestCase):
# If the request is /v3 and the Accept header is application/json-home
# then the server responds with a JSON Home document.
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
resp = client.get('/v3/', headers={'Accept': 'application/json-home'})
self.assertThat(resp.status, tt_matchers.Equals('200 OK'))
@@ -1055,7 +1086,7 @@ class VersionBehindSslTestCase(unit.TestCase):
return expected
def test_versions_without_headers(self):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
host_name = 'host-%d' % random.randint(10, 30)
host_port = random.randint(10000, 30000)
host = 'http://%s:%s/' % (host_name, host_port)
@@ -1066,7 +1097,7 @@ class VersionBehindSslTestCase(unit.TestCase):
self.assertThat(data, _VersionsEqual(expected))
def test_versions_with_header(self):
client = unit.TestClient(self.public_app)
client = TestClient(self.public_app)
host_name = 'host-%d' % random.randint(10, 30)
host_port = random.randint(10000, 30000)
resp = client.get('http://%s:%s/' % (host_name, host_port),