Gather args from services and auth handlers

This change allows services and auth handlers to specify their own ARGS
lists that requests can then gather and feed into the ArgumentParsers
they create.  Arg routes are now actual objects (or for declarative use,
attrgetters) as opposed to magic values.  Requests needing to call other
requests can re-use their own service objects to propagate the args that
the latter gathered from the command line.

Configuration files are parsed much earlier than they used to be and can
now cause BaseCommand.__init__() to fail.  BaseCommand.run() handles
this.

Config objects gained the concept of a "current" region and user, which
services can set as they parse their command line args.  These values
are used as defaults when one goes to look up options.

BaseCommand.DEFAULT_ROUTE is now an attribute/property, default_route.

BaesCommand.print_result is now a noop.  Define it yourself if you need
a tool to output something during that step.

BaseRequest.ACTION is now BaseRequest.NAME.

This change breaks a large number of internal APIs.  Docstrings are
horribly out of date at this point and should be fixed fairly soon in a
future commit.
This commit is contained in:
Garrett Holmstrom
2013-02-01 22:39:52 -08:00
parent 90669fbf87
commit a4f8ed4277
7 changed files with 402 additions and 333 deletions

View File

@@ -1,4 +1,4 @@
# Copyright (c) 2012, Eucalyptus Systems, Inc.
# Copyright (c) 2012-2013, Eucalyptus Systems, Inc.
#
# Permission to use, copy, modify, and/or distribute this software for
# any purpose with or without fee is hereby granted, provided that the
@@ -13,6 +13,7 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import argparse
import operator
__version__ = '0.0'
@@ -111,19 +112,13 @@ class GenericTagFilter(Filter):
########## SINGLETONS ##########
# Indicates a parameter that should be sent to the server without a value
# Indicates a parameter that should be sent to the server without a value.
# Contrast this with empty strings, with are omitted from requests entirely.
EMPTY = type('EMPTY', (), {'__repr__': lambda self: "''",
'__str__': lambda self: ''})()
# Constants (enums?) used for arg routing
AUTH = type('AUTH', (), {'__repr__': lambda self: 'AUTH'})()
PARAMS = type('PARAMS', (), {'__repr__': lambda self: 'PARAMS'})()
SERVICE = type('SERVICE', (), {'__repr__': lambda self: 'SERVICE'})()
SESSION = type('SESSION', (), {'__repr__': lambda self: 'SESSION'})()
# Common args for query authentication
STD_AUTH_ARGS = [
Arg('-I', '--access-key-id', dest='key_id', metavar='KEY_ID',
route_to=AUTH),
Arg('-S', '--secret-key', dest='secret_key', metavar='KEY',
route_to=AUTH)]
# Getters used for arg routing
AUTH = operator.attrgetter('service.auth.args')
PARAMS = operator.attrgetter('params')
SERVICE = operator.attrgetter('service.args')
SESSION = operator.attrgetter('service.session_args')

View File

@@ -12,37 +12,96 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import argparse
import base64
import hashlib
import hmac
import os
import requests.auth
from six import text_type
import time
import urllib
import urlparse
from . import Arg, AUTH
from .exceptions import AuthError
from .util import aggregate_subclass_fields
ISO8601 = '%Y-%m-%dT%H:%M:%SZ'
class QuerySignatureV2Auth(requests.auth.AuthBase):
class BaseAuth(requests.auth.AuthBase):
ARGS = []
def __init__(self, service, **kwargs):
self.args = kwargs
self.config = service.config
self.log = service.log.getChild(self.__class__.__name__)
self.service = service
def collect_arg_objs(self):
return aggregate_subclass_fields(self.__class__, 'ARGS')
def preprocess_arg_objs(self, arg_objs):
pass
def configure(self):
pass
def __call__(self, req):
pass
class HmacKeyAuth(BaseAuth):
ARGS = [Arg('-I', '--access-key-id', dest='key_id', metavar='KEY_ID',
default=argparse.SUPPRESS, route_to=AUTH),
Arg('-S', '--secret-key', dest='secret_key', metavar='KEY',
default=argparse.SUPPRESS, route_to=AUTH)]
def configure(self):
# See if an AWS credential file was given in the environment
self.configure_from_aws_credential_file()
# Try the requestbuilder config file next
self.configure_from_configfile()
if not self.args.get('key_id'):
raise AuthError('missing access key ID')
if not self.args.get('secret_key'):
raise AuthError('missing secret key')
def configure_from_aws_credential_file(self):
if 'AWS_CREDENTIAL_FILE' in os.environ:
path = os.getenv('AWS_CREDENTIAL_FILE')
path = os.path.expandvars(path)
path = os.path.expanduser(path)
with open(path) as credfile:
for line in credfile:
line = line.split('#', 1)[0]
if '=' in line:
(key, val) = line.split('=', 1)
if key.strip() == 'AWSAccessKeyId':
self.args.setdefault('key_id', val.strip())
elif key.strip() == 'AWSSecretKey':
self.args.setdefault('secret_key', val.strip())
def configure_from_configfile(self):
config_key_id = self.config.get_user_option('key-id')
if config_key_id:
self.args.setdefault('key_id', config_key_id)
config_secret_key = self.config.get_user_option('secret-key',
redact=True)
if config_secret_key:
self.args.setdefault('secret_key', config_secret_key)
class QuerySigV2Auth(HmacKeyAuth):
'''
AWS signature version 2
http://docs.amazonwebservices.com/general/latest/gr/signature-version-2.html
'''
def __init__(self, service, key_id, secret_key, params_to_data=True):
self.service = service
if not key_id:
raise AuthError('missing access key ID')
if not secret_key:
raise AuthError('missing secret key')
self.key_id = key_id
self.hmac = hmac.new(secret_key, digestmod=hashlib.sha256)
self.log = self.service.log.getChild(self.__class__.__name__)
# Whether to convert params to data if POSTing with only the former
self.params_to_data = params_to_data
def __call__(self, req):
# We assume that req.params is a dict
req.params['AWSAccessKeyId'] = self.key_id
req.params['AWSAccessKeyId'] = self.args['key_id']
req.params['SignatureVersion'] = 2
req.params['SignatureMethod'] = 'HmacSHA256'
req.params['Timestamp'] = time.strftime(ISO8601, time.gmtime())
@@ -69,14 +128,13 @@ class QuerySignatureV2Auth(requests.auth.AuthBase):
return req
def convert_params_to_data(self, req):
if (self.params_to_data and req.method.upper() == 'POST' and
isinstance(req.params, dict)):
if req.method.upper() == 'POST' and isinstance(req.params, dict):
# POST with params -> use params as form data instead
self.log.debug('converting params to POST data')
req.data = req.params
req.params = None
def sign_string(self, to_sign):
hmac = self.hmac.copy()
hmac.update(to_sign)
return base64.b64encode(hmac.digest())
req_hmac = hmac.new(self.args['secret_key'], digestmod=hashlib.sha256)
req_hmac.update(to_sign)
return base64.b64encode(req_hmac.digest())

View File

@@ -29,6 +29,7 @@ except ImportError:
from . import __version__, Arg, MutuallyExclusiveArgList
from .config import Config
from .logging import configure_root_logger
from .util import aggregate_subclass_fields
class BaseCommand(object):
'''
@@ -59,99 +60,121 @@ class BaseCommand(object):
those of their parent classes.
'''
VERSION = 'requestbuilder ' + __version__
DESCRIPTION = ''
ARGS = [Arg('-D', '--debug', action='store_true', route_to=None,
help='show debugging output'),
Arg('--debugger', action='store_true', route_to=None,
help='enable interactive debugger on error')]
DEFAULT_ROUTE = None
CONFIG_FILES = ['/etc/requestbuilder.ini']
VERSION = 'requestbuilder ' + __version__
def __init__(self, _do_cli=False, **kwargs):
# Note to programmer: when run() is initializing the first object it
# can't catch exceptions that may result from accesses to self.config.
# To deal with this, run() disables config file parsing during this
# process to expose premature access to self.config during testing.
self.args = {} # populated later
self.args = kwargs
self.config = None # created by _process_configfile
self.log = None # created by _configure_logging
self._allowed_args = None # created by _build_parser
self._arg_routes = {}
self._cli_parser = None # created by _build_parser
self._config = None
self._configure_logging()
self._process_configfiles()
if _do_cli:
self._configure_global_logging()
# We need to enforce arg constraints in one location to make this
# framework equally useful for chained commands and those driven
# directly from the command line. Thus, we do most of the parsing/
# validation work in __init__ as opposed to putting it off until
# we hit CLI-specific code.
# validation work before __init__ returns as opposed to putting it
# off until we hit CLI-specific code.
#
# Derived classes MUST call this method to ensure things stay sane.
self.__do_cli = _do_cli
self._post_init()
def _post_init(self):
self._build_parser()
if self.__do_cli:
# Distribute CLI args to the various places that need them
self.process_cli_args()
self.configure()
# Come up with a list of args that the arg parser will allow
for key, val in kwargs.iteritems():
if key in self._allowed_args:
self.args[key] = val
else:
raise TypeError('__init__() got an unexpected keyword '
'argument \'{0}\'; allowed arguments are {1}'
.format(key, ', '.join(self._allowed_args)))
@property
def default_route(self):
# This is a property so we can return something that references self.
return None
## TODO: AUTH PARAM PASSING (probably involves the service class)
if _do_cli:
self._process_cli_args()
else:
# TODO: enforce arg constraints when not pulling from the CLI
pass
@property
def config_files(self):
# This list may need to be computed on the fly.
return []
def _configure_logging(self):
# Does not have access to self.config
self.log = logging.getLogger(self.name)
if self.debug:
self.log.setLevel(logging.DEBUG)
def _process_configfiles(self):
self.config = Config(self.config_files, log=self.log)
# Now that we have a config file we should check to see if it wants
# us to turn on debugging
if self.__config_enables_debugging():
self.log.setLevel(logging.DEBUG)
def _configure_global_logging(self):
if self.config.get_global_option('debug') in ('color', 'colour'):
configure_root_logger(use_color=True)
else:
configure_root_logger()
if self.args.get('debugger'):
sys.excepthook = _debugger_except_hook(
self.args.get('debugger', False),
self.args.get('debug', False))
def _build_parser(self):
# Does not have access to self.config
description = '\n\n'.join([textwrap.fill(textwrap.dedent(para))
for para in self.DESCRIPTION.split('\n\n')])
parser = argparse.ArgumentParser(description=description,
formatter_class=argparse.RawDescriptionHelpFormatter)
arg_objs = self.collect_arg_objs()
## FIXME: _allowed_args is full of argparse args, but __init__ thinks it's full of strings
self._allowed_args = self.populate_parser(parser, arg_objs)
self.preprocess_arg_objs(arg_objs)
self.populate_parser(parser, arg_objs)
parser.add_argument('--version', action='version',
version=self.VERSION) # doesn't need routing
self._cli_parser = parser
@classmethod
def collect_arg_objs(cls):
## TODO: leave notes on how to override this
return aggregate_subclass_fields(cls, 'ARGS')
def collect_arg_objs(self):
return aggregate_subclass_fields(self.__class__, 'ARGS')
def preprocess_arg_objs(self, arg_objs):
pass
def populate_parser(self, parser, arg_objs):
# Returns the args the parser was populated with <-- FIXME (the docs)
# Does not have access to self.config
args = []
for arg_obj in arg_objs:
args.extend(self.__add_arg_to_cli_parser(arg_obj, parser))
return args
self.__add_arg_to_cli_parser(arg_obj, parser)
def _process_cli_args(self):
def process_cli_args(self):
'''
Process CLI args to fill in missing parts of self.args and enable
debugging if necessary.
'''
# Does not have access to self.config
cli_args = self._cli_parser.parse_args()
for (key, val) in vars(cli_args).iteritems():
self.args.setdefault(key, val)
for key, val in vars(cli_args).iteritems():
# Everything goes in self.args
self.args[key] = val
# If a location to route this to was supplied, put it there, too.
route = self._arg_routes[key]
if route is not None:
if callable(route):
# If it's callable, call it to get the actual destination
# dict. This is needed to allow Arg objects to refer to
# instance attributes from the context of the class.
route = route(self)
# At this point we had better have a dict.
route[key] = val
def __add_arg_to_cli_parser(self, arglike_obj, parser):
# Returns the args the parser was populated with
# Does not have access to self.config
if isinstance(arglike_obj, Arg):
if arglike_obj.kwargs.get('dest') is argparse.SUPPRESS:
# Treat it like it doesn't exist at all
@@ -159,9 +182,8 @@ class BaseCommand(object):
else:
arg = parser.add_argument(*arglike_obj.pargs,
**arglike_obj.kwargs)
route = getattr(arglike_obj, 'route', self.DEFAULT_ROUTE)
self._arg_routes.setdefault(route, [])
self._arg_routes[route].append(arg.dest)
route = getattr(arglike_obj, 'route', self.default_route)
self._arg_routes[arg.dest] = route
return [arg]
elif isinstance(arglike_obj, MutuallyExclusiveArgList):
exgroup = parser.add_mutually_exclusive_group(
@@ -174,42 +196,29 @@ class BaseCommand(object):
raise TypeError('Unknown argument type ' +
arglike_obj.__class__.__name__)
def configure(self):
# TODO: Come up with something that can enforce arg constraints based
# on the info we can get from self._cli_parser
pass
@classmethod
def run(cls):
BaseCommand.__INHIBIT_CONFIG_PARSING = True
## TODO: document: command line entry point
cmd = cls(_do_cli=True)
BaseCommand.__INHIBIT_CONFIG_PARSING = False
try:
cmd.configure_global_logging()
cmd = cls(_do_cli=True)
except Exception as err:
print >> sys.stderr, 'error: {0}'.format(err)
# Since we don't even have a config file to consult our options for
# determining when debugging is on are limited to what we got at
# the command line.
if any(arg in sys.argv for arg in ('--debug', '-D', '--debugger')):
raise
sys.exit(1)
try:
result = cmd.main()
cmd.print_result(result)
except Exception as err:
cmd.handle_cli_exception(err)
def configure_global_logging(self):
if self.config.get_global_option('debug') in ('color', 'colour'):
configure_root_logger(use_color=True)
else:
configure_root_logger()
if self.args.get('debugger'):
sys.excepthook = _debugger_except_hook(
self.args.get('debugger', False),
self.args.get('debug', False))
@property
def config(self):
if not self._config:
if getattr(BaseCommand, '__INHIBIT_CONFIG_PARSING', False):
raise AssertionError(
'config files may not be parsed during __init__')
self._config = Config(self.CONFIG_FILES, log=self.log)
# Now that we have a config file we should check to see if it wants
# us to turn on debugging
if self.__config_enables_debugging():
self.log.setLevel(logging.DEBUG)
return self._config
@property
def name(self):
return self.__class__.__name__
@@ -226,7 +235,7 @@ class BaseCommand(object):
@property
def debug(self):
if self._config and self.__config_enables_debugging():
if self.__config_enables_debugging():
return True
if self.args.get('debug') or self.args.get('debugger'):
return True
@@ -242,21 +251,12 @@ class BaseCommand(object):
sys.exit(1)
def __config_enables_debugging(self):
if self._config.get_global_option('debug') in ('color', 'colour'):
if self.config is None:
return False
if self.config.get_global_option('debug') in ('color', 'colour'):
# It isn't boolean, but still counts as true.
return True
return self._config.get_global_option_bool('debug', False)
def aggregate_subclass_fields(cls, field_name):
values = []
# pylint doesn't know about classes' built-in mro() method
# pylint: disable-msg=E1101
for m_class in cls.mro():
# pylint: enable-msg=E1101
if field_name in vars(m_class):
values.extend(getattr(m_class, field_name))
return values
return self.config.get_global_option_bool('debug', False)
def _debugger_except_hook(debugger_enabled, debug_enabled):

View File

@@ -1,4 +1,4 @@
# Copyright (c) 2012, Eucalyptus Systems, Inc.
# Copyright (c) 2012-2013, Eucalyptus Systems, Inc.
#
# Permission to use, copy, modify, and/or distribute this software for
# any purpose with or without fee is hereby granted, provided that the
@@ -21,9 +21,11 @@ class Config(object):
self.log = log.getChild('config')
else:
self.log = _FakeLogger()
self.globals = {}
self.regions = {}
self.users = {}
self.globals = {}
self.__current_region = None
self.__current_user = None
self._memo = {}
self._parse_config(filenames)
@@ -58,6 +60,45 @@ class Config(object):
self.users[user] = dict(parser.items(section))
# Ignore unrecognized sections for forward compatibility
@property
def current_region(self):
# This is a property so we can log when it is set.
return self.__current_region
@current_region.setter
def current_region(self, val):
self.log.debug('current region set to %s', repr(val))
self.__current_region = val
def get_region(self):
if self.current_region is not None:
return self.current_region
if 'default-region' in self.globals:
return self.globals['default-region']
raise KeyError('no region was chosen')
@property
def current_user(self):
# This is a property so we can log when it is set.
return self.__current_user
@current_user.setter
def current_user(self, val):
self.log.debug('current user set to %s', repr(val))
self.__current_user = val
def get_user(self):
if self.current_user is not None:
return self.current_user
if self.get_region() is not None:
# Try to pull it from the current region
region_user = self.get_region_option('user')
if region_user is not None:
return region_user
if 'default-user' in self.globals:
return self.globals['default-user']
raise KeyError('no user was chosen')
def get_global_option(self, option):
return self.globals.get(option)
@@ -65,51 +106,29 @@ class Config(object):
value = self.get_global_option(option)
return convert_to_bool(value, default=default)
def get_user_option(self, regionspec, option):
user = None
region = None
if regionspec:
if '@' in regionspec:
user, region = regionspec.split('@', 1)
else:
region = regionspec
if not region:
region = self.globals.get('default-region')
if not user and region:
user = self._lookup_recursively(self.regions, region,
'default-user')
if not user and self.globals.get('default-user'):
user = self.globals['default-user']
if not user:
self.log.debug('no user to find')
return None
return self._lookup_recursively(self.users, user, option,
redact=['secret-key'])
def get_user_option(self, option, user=None, redact=False):
if user is None:
user = self.get_user()
return self._lookup_recursively('users', self.users, user, option,
redact=redact)
def get_user_option_bool(self, regionspec, option, default=None):
value = self.get_user_option(regionspec, option)
def get_user_option_bool(self, option, user=None, default=None):
value = self.get_user_option(option, user=user)
return convert_to_bool(value, default=default)
def get_region_option(self, regionspec, option):
if regionspec:
if '@' in regionspec:
region = regionspec.split('@', 1)[1]
else:
region = regionspec
elif self.globals.get('default-region'):
region = self.globals['default-region']
else:
self.log.debug('no region to find')
return None
return self._lookup_recursively(self.regions, region, option)
def get_region_option(self, option, region=None, redact=False):
if region is None:
region = self.get_region()
return self._lookup_recursively('regions', self.regions, region,
option, redact=redact)
def get_region_option_bool(self, regionspec, option, default=None):
value = self.get_region_option(regionspec, option)
def get_region_option_bool(self, option, region=None, default=None):
value = self.get_region_option(option, region=region)
return convert_to_bool(value, default=default)
def _lookup_recursively(self, confdict, section, option, redact=None,
cont_reason=None):
## TODO: detect loops
def _lookup_recursively(self, confdict_name, confdict, section, option,
redact=None, cont_reason=None):
# TODO: detect loops
self._memo.setdefault(id(confdict), {})
if (section, option) in self._memo[id(confdict)]:
return self._memo[id(confdict)][(section, option)]
@@ -119,7 +138,8 @@ class Config(object):
section_bits = section.split(':')
if not cont_reason:
self.log.debug('searching for option %s', repr(option))
self.log.debug('searching %s for option %s', confdict_name,
repr(option))
for prd in itertools.product((True, False), repeat=len(section_bits)):
prd_section = ':'.join(section_bits[i] if prd[i] else '*'
for i in range(len(section_bits)))
@@ -143,11 +163,11 @@ class Config(object):
new_option = value_chunks[2]
else:
new_option = option
return memoize(self._lookup_recursively(
return memoize(self._lookup_recursively(confdict_name,
confdict, new_section, new_option,
cont_reason='deferred'))
# We're done!
if redact and option in redact:
if redact:
print_value = '<redacted>'
else:
print_value = repr(value)
@@ -166,7 +186,7 @@ class Config(object):
if count > len(section_bits):
matches = c_counts[count]
if len(matches) == 1:
return memoize(self._lookup_recursively(
return memoize(self._lookup_recursively(confdict_name,
confdict, matches[0], option,
cont_reason=('from ' + repr(section))))
elif len(matches) > 1:

View File

@@ -20,10 +20,11 @@ import platform
import sys
import textwrap
from . import __version__, EMPTY, AUTH, PARAMS, SERVICE, SESSION
from . import __version__, EMPTY
from .command import BaseCommand
from .exceptions import ClientError, ServerError
from .service import BaseService
from .util import aggregate_subclass_fields
from .xmlparse import parse_listdelimited_aws_xml
class BaseRequest(BaseCommand):
@@ -52,7 +53,7 @@ class BaseRequest(BaseCommand):
- API_VERSION: the API version to send along with the request. This is
only necessary to override the service class's API
version for a specific request.
- ACTION: a string containing the Action query parameter. This
- NAME: a string containing the Action query parameter. This
defaults to the class's name.
- DESCRIPTION: a string describing the tool. This becomes part of the
command line help string.
@@ -69,59 +70,65 @@ class BaseRequest(BaseCommand):
SERVICE_CLASS = BaseService
API_VERSION = None
ACTION = None
NAME = None
FILTERS = []
DEFAULT_ROUTE = PARAMS
LIST_MARKERS = []
def __init__(self, **kwargs):
BaseCommand.__init__(self, **kwargs)
def __init__(self, service=None, **kwargs):
self.service = service
# Parts of the HTTP request to be sent to the server.
# Note that self.serialize_params will update self.params for each
# entry in self.args that routes to PARAMS.
self.headers = None
self.params = None
self.headers = {}
self.params = {}
self.post_data = None
self.method = 'GET'
# HTTP response obtained from the server
self.response = None
self._service = None
self.__user_agent = None
@classmethod
def collect_arg_objs(cls):
request_args = super(BaseRequest, cls).collect_arg_objs()
service_args = cls.SERVICE_CLASS.collect_arg_objs()
# Note that the service is likely to include auth args of its own.
BaseCommand.__init__(self, **kwargs)
def _post_init(self):
if self.service is None:
self.service = self.SERVICE_CLASS(self.config, self.log)
BaseCommand._post_init(self)
@property
def default_route(self):
return self.params
def collect_arg_objs(self):
request_args = BaseCommand.collect_arg_objs(self)
service_args = self.service.collect_arg_objs()
# Note that the service is likely to include auth args as well.
return request_args + service_args
def populate_parser(self, parser, arg_objs):
# Does not have access to self.config
args = BaseCommand.populate_parser(self, parser, arg_objs)
if self.FILTERS:
args.append(parser.add_argument('--filter', metavar='NAME=VALUE',
action='append', dest='filters',
help='restrict results to those that meet criteria',
type=partial(_parse_filter, filter_objs=self.FILTERS)))
parser.epilog = self.__build_filter_help()
self._arg_routes.setdefault(None, [])
self._arg_routes[None].append('filters')
## TODO: service args
return args
def preprocess_arg_objs(self, arg_objs):
self.service.preprocess_arg_objs(arg_objs)
def _process_cli_args(self):
# Does not have access to self.config
BaseCommand._process_cli_args(self)
def populate_parser(self, parser, arg_objs):
BaseCommand.populate_parser(self, parser, arg_objs)
if self.FILTERS:
parser.add_argument('--filter', metavar='NAME=VALUE',
action='append', dest='filters',
help='restrict results to those that meet criteria',
type=partial(_parse_filter, filter_objs=self.FILTERS))
parser.epilog = self.__build_filter_help()
self._arg_routes['filters'] = None
def process_cli_args(self):
BaseCommand.process_cli_args(self)
if 'filters' in self.args:
self.args['Filter'] = _process_filters(self.args.pop('filters'))
self._arg_routes.setdefault(self.DEFAULT_ROUTE, [])
self._arg_routes[self.DEFAULT_ROUTE].append('Filter')
self._arg_routes['Filter'] = self.params
def configure(self):
self.service.configure()
@property
def name(self):
@@ -129,7 +136,7 @@ class BaseRequest(BaseCommand):
The name of this action. Used when choosing what to supply for the
Action query parameter.
'''
return self.ACTION or self.__class__.__name__
return self.NAME or self.__class__.__name__
@property
def user_agent(self):
@@ -144,22 +151,6 @@ class BaseRequest(BaseCommand):
pyver=platform.python_version())
return self.__user_agent
@property
def service(self):
if self._service is None:
service_args = {'auth_args': {},
'session_args': {}}
for (key, val) in self.args.iteritems():
if key in self._arg_routes.get(SERVICE, []):
service_args[key] = val
elif key in self._arg_routes.get(AUTH, []):
service_args['auth_args'][key] = val
elif key in self._arg_routes.get(SESSION, []):
service_args['session_args'][key] = val
self._service = self.SERVICE_CLASS(self.config, self.log,
**service_args)
return self._service
@property
def status(self):
if self.response is not None:
@@ -167,7 +158,7 @@ class BaseRequest(BaseCommand):
else:
return None
def serialize_params(self, args, route, prefix=None):
def serialize_params(self, args, prefix=None):
'''
Given a possibly-nested dict of args and an arg routing destination,
transform each element in the dict that matches the corresponding
@@ -207,14 +198,13 @@ class BaseRequest(BaseCommand):
elif isinstance(args, dict):
for (key, val) in args.iteritems():
# Prefix.Key1, Prefix.Key2, ...
if key in self._arg_routes.get(route, []) or route is _ALWAYS:
if prefix:
prefixed_key = prefix + '.' + str(key)
else:
prefixed_key = str(key)
if isinstance(val, dict) or isinstance(val, list):
flattened.update(self.serialize_params(val, _ALWAYS,
flattened.update(self.serialize_params(val,
prefixed_key))
elif isinstance(val, file):
flattened[prefixed_key] = val.read()
@@ -231,8 +221,7 @@ class BaseRequest(BaseCommand):
prefixed_key = str(i_item)
if isinstance(item, dict) or isinstance(item, list):
flattened.update(self.serialize_params(item, _ALWAYS,
prefixed_key))
flattened.update(self.serialize_params(item, prefixed_key))
elif isinstance(item, file):
flattened[prefixed_key] = item.read()
elif item or item == 0:
@@ -259,8 +248,7 @@ class BaseRequest(BaseCommand):
4. If the response's status code does not indicate success, log an
error and raise a ServerError.
'''
params = self.serialize_params(self.args, PARAMS)
params.update(self.serialize_params(self.params, _ALWAYS))
params = self.serialize_params(self.params)
headers = dict(self.headers or {})
headers.setdefault('User-Agent', self.user_agent)
self.log.info('parameters: %s', params)
@@ -470,6 +458,3 @@ class _ReadLoggingFileWrapper(object):
chunk = self.fileobj.read(size)
self.logger.log(self.level, chunk, extra={'append': True})
return chunk
_ALWAYS = type('_ALWAYS', (), {'__repr__': lambda self: '_ALWAYS'})()

View File

@@ -19,116 +19,113 @@ import requests.exceptions
import time
import urlparse
from .auth import QuerySignatureV2Auth
from .auth import QuerySigV2Auth
from .exceptions import ClientError, ServiceInitError
from .util import aggregate_subclass_fields
class BaseService(object):
NAME = ''
DESCRIPTION = ''
API_VERSION = ''
MAX_RETRIES = 4
MAX_RETRIES = 4 ## TODO: check the config file
AUTH_CLASS = QuerySignatureV2Auth
ENV_URL = 'AWS_URL' # endpoint URL
AUTH_CLASS = None
ENV_URL = None
def __init__(self, config, log, url=None, regionspec=None, auth_args=None,
session_args=None):
self.log = log
# The region name currently only matters for sigv4.
## FIXME: It also won't work with every config source yet.
## TODO: DOCUMENT: if url contains :: it will be split into
## regionspec::endpoint
## FIXME: Is the above info true any more?
self.config = config
self.endpoint_url = None
self.regionspec = regionspec ## TODO: rename this
self._auth_args = auth_args or {}
self._session_args = session_args or {}
ARGS = []
# SSL verification is opt-in
self._session_args.setdefault('verify', False)
def __init__(self, config, log, **kwargs):
self.args = kwargs
self.config = config
self.endpoint = None
self.log = log
self.session_args = {'verify': False} # SSL verification is opt-in
self._session = None
# Set self.endpoint_url and self.regionspec from __init__ args
self._set_url_vars(url)
if self.AUTH_CLASS is not None:
self.auth = self.AUTH_CLASS(self)
else:
self.auth = None
# Grab info from the command line or service-specific config
self.read_config()
@property
def region_name(self):
return self.config.get_region()
if not self.endpoint_url:
def collect_arg_objs(self):
service_args = aggregate_subclass_fields(self.__class__, 'ARGS')
if self.auth is not None:
auth_args = self.auth.collect_arg_objs()
else:
auth_args = []
return service_args + auth_args
def preprocess_arg_objs(self, arg_objs):
if self.auth is not None:
self.auth.preprocess_arg_objs(arg_objs)
def configure(self):
# self.args gets highest precedence for self.endpoint and user/region
self.process_url(self.args.get('url'))
if self.args.get('userregion'):
self.process_userregion(self.args['userregion'])
# Environment comes next
self.process_url(os.getenv(self.ENV_URL))
# Finally, try the config file
self.process_url(self.config.get_region_option(self.NAME + '-url'))
# Ensure everything is okay and finish up
self.validate_config()
if self.auth is not None:
self.auth.configure()
@property
def session(self):
if self._session is not None:
return self._session
if requests.__version__ >= '1.0':
self._session = requests.session()
self._session.auth = self.auth
for key, val in self.session_args.iteritems():
setattr(self._session, key, val)
else:
self._session = requests.session(auth=self.auth,
**self.session_args)
return self._session
def validate_config(self):
if self.endpoint is None:
regions = ', '.join(sorted(self.config.regions.keys()))
errmsg = 'no endpoint to connect to was given'
if regions:
errmsg += '. Known regions are '
errmsg += ', '.join(sorted(self.config.regions.keys()))
errmsg += '. Known regions are ' + regions
raise ServiceInitError(errmsg)
auth = self.AUTH_CLASS(self, **self._auth_args)
self.session = requests.session(auth=auth, **self._session_args)
def process_url(self, url):
if url:
if '::' in url:
userregion, endpoint = url.split('::', 1)
else:
endpoint = url
userregion = None
if self.endpoint is None:
self.endpoint = url
if userregion:
self.process_userregion(userregion)
@classmethod
def collect_arg_objs(cls):
## TODO: implement this
return []
def read_config(self):
'''
Read configuration from the environment, files, and so on and use them
to populate self.endpoint_url, self.regionspec, and self._auth_args.
This method's configuration sources are, in order:
- An environment variable with the same name as self.ENV_URL
- An AWS credential file, from the path given in the
AWS_CREDENTIAL_FILE environment variable
- Requestbuilder configuration files, from paths given in
self.CONFIG_FILES
Of these, earlier sources take precedence over later sources.
Subclasses may override this method to add or rearrange configuration
sources.
'''
# Try the environment first
if self.ENV_URL in os.environ:
self._set_url_vars(os.getenv(self.ENV_URL, None))
# Read config files from their default locations
self.read_aws_credential_file()
self.read_requestbuilder_config()
def read_requestbuilder_config(self):
self._set_url_vars(self.config.get_region_option(self.regionspec,
self.NAME + '-url'))
secret_key = self.config.get_user_option(self.regionspec, 'secret-key')
if secret_key and not self._auth_args.get('secret_key'):
self._auth_args['secret_key'] = secret_key
key_id = self.config.get_user_option(self.regionspec, 'key-id')
if key_id and not self._auth_args.get('key_id'):
self._auth_args['key_id'] = key_id
if self.config.get_region_option_bool(self.regionspec, 'verify-ssl'):
self._session_args['verify'] = True
def read_aws_credential_file(self):
'''
If the 'AWS_CREDENTIAL_FILE' environment variable exists, parse that
file for access keys and use them if keys were not already supplied to
__init__.
'''
if 'AWS_CREDENTIAL_FILE' in os.environ:
path = os.getenv('AWS_CREDENTIAL_FILE')
path = os.path.expandvars(path)
path = os.path.expanduser(path)
with open(path) as credfile:
for line in credfile:
line = line.split('#', 1)[0]
if '=' in line:
(key, val) = line.split('=', 1)
if (key.strip() == 'AWSAccessKeyId' and
not self._auth_args.get('key_id')):
self._auth_args['key_id'] = val.strip()
elif (key.strip() == 'AWSSecretKey' and
not self._auth_args.get('secret_key')):
self._auth_args['secret_key'] = val.strip()
def process_userregion(self, userregion):
if '@' in userregion:
user, region = userregion.split('@', 1)
else:
user = None
region = userregion
if region and self.config.current_region is None:
self.config.current_region = region
if user and self.config.current_user is None:
self.config.current_user = user
## TODO: nuke Action; the request should make it a param instead
## TODO: the same should probably happen with API versions, but the
## request would have to deal with service.API_VERSION, too
def make_request(self, action, method='GET', path=None, params=None,
headers=None, data=None, api_version=None):
params = params or {}
@@ -143,12 +140,12 @@ class BaseService(object):
if path:
# We can't simply use urljoin because a path might start with '/'
# like it could for S3 keys that start with that character.
if self.endpoint_url.endswith('/'):
url = self.endpoint_url + path
if self.endpoint.endswith('/'):
url = self.endpoint + path
else:
url = self.endpoint_url + '/' + path
url = self.endpoint + '/' + path
else:
url = self.endpoint_url
url = self.endpoint
## TODO: replace pre_send and post_request hooks for use with requests 1
hooks = {'pre_send': _log_request_data(self.log),
@@ -165,15 +162,6 @@ class BaseService(object):
except requests.exceptions.RequestException as exc:
raise ClientError(exc)
def _set_url_vars(self, url):
if url:
if '::' in url:
regionspec, endpoint_url = url.split('::', 1)
else:
regionspec = None
endpoint_url = url
self.regionspec = regionspec or self.regionspec
self.endpoint_url = endpoint_url or self.endpoint_url
class RetryOnStatuses(object):
def __init__(self, statuses, max_retries, logger=None):

23
requestbuilder/util.py Normal file
View File

@@ -0,0 +1,23 @@
# Copyright (c) 2013, Eucalyptus Systems, Inc.
#
# Permission to use, copy, modify, and/or distribute this software for
# any purpose with or without fee is hereby granted, provided that the
# above copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
def aggregate_subclass_fields(cls, field_name):
values = []
# pylint doesn't know about classes' built-in mro() method
# pylint: disable-msg=E1101
for m_class in cls.mro():
# pylint: enable-msg=E1101
if field_name in vars(m_class):
values.extend(getattr(m_class, field_name))
return values