# Copyright (c) 2012-2013, Eucalyptus Systems, Inc. # # Permission to use, copy, modify, and/or distribute this software for # any purpose with or without fee is hereby granted, provided that the # above copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import copy import os.path import random import requests.exceptions import time import urlparse from .exceptions import ClientError, ServiceInitError from .util import aggregate_subclass_fields class BaseService(object): NAME = '' DESCRIPTION = '' API_VERSION = '' MAX_RETRIES = 4 ## TODO: check the config file AUTH_CLASS = None URL_ENVVAR = None ARGS = [] def __init__(self, config, log, **kwargs): self.args = kwargs self.config = config self.endpoint = None self.log = log self.session_args = {'verify': False} # SSL verification is opt-in self._session = None if self.AUTH_CLASS is not None: self.auth = self.AUTH_CLASS(self) else: self.auth = None @property def region_name(self): return self.config.get_region() def collect_arg_objs(self): service_args = aggregate_subclass_fields(self.__class__, 'ARGS') if self.auth is not None: auth_args = self.auth.collect_arg_objs() else: auth_args = [] return service_args + auth_args def preprocess_arg_objs(self, arg_objs): if self.auth is not None: self.auth.preprocess_arg_objs(arg_objs) def configure(self): # self.args gets highest precedence for self.endpoint and user/region self.process_url(self.args.get('url')) if self.args.get('userregion'): self.process_userregion(self.args['userregion']) # Environment comes next self.process_url(os.getenv(self.URL_ENVVAR)) # Finally, try the config file if self.endpoint is None: self.process_url(self.config.get_region_option(self.NAME + '-url')) # Ensure everything is okay and finish up self.validate_config() if self.auth is not None: self.auth.configure() @property def session(self): if self._session is not None: return self._session if requests.__version__ >= '1.0': self._session = requests.session() self._session.auth = self.auth for key, val in self.session_args.iteritems(): setattr(self._session, key, val) else: self._session = requests.session(auth=self.auth, **self.session_args) return self._session def validate_config(self): if self.endpoint is None: regions = ', '.join(sorted(self.config.regions.keys())) errmsg = 'no endpoint to connect to was given' if regions: errmsg += '. Known regions are ' + regions raise ServiceInitError(errmsg) def process_url(self, url): if url: if '::' in url: userregion, endpoint = url.split('::', 1) else: endpoint = url userregion = None if self.endpoint is None: self.endpoint = url if userregion: self.process_userregion(userregion) def process_userregion(self, userregion): if '@' in userregion: user, region = userregion.split('@', 1) else: user = None region = userregion if region and self.config.current_region is None: self.config.current_region = region if user and self.config.current_user is None: self.config.current_user = user def send_request(self, method='GET', path=None, params=None, headers=None, data=None): ## TODO: test url-encoding if path: # We can't simply use urljoin because a path might start with '/' # like it could for S3 keys that start with that character. if self.endpoint.endswith('/'): url = self.endpoint + path else: url = self.endpoint + '/' + path else: url = self.endpoint ## TODO: replace pre_send and post_request hooks for use with requests 1 hooks = {'pre_send': _log_request_data(self.log), 'response': _log_response_data(self.log), 'post_request': RetryOnStatuses((500, 503), self.MAX_RETRIES, logger=self.log)} try: return self.session.request(method=method, url=url, params=params, data=data, headers=headers, hooks=hooks) except requests.exceptions.ConnectionError as exc: raise ClientError('connection error') except requests.exceptions.RequestException as exc: raise ClientError(exc) class RetryOnStatuses(object): def __init__(self, statuses, max_retries, logger=None): self.statuses = statuses self.max_retries = max_retries self.current_try = 0 self.logger = logger def __call__(self, request): if (request.response.status_code in self.statuses and self.current_try < self.max_retries): # Exponential backoff self.current_try += 1 delay = (1 + random.random()) ** self.current_try if self.logger: self.logger.info('Retrying after %.3f seconds', delay) time.sleep((1 + random.random()) ** self.current_try) orig_response = request.response request.send(anyway=True) request.response.history = (orig_response.history + [orig_response] + request.response.history) def _log_request_data(logger): def __log_request_data(request): logger.debug('request method: %s', request.method) logger.debug('request url: %s', request.url) if isinstance(request.headers, dict): for key, val in sorted(request.headers.iteritems()): logger.debug('request header: %s: %s', key, val) if isinstance(request.params, dict): for key, val in sorted(request.params.iteritems()): logger.debug('request param: %s: %s', key, val) if isinstance(request.data, dict): for key, val in sorted(request.data.iteritems()): logger.debug('request data: %s: %s', key, val) return __log_request_data def _log_response_data(logger): def __log_response_data(response): if response.status_code >= 400: logger.error('response status: %i', response.status_code) elif response.status_code >= 300: logger.info('response status: %i', response.status_code) else: logger.debug('response status: %i', response.status_code) if isinstance(response.headers, dict): for key, val in sorted(response.headers.items()): logger.debug('response header: %s: %s', key, val) return __log_response_data