Files
deb-python-requestbuilder/requestbuilder/service.py
2013-02-02 23:56:23 -08:00

205 lines
7.6 KiB
Python

# Copyright (c) 2012-2013, Eucalyptus Systems, Inc.
#
# Permission to use, copy, modify, and/or distribute this software for
# any purpose with or without fee is hereby granted, provided that the
# above copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import copy
import os.path
import random
import requests.exceptions
import time
import urlparse
from .exceptions import ClientError, ServiceInitError
from .util import aggregate_subclass_fields
class BaseService(object):
NAME = ''
DESCRIPTION = ''
API_VERSION = ''
MAX_RETRIES = 4 ## TODO: check the config file
AUTH_CLASS = None
URL_ENVVAR = None
ARGS = []
def __init__(self, config, log, **kwargs):
self.args = kwargs
self.config = config
self.endpoint = None
self.log = log
self.session_args = {'verify': False} # SSL verification is opt-in
self._session = None
if self.AUTH_CLASS is not None:
self.auth = self.AUTH_CLASS(self)
else:
self.auth = None
@property
def region_name(self):
return self.config.get_region()
def collect_arg_objs(self):
service_args = aggregate_subclass_fields(self.__class__, 'ARGS')
if self.auth is not None:
auth_args = self.auth.collect_arg_objs()
else:
auth_args = []
return service_args + auth_args
def preprocess_arg_objs(self, arg_objs):
if self.auth is not None:
self.auth.preprocess_arg_objs(arg_objs)
def configure(self):
# self.args gets highest precedence for self.endpoint and user/region
self.process_url(self.args.get('url'))
if self.args.get('userregion'):
self.process_userregion(self.args['userregion'])
# Environment comes next
self.process_url(os.getenv(self.URL_ENVVAR))
# Finally, try the config file
if self.endpoint is None:
self.process_url(self.config.get_region_option(self.NAME + '-url'))
# Ensure everything is okay and finish up
self.validate_config()
if self.auth is not None:
self.auth.configure()
@property
def session(self):
if self._session is not None:
return self._session
if requests.__version__ >= '1.0':
self._session = requests.session()
self._session.auth = self.auth
for key, val in self.session_args.iteritems():
setattr(self._session, key, val)
else:
self._session = requests.session(auth=self.auth,
**self.session_args)
return self._session
def validate_config(self):
if self.endpoint is None:
regions = ', '.join(sorted(self.config.regions.keys()))
errmsg = 'no endpoint to connect to was given'
if regions:
errmsg += '. Known regions are ' + regions
raise ServiceInitError(errmsg)
def process_url(self, url):
if url:
if '::' in url:
userregion, endpoint = url.split('::', 1)
else:
endpoint = url
userregion = None
if self.endpoint is None:
self.endpoint = url
if userregion:
self.process_userregion(userregion)
def process_userregion(self, userregion):
if '@' in userregion:
user, region = userregion.split('@', 1)
else:
user = None
region = userregion
if region and self.config.current_region is None:
self.config.current_region = region
if user and self.config.current_user is None:
self.config.current_user = user
def send_request(self, method='GET', path=None, params=None, headers=None,
data=None):
## TODO: test url-encoding
if path:
# We can't simply use urljoin because a path might start with '/'
# like it could for S3 keys that start with that character.
if self.endpoint.endswith('/'):
url = self.endpoint + path
else:
url = self.endpoint + '/' + path
else:
url = self.endpoint
## TODO: replace pre_send and post_request hooks for use with requests 1
hooks = {'pre_send': _log_request_data(self.log),
'response': _log_response_data(self.log),
'post_request': RetryOnStatuses((500, 503), self.MAX_RETRIES,
logger=self.log)}
try:
return self.session.request(method=method, url=url, params=params,
data=data, headers=headers,
hooks=hooks)
except requests.exceptions.ConnectionError as exc:
raise ClientError('connection error')
except requests.exceptions.RequestException as exc:
raise ClientError(exc)
class RetryOnStatuses(object):
def __init__(self, statuses, max_retries, logger=None):
self.statuses = statuses
self.max_retries = max_retries
self.current_try = 0
self.logger = logger
def __call__(self, request):
if (request.response.status_code in self.statuses and
self.current_try < self.max_retries):
# Exponential backoff
self.current_try += 1
delay = (1 + random.random()) ** self.current_try
if self.logger:
self.logger.info('Retrying after %.3f seconds', delay)
time.sleep((1 + random.random()) ** self.current_try)
orig_response = request.response
request.send(anyway=True)
request.response.history = (orig_response.history +
[orig_response] + request.response.history)
def _log_request_data(logger):
def __log_request_data(request):
logger.debug('request method: %s', request.method)
logger.debug('request url: %s', request.url)
if isinstance(request.headers, dict):
for key, val in sorted(request.headers.iteritems()):
logger.debug('request header: %s: %s', key, val)
if isinstance(request.params, dict):
for key, val in sorted(request.params.iteritems()):
logger.debug('request param: %s: %s', key, val)
if isinstance(request.data, dict):
for key, val in sorted(request.data.iteritems()):
logger.debug('request data: %s: %s', key, val)
return __log_request_data
def _log_response_data(logger):
def __log_response_data(response):
if response.status_code >= 400:
logger.error('response status: %i', response.status_code)
elif response.status_code >= 300:
logger.info('response status: %i', response.status_code)
else:
logger.debug('response status: %i', response.status_code)
if isinstance(response.headers, dict):
for key, val in sorted(response.headers.items()):
logger.debug('response header: %s: %s', key, val)
return __log_response_data