mixmatch/mixmatch/proxy.py

412 lines
15 KiB
Python

# Copyright 2016 Massachusetts Open Cloud
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import six
import requests
from urllib3.util import retry
import flask
from flask import abort
import functools
from mixmatch import config
from mixmatch.config import LOG, CONF, service_providers
from mixmatch.session import app
from mixmatch.session import chunked_reader
from mixmatch.session import request
from mixmatch import auth
from mixmatch import extend
from mixmatch import model
from mixmatch import services
from mixmatch import utils
METHODS_ACCEPTED = ['GET', 'POST', 'PUT', 'DELETE', 'HEAD', 'PATCH']
RESOURCES_AGGREGATE = ['images', 'volumes', 'snapshots']
def stream_response(response):
yield response.raw.read()
def get_service(a):
"""Determine service type based on path."""
service = a.pop(0)
if service in ['image', 'volume', 'network']:
return service
else:
abort(404)
def is_json_response(response):
return response.headers.get('Content-Type') == 'application/json'
def is_token_header_key(string):
return string in ['X-AUTH-TOKEN', 'X-SERVICE-TOKEN']
def strip_tokens_from_headers(headers):
return {k: '<token omitted>' if is_token_header_key(k)
else headers[k] for k in headers}
def format_for_log(title=None, method=None, url=None, headers=None,
status_code=None):
return ''.join([
'{}: '.format(title) if title else '',
'Method: {}, '.format(method) if method else '',
'URL: {}, '.format(url) if url else '',
'Headers: {}, '.format(
strip_tokens_from_headers(headers)) if headers else '',
'Status Code: {}\n'.format(status_code) if status_code else ''
])
class RequestDetails(object):
def __init__(self, method, orig_path, headers):
local_path = orig_path.split('/')
# NOTE(knikolla): Request usually looks like this:
# /<service>/<version>/<project_id:uuid>/<res_type>/<res_id:uuid>
# or
# /<service>/<version>/<res_type>/<specific action>
self.method = method
self.service = get_service(local_path)
self.version = utils.safe_pop(local_path)
self.project_id = utils.pop_if_uuid(local_path)
self.project_in_path = True if self.project_id else False
self.action = local_path[:] # NOTE(knikolla): This includes
self.resource_type = utils.safe_pop(local_path) # this
self.resource_id = utils.pop_if_uuid(local_path) # and this
self.token = headers.get('X-AUTH-TOKEN', None)
self.headers = {k.upper(): v for k, v in dict(headers).items()}
self.path = orig_path
self.args = dict(request.args)
# NOTE(jfreud): if chunked transfer, body must be accessed through
# utilities found in mixmatch.session
self.body = request.data
self.environ = request.environ
class RequestHandler(object):
def __init__(self, method, path, headers):
self.details = RequestDetails(method, path, headers)
self.extensions = extend.get_matched_extensions(self.details)
self._set_strip_details(self.details)
self.enabled_sps = filter(
lambda sp: (self.details.service in
service_providers.get(CONF, sp).enabled_services),
CONF.service_providers
)
if 'MM-PROXY-LIST' in self.details.headers:
proxy_csv = self.details.headers.get('MM-PROXY-LIST', None)
if proxy_csv:
proxy_list = proxy_csv.split(',')
self.enabled_sps = filter(
lambda sp: (sp not in proxy_list), self.enabled_sps)
self.append_proxy(self.details.headers)
# TODO(jfreud): more sophisticated/ordered invocation of extensions
for extension in self.extensions:
out = extension.handle_request(self.details)
if out is not None:
self._forward = functools.partial(self._finalize, out)
return
if not self.details.version:
if CONF.aggregation:
# unversioned calls with no action
self._forward = self._list_api_versions
else:
self._forward = self._local_forward
return
if not self.details.resource_type:
# versioned calls with no action
abort(400)
mapping = None
if self.details.resource_id:
mapping = model.ResourceMapping.find(
resource_type=self.details.action[0],
resource_id=self.details.resource_id)
LOG.debug('Local Token: %s ' % self.details.token)
if 'MM-SERVICE-PROVIDER' in self.details.headers:
# The user wants a specific service provider, use that SP.
# FIXME(knikolla): This isn't exercised by any unit test
(self.service_provider, self.project_id) = (
self.details.headers['MM-SERVICE-PROVIDER'],
self.details.headers.get('MM-PROJECT-ID', None)
)
if self.service_provider not in self.enabled_sps:
abort(400)
if not self.project_id and self.service_provider != 'default':
self.project_id = auth.get_projects_at_sp(
self.service_provider,
self.details.token
)[0]
self._forward = self._targeted_forward
elif mapping:
# Which we already know the location of, use that SP.
self.service_provider = mapping.resource_sp
self.project_id = mapping.project_id
self._forward = self._targeted_forward
else:
self._forward = self._forward
LOG.info(format_for_log(title="Request to proxy",
method=self.details.method,
url=self.details.path,
headers=dict(self.details.headers)))
def _do_request_on(self, sp, project_id=None):
headers = self._prepare_headers(self.details.headers)
if self.details.token:
if sp == 'default':
auth_session = auth.get_local_auth(self.details.token)
else:
auth_session = auth.get_sp_auth(sp,
self.details.token,
project_id)
headers['X-AUTH-TOKEN'] = auth_session.get_token()
project_id = auth_session.get_project_id()
else:
project_id = None
url = services.construct_url(
sp,
self.details.service,
self.details.version,
self.details.action,
project_id=project_id if self.details.project_in_path else None
)
request_kwargs = {
'method': self.details.method,
'url': url,
'headers': headers,
'params': self._prepare_args(self.details.args)
}
if self.chunked:
resp = self.session.request(data=chunked_reader(),
**request_kwargs)
else:
resp = self.session.request(data=self.details.body,
stream=self.stream,
**request_kwargs)
LOG.info(format_for_log(title='Request from proxy',
method=self.details.method,
url=url,
headers=headers))
LOG.info(format_for_log(title='Response to proxy',
status_code=resp.status_code,
headers=resp.headers))
return resp
def _finalize(self, response):
if self.stream and not is_json_response(response):
text = flask.stream_with_context(
stream_response(response))
else:
text = response.text
final_response = flask.Response(
text,
response.status_code,
headers=self._prepare_headers(response.headers, fix_case=True)
)
LOG.info(format_for_log(title='Response from proxy',
status_code=final_response.status_code,
headers=dict(final_response.headers)))
return final_response
def _local_forward(self):
return self._finalize(self._do_request_on('default'))
def _targeted_forward(self):
return self._finalize(
self._do_request_on(self.service_provider, self.project_id))
def _forward(self):
if self.fallback_to_local:
return self._local_forward()
responses = {}
errors = collections.defaultdict(lambda: [])
for sp in self.enabled_sps:
if sp == 'default':
r = self._do_request_on('default')
if 200 <= r.status_code < 300:
responses['default'] = r
if not self.aggregate:
return self._finalize(r)
else:
errors[r.status_code].append(r)
else:
for p in auth.get_projects_at_sp(sp, self.details.token):
r = self._do_request_on(sp, p)
if 200 <= r.status_code < 300:
responses[(sp, p)] = r
if not self.aggregate:
return self._finalize(r)
else:
errors[r.status_code].append(r)
# NOTE(knikolla): If we haven't returned yet, either we're aggregating
# or there are errors.
if not errors:
# TODO(knikolla): Plug this into _finalize to have a common path
# for everything that is returned.
return flask.Response(
services.aggregate(responses,
self.details.action[0],
self.details.service,
version=self.details.version,
params=self.details.args,
path=request.base_url,
strip_details=self.strip_details),
200,
content_type='application/json'
)
if six.viewkeys(errors) == {404}:
return self._finalize(errors[404][0])
else:
utils.safe_pop(errors, 404)
if len(errors.keys()) == 1:
return self._finalize(list(errors.values())[0][0])
# TODO(jfreud): log
return flask.Response("Something strange happened.\n", 500)
def _list_api_versions(self):
return services.list_api_versions(self.details.service,
request.base_url)
def forward(self):
return self._forward()
@staticmethod
def _prepare_headers(user_headers, fix_case=False):
# NOTE(jfreud): because this function may be called with either request
# headers or response headers, sometimes the header keys may not be
# already capitalized
if fix_case:
user_headers = {k.upper(): v for k, v in
dict(user_headers).items()}
headers = dict()
headers['ACCEPT'] = user_headers.get('ACCEPT', '')
headers['CONTENT-TYPE'] = user_headers.get('CONTENT-TYPE', '')
accepted_headers = ['OPENSTACK-API-VERSION', 'MM-PROXY-LIST']
for key, value in user_headers.items():
if ((key.startswith('X-') and not is_token_header_key(key)) or
key in accepted_headers):
headers[key] = value
return headers
@staticmethod
def _prepare_args(args):
"""Prepare the GET arguments by removing the limit and marker.
This is because the id of the marker will only be present in one of
the service providers.
"""
if CONF.aggregation:
args.pop('limit', None)
args.pop('marker', None)
return args
@staticmethod
def append_proxy(headers):
proxy_list = headers.get('MM-PROXY-LIST', None)
default_sp_name = service_providers.get(CONF, 'default').sp_name
if proxy_list:
proxy_list = proxy_list + ',' + default_sp_name
else:
proxy_list = default_sp_name
headers['MM-PROXY-LIST'] = proxy_list
@utils.CachedProperty
def chunked(self):
encoding = self.details.headers.get('TRANSFER-ENCODING', '')
return encoding.lower() == 'chunked'
@utils.CachedProperty
def stream(self):
return self.details.method == 'GET'
@utils.CachedProperty
def fallback_to_local(self):
"""Return true if can't aggregate or search."""
return ((self.aggregate and not CONF.aggregation) or
(not self.aggregate and not CONF.search_by_broadcast))
@utils.CachedProperty
def aggregate(self):
"""Return true if this is a case where we should aggregate."""
return (not self.details.resource_id and
self.details.method == 'GET' and
self.details.action[0] in RESOURCES_AGGREGATE)
@utils.CachedProperty
def session(self):
requests_session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
max_retries=retry.Retry(total=3,
read=3,
connect=3,
backoff_factor=0.3,
status_forcelist=[500, 502, 504])
)
requests_session.mount('http://', adapter=adapter)
requests_session.mount('https://', adapter=adapter)
return requests_session
def _set_strip_details(self, details):
# if request is to /volumes, change it
# to /volumes/detail for aggregation
# and set strip_details to true
if (details.service == 'volume' and
details.method == 'GET' and
utils.safe_get(details.action, -1) == 'volumes'):
self.strip_details = True
details.action.insert(len(details.action), 'detail')
else:
self.strip_details = False
@app.route('/', defaults={'path': ''}, methods=METHODS_ACCEPTED)
@app.route('/<path:path>', methods=METHODS_ACCEPTED)
def proxy(path):
k2k_request = RequestHandler(request.method, path, request.headers)
return k2k_request.forward()
def main():
config.configure()
extend.load_extensions()
if __name__ == "__main__":
main()
app.run(port=5001, threaded=True)