Rework authentication
Saharaclient's authentication is relatively simple. It resolves the sahara URL and the token prior to making any requests. Provide the means to use the features of the adapter directly and adopt those features. Change-Id: I7d0f8435b836d187a0396657b485d9d080a59ba1
This commit is contained in:
@@ -86,8 +86,11 @@ class ResourceManager(object):
|
||||
|
||||
def _create(self, url, data, response_key=None, dump_json=True):
|
||||
if dump_json:
|
||||
data = json.dumps(data)
|
||||
resp = self.api.post(url, data, json=dump_json)
|
||||
kwargs = {'json': data}
|
||||
else:
|
||||
kwargs = {'data': data}
|
||||
|
||||
resp = self.api.post(url, **kwargs)
|
||||
|
||||
if resp.status_code != 202:
|
||||
self._raise_api_exception(resp)
|
||||
@@ -100,8 +103,11 @@ class ResourceManager(object):
|
||||
|
||||
def _update(self, url, data, response_key=None, dump_json=True):
|
||||
if dump_json:
|
||||
data = json.dumps(data)
|
||||
resp = self.api.put(url, data, json=dump_json)
|
||||
kwargs = {'json': data}
|
||||
else:
|
||||
kwargs = {'data': data}
|
||||
|
||||
resp = self.api.put(url, **kwargs)
|
||||
|
||||
if resp.status_code != 202:
|
||||
self._raise_api_exception(resp)
|
||||
|
||||
@@ -13,15 +13,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
from keystoneclient import adapter
|
||||
from keystoneclient.openstack.common.apiclient import exceptions as kex
|
||||
from keystoneclient.auth.identity import v2
|
||||
from keystoneclient.auth.identity import v3
|
||||
from keystoneclient.auth import token_endpoint
|
||||
from keystoneclient import exceptions
|
||||
from keystoneclient import session as keystone_session
|
||||
from keystoneclient.v2_0 import client as keystone_client_v2
|
||||
from keystoneclient.v3 import client as keystone_client_v3
|
||||
|
||||
from saharaclient.api import cluster_templates
|
||||
from saharaclient.api import clusters
|
||||
from saharaclient.api import data_sources
|
||||
from saharaclient.api import httpclient
|
||||
from saharaclient.api import images
|
||||
from saharaclient.api import job_binaries
|
||||
from saharaclient.api import job_binary_internals
|
||||
@@ -31,72 +35,55 @@ from saharaclient.api import node_group_templates
|
||||
from saharaclient.api import plugins
|
||||
|
||||
|
||||
class HTTPClient(adapter.Adapter):
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
kwargs.setdefault('raise_exc', False)
|
||||
return super(HTTPClient, self).request(*args, **kwargs)
|
||||
|
||||
|
||||
class Client(object):
|
||||
def __init__(self, username=None, api_key=None, project_id=None,
|
||||
project_name=None, auth_url=None, sahara_url=None,
|
||||
endpoint_type='publicURL', service_type='data-processing',
|
||||
service_name=None, region_name=None,
|
||||
input_auth_token=None, session=None, auth=None,
|
||||
insecure=False, cacert=None):
|
||||
insecure=False, cacert=None, **kwargs):
|
||||
|
||||
keystone = None
|
||||
sahara_catalog_url = sahara_url
|
||||
if not session:
|
||||
warnings.warn('Passing authentication parameters to saharaclient '
|
||||
'is deprecated. Please construct and pass an '
|
||||
'authenticated session object directly.',
|
||||
DeprecationWarning)
|
||||
|
||||
if not input_auth_token:
|
||||
if input_auth_token:
|
||||
auth = token_endpoint.Token(sahara_url, input_auth_token)
|
||||
|
||||
if session:
|
||||
keystone = adapter.LegacyJsonAdapter(
|
||||
session=session,
|
||||
auth=auth,
|
||||
interface=endpoint_type,
|
||||
service_type=service_type,
|
||||
service_name=service_name,
|
||||
region_name=region_name)
|
||||
input_auth_token = keystone.session.get_token(auth)
|
||||
if not sahara_catalog_url:
|
||||
try:
|
||||
sahara_catalog_url = keystone.session.get_endpoint(
|
||||
auth, interface=endpoint_type,
|
||||
service_type=service_type)
|
||||
except kex.EndpointNotFound:
|
||||
# This is support of 'data_processing' service spelling
|
||||
# which was used for releases before Kilo
|
||||
service_type = service_type.replace('-', '_')
|
||||
sahara_catalog_url = keystone.session.get_endpoint(
|
||||
auth, interface=endpoint_type,
|
||||
service_type=service_type)
|
||||
else:
|
||||
keystone = self.get_keystone_client(
|
||||
username=username,
|
||||
api_key=api_key,
|
||||
auth_url=auth_url,
|
||||
project_id=project_id,
|
||||
project_name=project_name)
|
||||
input_auth_token = keystone.auth_token
|
||||
auth = self._get_keystone_auth(auth_url=auth_url,
|
||||
username=username,
|
||||
api_key=api_key,
|
||||
project_id=project_id,
|
||||
project_name=project_name)
|
||||
|
||||
if not input_auth_token:
|
||||
raise RuntimeError("Not Authorized")
|
||||
verify = True
|
||||
if insecure:
|
||||
verify = False
|
||||
elif cacert:
|
||||
verify = cacert
|
||||
|
||||
if not sahara_catalog_url:
|
||||
catalog = keystone.service_catalog.get_endpoints(service_type)
|
||||
if service_type not in catalog:
|
||||
# This is support of 'data_processing' service spelling
|
||||
# which was used for releases before Kilo
|
||||
service_type = service_type.replace('-', '_')
|
||||
catalog = keystone.service_catalog.get_endpoints(service_type)
|
||||
session = keystone_session.Session(verify=verify)
|
||||
service_type = self._determine_service_type(session,
|
||||
auth,
|
||||
service_type,
|
||||
endpoint_type)
|
||||
|
||||
if service_type in catalog:
|
||||
for e_type, endpoint in catalog.get(service_type)[0].items():
|
||||
if str(e_type).lower() == str(endpoint_type).lower():
|
||||
sahara_catalog_url = endpoint
|
||||
break
|
||||
if not sahara_catalog_url:
|
||||
raise RuntimeError("Could not find Sahara endpoint in catalog")
|
||||
kwargs.setdefault('interface', endpoint_type)
|
||||
kwargs.setdefault('endpoint_override', sahara_url)
|
||||
|
||||
client = httpclient.HTTPClient(sahara_catalog_url,
|
||||
input_auth_token,
|
||||
insecure=insecure,
|
||||
cacert=cacert)
|
||||
client = HTTPClient(session=session,
|
||||
auth=auth,
|
||||
service_type=service_type,
|
||||
**kwargs)
|
||||
|
||||
self.clusters = clusters.ClusterManager(client)
|
||||
self.cluster_templates = (
|
||||
@@ -116,27 +103,52 @@ class Client(object):
|
||||
job_binary_internals.JobBinaryInternalsManager(client)
|
||||
)
|
||||
|
||||
def get_keystone_client(self, username=None, api_key=None, auth_url=None,
|
||||
token=None, project_id=None, project_name=None):
|
||||
def _get_keystone_auth(self, username=None, api_key=None, auth_url=None,
|
||||
project_id=None, project_name=None):
|
||||
if not auth_url:
|
||||
raise RuntimeError("No auth url specified")
|
||||
|
||||
if not getattr(self, "keystone_client", None):
|
||||
imported_client = (keystone_client_v2 if "v2.0" in auth_url
|
||||
else keystone_client_v3)
|
||||
if 'v2.0' in auth_url:
|
||||
return v2.Password(auth_url=auth_url,
|
||||
username=username,
|
||||
password=api_key,
|
||||
tenant_id=project_id,
|
||||
tenant_name=project_name)
|
||||
else:
|
||||
# NOTE(jamielennox): Setting these to default is what
|
||||
# keystoneclient does in the event they are not passed.
|
||||
return v3.Password(auth_url=auth_url,
|
||||
username=username,
|
||||
password=api_key,
|
||||
user_domain_id='default',
|
||||
project_id=project_id,
|
||||
project_name=project_name,
|
||||
project_domain_id='default')
|
||||
|
||||
self.keystone_client = imported_client.Client(
|
||||
username=username,
|
||||
password=api_key,
|
||||
token=token,
|
||||
tenant_id=project_id,
|
||||
tenant_name=project_name,
|
||||
auth_url=auth_url,
|
||||
endpoint=auth_url)
|
||||
@staticmethod
|
||||
def _determine_service_type(session, auth, service_type, interface):
|
||||
"""Check a catalog for data-processing or data_processing"""
|
||||
|
||||
self.keystone_client.authenticate()
|
||||
# NOTE(jamielennox): calling get_endpoint forces an auth on
|
||||
# initialization which is required for backwards compatibility. It
|
||||
# also allows us to reset the service type if not in the catalog.
|
||||
for st in (service_type, service_type.replace('-', '_')):
|
||||
try:
|
||||
url = auth.get_endpoint(session,
|
||||
service_type=st,
|
||||
interface=interface)
|
||||
except exceptions.Unauthorized:
|
||||
raise RuntimeError("Not Authorized")
|
||||
except exceptions.EndpointNotFound:
|
||||
# NOTE(jamielennox): bug #1428447. This should not be
|
||||
# raised, instead None should be returned. Handle in case
|
||||
# it changes in the future
|
||||
url = None
|
||||
|
||||
return self.keystone_client
|
||||
if url:
|
||||
return st
|
||||
|
||||
raise RuntimeError("Could not find Sahara endpoint in catalog")
|
||||
|
||||
@staticmethod
|
||||
def get_projects_list(keystone_client):
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# Copyright (c) 2013 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class HTTPClient(object):
|
||||
def __init__(self, base_url, token, insecure=False, cacert=None):
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
|
||||
if insecure:
|
||||
self.verify_cert = False
|
||||
else:
|
||||
if cacert:
|
||||
self.verify_cert = cacert
|
||||
else:
|
||||
self.verify_cert = True
|
||||
|
||||
def get(self, url):
|
||||
return requests.get(self.base_url + url,
|
||||
headers={'x-auth-token': self.token},
|
||||
verify=self.verify_cert)
|
||||
|
||||
def post(self, url, body, json=True):
|
||||
headers = {'x-auth-token': self.token}
|
||||
if json:
|
||||
headers['content-type'] = 'application/json'
|
||||
return requests.post(self.base_url + url, body, headers=headers,
|
||||
verify=self.verify_cert)
|
||||
|
||||
def put(self, url, body, json=True):
|
||||
headers = {'x-auth-token': self.token}
|
||||
if json:
|
||||
headers['content-type'] = 'application/json'
|
||||
return requests.put(self.base_url + url, body, headers=headers,
|
||||
verify=self.verify_cert)
|
||||
|
||||
def delete(self, url):
|
||||
return requests.delete(self.base_url + url,
|
||||
headers={'x-auth-token': self.token},
|
||||
verify=self.verify_cert)
|
||||
@@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
from saharaclient.api import base
|
||||
|
||||
|
||||
@@ -40,7 +38,7 @@ class ImageManager(base.ResourceManager):
|
||||
body = {"username": user_name,
|
||||
"description": desc}
|
||||
|
||||
resp = self.api.post('/images/%s' % image_id, json.dumps(body))
|
||||
resp = self.api.post('/images/%s' % image_id, json=body)
|
||||
if resp.status_code != 202:
|
||||
raise RuntimeError('Failed to register image %s' % image_id)
|
||||
|
||||
@@ -55,14 +53,14 @@ class ImageManager(base.ResourceManager):
|
||||
|
||||
if len(to_add) != 0:
|
||||
resp = self.api.post('/images/%s/tag' % image_id,
|
||||
json.dumps({'tags': to_add}))
|
||||
json={'tags': to_add})
|
||||
|
||||
if resp.status_code != 202:
|
||||
raise RuntimeError('Failed to add tags to image %s' % image_id)
|
||||
|
||||
if len(to_remove) != 0:
|
||||
resp = self.api.post('/images/%s/untag' % image_id,
|
||||
json.dumps({'tags': to_remove}))
|
||||
json={'tags': to_remove})
|
||||
|
||||
if resp.status_code != 202:
|
||||
raise RuntimeError('Failed to remove tags from image %s' %
|
||||
|
||||
@@ -48,7 +48,7 @@ class PluginManager(base.ResourceManager):
|
||||
(plugin_name,
|
||||
hadoop_version,
|
||||
urlparse.quote(template_name)),
|
||||
filecontent)
|
||||
data=filecontent)
|
||||
if resp.status_code != 202:
|
||||
raise RuntimeError('Failed to upload template file for plugin "%s"'
|
||||
' and version "%s"' %
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
# Copyright (c) 2013 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from requests_mock.contrib import fixture
|
||||
import testtools
|
||||
|
||||
from saharaclient.api import httpclient
|
||||
|
||||
|
||||
class ResourceTest(testtools.TestCase):
|
||||
|
||||
URL = 'http://localhost'
|
||||
TOKEN = 'token'
|
||||
|
||||
def setUp(self):
|
||||
super(ResourceTest, self).setUp()
|
||||
self.responses = self.useFixture(fixture.Fixture())
|
||||
|
||||
def test_post_json_content_type(self):
|
||||
m = self.responses.post(self.URL + '/test')
|
||||
|
||||
client = httpclient.HTTPClient(self.URL, self.TOKEN)
|
||||
client.post('/test', '{"json":"True"}')
|
||||
|
||||
self.assertEqual(1, m.call_count)
|
||||
self.assertEqual('application/json',
|
||||
m.last_request.headers['content-type'])
|
||||
|
||||
def test_put_json_content_type(self):
|
||||
m = self.responses.put(self.URL + '/test')
|
||||
|
||||
client = httpclient.HTTPClient(self.URL, self.TOKEN)
|
||||
client.put('/test', '{"json":"True"}')
|
||||
|
||||
self.assertEqual(1, m.call_count)
|
||||
self.assertEqual('application/json',
|
||||
m.last_request.headers['content-type'])
|
||||
|
||||
def test_post_nonjson_content_type(self):
|
||||
m = self.responses.post(self.URL + '/test')
|
||||
|
||||
client = httpclient.HTTPClient(self.URL, self.TOKEN)
|
||||
client.post('/test', 'nonjson', json=False)
|
||||
|
||||
self.assertEqual(1, m.call_count)
|
||||
self.assertNotIn("content-type", m.last_request.headers)
|
||||
|
||||
def test_put_nonjson_content_type(self):
|
||||
m = self.responses.put(self.URL + '/test')
|
||||
|
||||
client = httpclient.HTTPClient(self.URL, self.TOKEN)
|
||||
client.put('/test', 'nonjson', json=False)
|
||||
|
||||
self.assertEqual(1, m.call_count)
|
||||
self.assertNotIn("content-type", m.last_request.headers)
|
||||
Reference in New Issue
Block a user