Gather args from services and auth handlers
This change allows services and auth handlers to specify their own ARGS lists that requests can then gather and feed into the ArgumentParsers they create. Arg routes are now actual objects (or for declarative use, attrgetters) as opposed to magic values. Requests needing to call other requests can re-use their own service objects to propagate the args that the latter gathered from the command line. Configuration files are parsed much earlier than they used to be and can now cause BaseCommand.__init__() to fail. BaseCommand.run() handles this. Config objects gained the concept of a "current" region and user, which services can set as they parse their command line args. These values are used as defaults when one goes to look up options. BaseCommand.DEFAULT_ROUTE is now an attribute/property, default_route. BaesCommand.print_result is now a noop. Define it yourself if you need a tool to output something during that step. BaseRequest.ACTION is now BaseRequest.NAME. This change breaks a large number of internal APIs. Docstrings are horribly out of date at this point and should be fixed fairly soon in a future commit.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2012, Eucalyptus Systems, Inc.
|
||||
# Copyright (c) 2012-2013, Eucalyptus Systems, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and/or distribute this software for
|
||||
# any purpose with or without fee is hereby granted, provided that the
|
||||
@@ -13,6 +13,7 @@
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import operator
|
||||
|
||||
__version__ = '0.0'
|
||||
|
||||
@@ -111,19 +112,13 @@ class GenericTagFilter(Filter):
|
||||
|
||||
|
||||
########## SINGLETONS ##########
|
||||
# Indicates a parameter that should be sent to the server without a value
|
||||
# Indicates a parameter that should be sent to the server without a value.
|
||||
# Contrast this with empty strings, with are omitted from requests entirely.
|
||||
EMPTY = type('EMPTY', (), {'__repr__': lambda self: "''",
|
||||
'__str__': lambda self: ''})()
|
||||
|
||||
# Constants (enums?) used for arg routing
|
||||
AUTH = type('AUTH', (), {'__repr__': lambda self: 'AUTH'})()
|
||||
PARAMS = type('PARAMS', (), {'__repr__': lambda self: 'PARAMS'})()
|
||||
SERVICE = type('SERVICE', (), {'__repr__': lambda self: 'SERVICE'})()
|
||||
SESSION = type('SESSION', (), {'__repr__': lambda self: 'SESSION'})()
|
||||
|
||||
# Common args for query authentication
|
||||
STD_AUTH_ARGS = [
|
||||
Arg('-I', '--access-key-id', dest='key_id', metavar='KEY_ID',
|
||||
route_to=AUTH),
|
||||
Arg('-S', '--secret-key', dest='secret_key', metavar='KEY',
|
||||
route_to=AUTH)]
|
||||
# Getters used for arg routing
|
||||
AUTH = operator.attrgetter('service.auth.args')
|
||||
PARAMS = operator.attrgetter('params')
|
||||
SERVICE = operator.attrgetter('service.args')
|
||||
SESSION = operator.attrgetter('service.session_args')
|
||||
|
||||
@@ -12,37 +12,96 @@
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import requests.auth
|
||||
from six import text_type
|
||||
import time
|
||||
import urllib
|
||||
import urlparse
|
||||
from . import Arg, AUTH
|
||||
from .exceptions import AuthError
|
||||
from .util import aggregate_subclass_fields
|
||||
|
||||
ISO8601 = '%Y-%m-%dT%H:%M:%SZ'
|
||||
|
||||
class QuerySignatureV2Auth(requests.auth.AuthBase):
|
||||
|
||||
class BaseAuth(requests.auth.AuthBase):
|
||||
ARGS = []
|
||||
|
||||
def __init__(self, service, **kwargs):
|
||||
self.args = kwargs
|
||||
self.config = service.config
|
||||
self.log = service.log.getChild(self.__class__.__name__)
|
||||
self.service = service
|
||||
|
||||
def collect_arg_objs(self):
|
||||
return aggregate_subclass_fields(self.__class__, 'ARGS')
|
||||
|
||||
def preprocess_arg_objs(self, arg_objs):
|
||||
pass
|
||||
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
def __call__(self, req):
|
||||
pass
|
||||
|
||||
|
||||
class HmacKeyAuth(BaseAuth):
|
||||
ARGS = [Arg('-I', '--access-key-id', dest='key_id', metavar='KEY_ID',
|
||||
default=argparse.SUPPRESS, route_to=AUTH),
|
||||
Arg('-S', '--secret-key', dest='secret_key', metavar='KEY',
|
||||
default=argparse.SUPPRESS, route_to=AUTH)]
|
||||
|
||||
def configure(self):
|
||||
# See if an AWS credential file was given in the environment
|
||||
self.configure_from_aws_credential_file()
|
||||
# Try the requestbuilder config file next
|
||||
self.configure_from_configfile()
|
||||
|
||||
if not self.args.get('key_id'):
|
||||
raise AuthError('missing access key ID')
|
||||
if not self.args.get('secret_key'):
|
||||
raise AuthError('missing secret key')
|
||||
|
||||
def configure_from_aws_credential_file(self):
|
||||
if 'AWS_CREDENTIAL_FILE' in os.environ:
|
||||
path = os.getenv('AWS_CREDENTIAL_FILE')
|
||||
path = os.path.expandvars(path)
|
||||
path = os.path.expanduser(path)
|
||||
with open(path) as credfile:
|
||||
for line in credfile:
|
||||
line = line.split('#', 1)[0]
|
||||
if '=' in line:
|
||||
(key, val) = line.split('=', 1)
|
||||
if key.strip() == 'AWSAccessKeyId':
|
||||
self.args.setdefault('key_id', val.strip())
|
||||
elif key.strip() == 'AWSSecretKey':
|
||||
self.args.setdefault('secret_key', val.strip())
|
||||
|
||||
def configure_from_configfile(self):
|
||||
config_key_id = self.config.get_user_option('key-id')
|
||||
if config_key_id:
|
||||
self.args.setdefault('key_id', config_key_id)
|
||||
config_secret_key = self.config.get_user_option('secret-key',
|
||||
redact=True)
|
||||
if config_secret_key:
|
||||
self.args.setdefault('secret_key', config_secret_key)
|
||||
|
||||
|
||||
class QuerySigV2Auth(HmacKeyAuth):
|
||||
'''
|
||||
AWS signature version 2
|
||||
http://docs.amazonwebservices.com/general/latest/gr/signature-version-2.html
|
||||
'''
|
||||
def __init__(self, service, key_id, secret_key, params_to_data=True):
|
||||
self.service = service
|
||||
if not key_id:
|
||||
raise AuthError('missing access key ID')
|
||||
if not secret_key:
|
||||
raise AuthError('missing secret key')
|
||||
self.key_id = key_id
|
||||
self.hmac = hmac.new(secret_key, digestmod=hashlib.sha256)
|
||||
self.log = self.service.log.getChild(self.__class__.__name__)
|
||||
# Whether to convert params to data if POSTing with only the former
|
||||
self.params_to_data = params_to_data
|
||||
|
||||
def __call__(self, req):
|
||||
# We assume that req.params is a dict
|
||||
req.params['AWSAccessKeyId'] = self.key_id
|
||||
req.params['AWSAccessKeyId'] = self.args['key_id']
|
||||
req.params['SignatureVersion'] = 2
|
||||
req.params['SignatureMethod'] = 'HmacSHA256'
|
||||
req.params['Timestamp'] = time.strftime(ISO8601, time.gmtime())
|
||||
@@ -69,14 +128,13 @@ class QuerySignatureV2Auth(requests.auth.AuthBase):
|
||||
return req
|
||||
|
||||
def convert_params_to_data(self, req):
|
||||
if (self.params_to_data and req.method.upper() == 'POST' and
|
||||
isinstance(req.params, dict)):
|
||||
if req.method.upper() == 'POST' and isinstance(req.params, dict):
|
||||
# POST with params -> use params as form data instead
|
||||
self.log.debug('converting params to POST data')
|
||||
req.data = req.params
|
||||
req.params = None
|
||||
|
||||
def sign_string(self, to_sign):
|
||||
hmac = self.hmac.copy()
|
||||
hmac.update(to_sign)
|
||||
return base64.b64encode(hmac.digest())
|
||||
req_hmac = hmac.new(self.args['secret_key'], digestmod=hashlib.sha256)
|
||||
req_hmac.update(to_sign)
|
||||
return base64.b64encode(req_hmac.digest())
|
||||
|
||||
@@ -29,6 +29,7 @@ except ImportError:
|
||||
from . import __version__, Arg, MutuallyExclusiveArgList
|
||||
from .config import Config
|
||||
from .logging import configure_root_logger
|
||||
from .util import aggregate_subclass_fields
|
||||
|
||||
class BaseCommand(object):
|
||||
'''
|
||||
@@ -59,99 +60,121 @@ class BaseCommand(object):
|
||||
those of their parent classes.
|
||||
'''
|
||||
|
||||
VERSION = 'requestbuilder ' + __version__
|
||||
|
||||
DESCRIPTION = ''
|
||||
|
||||
ARGS = [Arg('-D', '--debug', action='store_true', route_to=None,
|
||||
help='show debugging output'),
|
||||
Arg('--debugger', action='store_true', route_to=None,
|
||||
help='enable interactive debugger on error')]
|
||||
DEFAULT_ROUTE = None
|
||||
CONFIG_FILES = ['/etc/requestbuilder.ini']
|
||||
|
||||
VERSION = 'requestbuilder ' + __version__
|
||||
|
||||
def __init__(self, _do_cli=False, **kwargs):
|
||||
# Note to programmer: when run() is initializing the first object it
|
||||
# can't catch exceptions that may result from accesses to self.config.
|
||||
# To deal with this, run() disables config file parsing during this
|
||||
# process to expose premature access to self.config during testing.
|
||||
self.args = {} # populated later
|
||||
self.args = kwargs
|
||||
self.config = None # created by _process_configfile
|
||||
self.log = None # created by _configure_logging
|
||||
self._allowed_args = None # created by _build_parser
|
||||
self._arg_routes = {}
|
||||
self._cli_parser = None # created by _build_parser
|
||||
self._config = None
|
||||
|
||||
self._configure_logging()
|
||||
self._process_configfiles()
|
||||
if _do_cli:
|
||||
self._configure_global_logging()
|
||||
|
||||
# We need to enforce arg constraints in one location to make this
|
||||
# framework equally useful for chained commands and those driven
|
||||
# directly from the command line. Thus, we do most of the parsing/
|
||||
# validation work in __init__ as opposed to putting it off until
|
||||
# we hit CLI-specific code.
|
||||
# validation work before __init__ returns as opposed to putting it
|
||||
# off until we hit CLI-specific code.
|
||||
#
|
||||
# Derived classes MUST call this method to ensure things stay sane.
|
||||
self.__do_cli = _do_cli
|
||||
self._post_init()
|
||||
|
||||
def _post_init(self):
|
||||
self._build_parser()
|
||||
if self.__do_cli:
|
||||
# Distribute CLI args to the various places that need them
|
||||
self.process_cli_args()
|
||||
self.configure()
|
||||
|
||||
# Come up with a list of args that the arg parser will allow
|
||||
for key, val in kwargs.iteritems():
|
||||
if key in self._allowed_args:
|
||||
self.args[key] = val
|
||||
else:
|
||||
raise TypeError('__init__() got an unexpected keyword '
|
||||
'argument \'{0}\'; allowed arguments are {1}'
|
||||
.format(key, ', '.join(self._allowed_args)))
|
||||
@property
|
||||
def default_route(self):
|
||||
# This is a property so we can return something that references self.
|
||||
return None
|
||||
|
||||
## TODO: AUTH PARAM PASSING (probably involves the service class)
|
||||
if _do_cli:
|
||||
self._process_cli_args()
|
||||
else:
|
||||
# TODO: enforce arg constraints when not pulling from the CLI
|
||||
pass
|
||||
@property
|
||||
def config_files(self):
|
||||
# This list may need to be computed on the fly.
|
||||
return []
|
||||
|
||||
def _configure_logging(self):
|
||||
# Does not have access to self.config
|
||||
self.log = logging.getLogger(self.name)
|
||||
if self.debug:
|
||||
self.log.setLevel(logging.DEBUG)
|
||||
|
||||
def _process_configfiles(self):
|
||||
self.config = Config(self.config_files, log=self.log)
|
||||
# Now that we have a config file we should check to see if it wants
|
||||
# us to turn on debugging
|
||||
if self.__config_enables_debugging():
|
||||
self.log.setLevel(logging.DEBUG)
|
||||
|
||||
def _configure_global_logging(self):
|
||||
if self.config.get_global_option('debug') in ('color', 'colour'):
|
||||
configure_root_logger(use_color=True)
|
||||
else:
|
||||
configure_root_logger()
|
||||
if self.args.get('debugger'):
|
||||
sys.excepthook = _debugger_except_hook(
|
||||
self.args.get('debugger', False),
|
||||
self.args.get('debug', False))
|
||||
|
||||
def _build_parser(self):
|
||||
# Does not have access to self.config
|
||||
description = '\n\n'.join([textwrap.fill(textwrap.dedent(para))
|
||||
for para in self.DESCRIPTION.split('\n\n')])
|
||||
parser = argparse.ArgumentParser(description=description,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
arg_objs = self.collect_arg_objs()
|
||||
## FIXME: _allowed_args is full of argparse args, but __init__ thinks it's full of strings
|
||||
self._allowed_args = self.populate_parser(parser, arg_objs)
|
||||
self.preprocess_arg_objs(arg_objs)
|
||||
self.populate_parser(parser, arg_objs)
|
||||
parser.add_argument('--version', action='version',
|
||||
version=self.VERSION) # doesn't need routing
|
||||
self._cli_parser = parser
|
||||
|
||||
@classmethod
|
||||
def collect_arg_objs(cls):
|
||||
## TODO: leave notes on how to override this
|
||||
return aggregate_subclass_fields(cls, 'ARGS')
|
||||
def collect_arg_objs(self):
|
||||
return aggregate_subclass_fields(self.__class__, 'ARGS')
|
||||
|
||||
def preprocess_arg_objs(self, arg_objs):
|
||||
pass
|
||||
|
||||
def populate_parser(self, parser, arg_objs):
|
||||
# Returns the args the parser was populated with <-- FIXME (the docs)
|
||||
# Does not have access to self.config
|
||||
args = []
|
||||
for arg_obj in arg_objs:
|
||||
args.extend(self.__add_arg_to_cli_parser(arg_obj, parser))
|
||||
return args
|
||||
self.__add_arg_to_cli_parser(arg_obj, parser)
|
||||
|
||||
def _process_cli_args(self):
|
||||
def process_cli_args(self):
|
||||
'''
|
||||
Process CLI args to fill in missing parts of self.args and enable
|
||||
debugging if necessary.
|
||||
'''
|
||||
# Does not have access to self.config
|
||||
|
||||
cli_args = self._cli_parser.parse_args()
|
||||
for (key, val) in vars(cli_args).iteritems():
|
||||
self.args.setdefault(key, val)
|
||||
for key, val in vars(cli_args).iteritems():
|
||||
# Everything goes in self.args
|
||||
self.args[key] = val
|
||||
|
||||
# If a location to route this to was supplied, put it there, too.
|
||||
route = self._arg_routes[key]
|
||||
if route is not None:
|
||||
if callable(route):
|
||||
# If it's callable, call it to get the actual destination
|
||||
# dict. This is needed to allow Arg objects to refer to
|
||||
# instance attributes from the context of the class.
|
||||
route = route(self)
|
||||
# At this point we had better have a dict.
|
||||
route[key] = val
|
||||
|
||||
def __add_arg_to_cli_parser(self, arglike_obj, parser):
|
||||
# Returns the args the parser was populated with
|
||||
# Does not have access to self.config
|
||||
if isinstance(arglike_obj, Arg):
|
||||
if arglike_obj.kwargs.get('dest') is argparse.SUPPRESS:
|
||||
# Treat it like it doesn't exist at all
|
||||
@@ -159,9 +182,8 @@ class BaseCommand(object):
|
||||
else:
|
||||
arg = parser.add_argument(*arglike_obj.pargs,
|
||||
**arglike_obj.kwargs)
|
||||
route = getattr(arglike_obj, 'route', self.DEFAULT_ROUTE)
|
||||
self._arg_routes.setdefault(route, [])
|
||||
self._arg_routes[route].append(arg.dest)
|
||||
route = getattr(arglike_obj, 'route', self.default_route)
|
||||
self._arg_routes[arg.dest] = route
|
||||
return [arg]
|
||||
elif isinstance(arglike_obj, MutuallyExclusiveArgList):
|
||||
exgroup = parser.add_mutually_exclusive_group(
|
||||
@@ -174,42 +196,29 @@ class BaseCommand(object):
|
||||
raise TypeError('Unknown argument type ' +
|
||||
arglike_obj.__class__.__name__)
|
||||
|
||||
def configure(self):
|
||||
# TODO: Come up with something that can enforce arg constraints based
|
||||
# on the info we can get from self._cli_parser
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def run(cls):
|
||||
BaseCommand.__INHIBIT_CONFIG_PARSING = True
|
||||
## TODO: document: command line entry point
|
||||
cmd = cls(_do_cli=True)
|
||||
BaseCommand.__INHIBIT_CONFIG_PARSING = False
|
||||
try:
|
||||
cmd.configure_global_logging()
|
||||
cmd = cls(_do_cli=True)
|
||||
except Exception as err:
|
||||
print >> sys.stderr, 'error: {0}'.format(err)
|
||||
# Since we don't even have a config file to consult our options for
|
||||
# determining when debugging is on are limited to what we got at
|
||||
# the command line.
|
||||
if any(arg in sys.argv for arg in ('--debug', '-D', '--debugger')):
|
||||
raise
|
||||
sys.exit(1)
|
||||
try:
|
||||
result = cmd.main()
|
||||
cmd.print_result(result)
|
||||
except Exception as err:
|
||||
cmd.handle_cli_exception(err)
|
||||
|
||||
def configure_global_logging(self):
|
||||
if self.config.get_global_option('debug') in ('color', 'colour'):
|
||||
configure_root_logger(use_color=True)
|
||||
else:
|
||||
configure_root_logger()
|
||||
if self.args.get('debugger'):
|
||||
sys.excepthook = _debugger_except_hook(
|
||||
self.args.get('debugger', False),
|
||||
self.args.get('debug', False))
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
if not self._config:
|
||||
if getattr(BaseCommand, '__INHIBIT_CONFIG_PARSING', False):
|
||||
raise AssertionError(
|
||||
'config files may not be parsed during __init__')
|
||||
self._config = Config(self.CONFIG_FILES, log=self.log)
|
||||
# Now that we have a config file we should check to see if it wants
|
||||
# us to turn on debugging
|
||||
if self.__config_enables_debugging():
|
||||
self.log.setLevel(logging.DEBUG)
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
@@ -226,7 +235,7 @@ class BaseCommand(object):
|
||||
|
||||
@property
|
||||
def debug(self):
|
||||
if self._config and self.__config_enables_debugging():
|
||||
if self.__config_enables_debugging():
|
||||
return True
|
||||
if self.args.get('debug') or self.args.get('debugger'):
|
||||
return True
|
||||
@@ -242,21 +251,12 @@ class BaseCommand(object):
|
||||
sys.exit(1)
|
||||
|
||||
def __config_enables_debugging(self):
|
||||
if self._config.get_global_option('debug') in ('color', 'colour'):
|
||||
if self.config is None:
|
||||
return False
|
||||
if self.config.get_global_option('debug') in ('color', 'colour'):
|
||||
# It isn't boolean, but still counts as true.
|
||||
return True
|
||||
return self._config.get_global_option_bool('debug', False)
|
||||
|
||||
|
||||
def aggregate_subclass_fields(cls, field_name):
|
||||
values = []
|
||||
# pylint doesn't know about classes' built-in mro() method
|
||||
# pylint: disable-msg=E1101
|
||||
for m_class in cls.mro():
|
||||
# pylint: enable-msg=E1101
|
||||
if field_name in vars(m_class):
|
||||
values.extend(getattr(m_class, field_name))
|
||||
return values
|
||||
return self.config.get_global_option_bool('debug', False)
|
||||
|
||||
|
||||
def _debugger_except_hook(debugger_enabled, debug_enabled):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2012, Eucalyptus Systems, Inc.
|
||||
# Copyright (c) 2012-2013, Eucalyptus Systems, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and/or distribute this software for
|
||||
# any purpose with or without fee is hereby granted, provided that the
|
||||
@@ -21,9 +21,11 @@ class Config(object):
|
||||
self.log = log.getChild('config')
|
||||
else:
|
||||
self.log = _FakeLogger()
|
||||
self.globals = {}
|
||||
self.regions = {}
|
||||
self.users = {}
|
||||
self.globals = {}
|
||||
self.__current_region = None
|
||||
self.__current_user = None
|
||||
self._memo = {}
|
||||
self._parse_config(filenames)
|
||||
|
||||
@@ -58,6 +60,45 @@ class Config(object):
|
||||
self.users[user] = dict(parser.items(section))
|
||||
# Ignore unrecognized sections for forward compatibility
|
||||
|
||||
@property
|
||||
def current_region(self):
|
||||
# This is a property so we can log when it is set.
|
||||
return self.__current_region
|
||||
|
||||
@current_region.setter
|
||||
def current_region(self, val):
|
||||
self.log.debug('current region set to %s', repr(val))
|
||||
self.__current_region = val
|
||||
|
||||
def get_region(self):
|
||||
if self.current_region is not None:
|
||||
return self.current_region
|
||||
if 'default-region' in self.globals:
|
||||
return self.globals['default-region']
|
||||
raise KeyError('no region was chosen')
|
||||
|
||||
@property
|
||||
def current_user(self):
|
||||
# This is a property so we can log when it is set.
|
||||
return self.__current_user
|
||||
|
||||
@current_user.setter
|
||||
def current_user(self, val):
|
||||
self.log.debug('current user set to %s', repr(val))
|
||||
self.__current_user = val
|
||||
|
||||
def get_user(self):
|
||||
if self.current_user is not None:
|
||||
return self.current_user
|
||||
if self.get_region() is not None:
|
||||
# Try to pull it from the current region
|
||||
region_user = self.get_region_option('user')
|
||||
if region_user is not None:
|
||||
return region_user
|
||||
if 'default-user' in self.globals:
|
||||
return self.globals['default-user']
|
||||
raise KeyError('no user was chosen')
|
||||
|
||||
def get_global_option(self, option):
|
||||
return self.globals.get(option)
|
||||
|
||||
@@ -65,51 +106,29 @@ class Config(object):
|
||||
value = self.get_global_option(option)
|
||||
return convert_to_bool(value, default=default)
|
||||
|
||||
def get_user_option(self, regionspec, option):
|
||||
user = None
|
||||
region = None
|
||||
if regionspec:
|
||||
if '@' in regionspec:
|
||||
user, region = regionspec.split('@', 1)
|
||||
else:
|
||||
region = regionspec
|
||||
if not region:
|
||||
region = self.globals.get('default-region')
|
||||
if not user and region:
|
||||
user = self._lookup_recursively(self.regions, region,
|
||||
'default-user')
|
||||
if not user and self.globals.get('default-user'):
|
||||
user = self.globals['default-user']
|
||||
if not user:
|
||||
self.log.debug('no user to find')
|
||||
return None
|
||||
return self._lookup_recursively(self.users, user, option,
|
||||
redact=['secret-key'])
|
||||
def get_user_option(self, option, user=None, redact=False):
|
||||
if user is None:
|
||||
user = self.get_user()
|
||||
return self._lookup_recursively('users', self.users, user, option,
|
||||
redact=redact)
|
||||
|
||||
def get_user_option_bool(self, regionspec, option, default=None):
|
||||
value = self.get_user_option(regionspec, option)
|
||||
def get_user_option_bool(self, option, user=None, default=None):
|
||||
value = self.get_user_option(option, user=user)
|
||||
return convert_to_bool(value, default=default)
|
||||
|
||||
def get_region_option(self, regionspec, option):
|
||||
if regionspec:
|
||||
if '@' in regionspec:
|
||||
region = regionspec.split('@', 1)[1]
|
||||
else:
|
||||
region = regionspec
|
||||
elif self.globals.get('default-region'):
|
||||
region = self.globals['default-region']
|
||||
else:
|
||||
self.log.debug('no region to find')
|
||||
return None
|
||||
return self._lookup_recursively(self.regions, region, option)
|
||||
def get_region_option(self, option, region=None, redact=False):
|
||||
if region is None:
|
||||
region = self.get_region()
|
||||
return self._lookup_recursively('regions', self.regions, region,
|
||||
option, redact=redact)
|
||||
|
||||
def get_region_option_bool(self, regionspec, option, default=None):
|
||||
value = self.get_region_option(regionspec, option)
|
||||
def get_region_option_bool(self, option, region=None, default=None):
|
||||
value = self.get_region_option(option, region=region)
|
||||
return convert_to_bool(value, default=default)
|
||||
|
||||
def _lookup_recursively(self, confdict, section, option, redact=None,
|
||||
cont_reason=None):
|
||||
## TODO: detect loops
|
||||
def _lookup_recursively(self, confdict_name, confdict, section, option,
|
||||
redact=None, cont_reason=None):
|
||||
# TODO: detect loops
|
||||
self._memo.setdefault(id(confdict), {})
|
||||
if (section, option) in self._memo[id(confdict)]:
|
||||
return self._memo[id(confdict)][(section, option)]
|
||||
@@ -119,7 +138,8 @@ class Config(object):
|
||||
|
||||
section_bits = section.split(':')
|
||||
if not cont_reason:
|
||||
self.log.debug('searching for option %s', repr(option))
|
||||
self.log.debug('searching %s for option %s', confdict_name,
|
||||
repr(option))
|
||||
for prd in itertools.product((True, False), repeat=len(section_bits)):
|
||||
prd_section = ':'.join(section_bits[i] if prd[i] else '*'
|
||||
for i in range(len(section_bits)))
|
||||
@@ -143,11 +163,11 @@ class Config(object):
|
||||
new_option = value_chunks[2]
|
||||
else:
|
||||
new_option = option
|
||||
return memoize(self._lookup_recursively(
|
||||
return memoize(self._lookup_recursively(confdict_name,
|
||||
confdict, new_section, new_option,
|
||||
cont_reason='deferred'))
|
||||
# We're done!
|
||||
if redact and option in redact:
|
||||
if redact:
|
||||
print_value = '<redacted>'
|
||||
else:
|
||||
print_value = repr(value)
|
||||
@@ -166,7 +186,7 @@ class Config(object):
|
||||
if count > len(section_bits):
|
||||
matches = c_counts[count]
|
||||
if len(matches) == 1:
|
||||
return memoize(self._lookup_recursively(
|
||||
return memoize(self._lookup_recursively(confdict_name,
|
||||
confdict, matches[0], option,
|
||||
cont_reason=('from ' + repr(section))))
|
||||
elif len(matches) > 1:
|
||||
|
||||
@@ -20,10 +20,11 @@ import platform
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from . import __version__, EMPTY, AUTH, PARAMS, SERVICE, SESSION
|
||||
from . import __version__, EMPTY
|
||||
from .command import BaseCommand
|
||||
from .exceptions import ClientError, ServerError
|
||||
from .service import BaseService
|
||||
from .util import aggregate_subclass_fields
|
||||
from .xmlparse import parse_listdelimited_aws_xml
|
||||
|
||||
class BaseRequest(BaseCommand):
|
||||
@@ -52,7 +53,7 @@ class BaseRequest(BaseCommand):
|
||||
- API_VERSION: the API version to send along with the request. This is
|
||||
only necessary to override the service class's API
|
||||
version for a specific request.
|
||||
- ACTION: a string containing the Action query parameter. This
|
||||
- NAME: a string containing the Action query parameter. This
|
||||
defaults to the class's name.
|
||||
- DESCRIPTION: a string describing the tool. This becomes part of the
|
||||
command line help string.
|
||||
@@ -69,59 +70,65 @@ class BaseRequest(BaseCommand):
|
||||
|
||||
SERVICE_CLASS = BaseService
|
||||
API_VERSION = None
|
||||
ACTION = None
|
||||
NAME = None
|
||||
|
||||
FILTERS = []
|
||||
DEFAULT_ROUTE = PARAMS
|
||||
|
||||
LIST_MARKERS = []
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
BaseCommand.__init__(self, **kwargs)
|
||||
|
||||
def __init__(self, service=None, **kwargs):
|
||||
self.service = service
|
||||
# Parts of the HTTP request to be sent to the server.
|
||||
# Note that self.serialize_params will update self.params for each
|
||||
# entry in self.args that routes to PARAMS.
|
||||
self.headers = None
|
||||
self.params = None
|
||||
self.headers = {}
|
||||
self.params = {}
|
||||
self.post_data = None
|
||||
self.method = 'GET'
|
||||
|
||||
# HTTP response obtained from the server
|
||||
self.response = None
|
||||
|
||||
self._service = None
|
||||
self.__user_agent = None
|
||||
|
||||
@classmethod
|
||||
def collect_arg_objs(cls):
|
||||
request_args = super(BaseRequest, cls).collect_arg_objs()
|
||||
service_args = cls.SERVICE_CLASS.collect_arg_objs()
|
||||
# Note that the service is likely to include auth args of its own.
|
||||
BaseCommand.__init__(self, **kwargs)
|
||||
|
||||
def _post_init(self):
|
||||
if self.service is None:
|
||||
self.service = self.SERVICE_CLASS(self.config, self.log)
|
||||
BaseCommand._post_init(self)
|
||||
|
||||
@property
|
||||
def default_route(self):
|
||||
return self.params
|
||||
|
||||
def collect_arg_objs(self):
|
||||
request_args = BaseCommand.collect_arg_objs(self)
|
||||
service_args = self.service.collect_arg_objs()
|
||||
# Note that the service is likely to include auth args as well.
|
||||
return request_args + service_args
|
||||
|
||||
def populate_parser(self, parser, arg_objs):
|
||||
# Does not have access to self.config
|
||||
args = BaseCommand.populate_parser(self, parser, arg_objs)
|
||||
if self.FILTERS:
|
||||
args.append(parser.add_argument('--filter', metavar='NAME=VALUE',
|
||||
action='append', dest='filters',
|
||||
help='restrict results to those that meet criteria',
|
||||
type=partial(_parse_filter, filter_objs=self.FILTERS)))
|
||||
parser.epilog = self.__build_filter_help()
|
||||
self._arg_routes.setdefault(None, [])
|
||||
self._arg_routes[None].append('filters')
|
||||
## TODO: service args
|
||||
return args
|
||||
def preprocess_arg_objs(self, arg_objs):
|
||||
self.service.preprocess_arg_objs(arg_objs)
|
||||
|
||||
def _process_cli_args(self):
|
||||
# Does not have access to self.config
|
||||
BaseCommand._process_cli_args(self)
|
||||
def populate_parser(self, parser, arg_objs):
|
||||
BaseCommand.populate_parser(self, parser, arg_objs)
|
||||
if self.FILTERS:
|
||||
parser.add_argument('--filter', metavar='NAME=VALUE',
|
||||
action='append', dest='filters',
|
||||
help='restrict results to those that meet criteria',
|
||||
type=partial(_parse_filter, filter_objs=self.FILTERS))
|
||||
parser.epilog = self.__build_filter_help()
|
||||
self._arg_routes['filters'] = None
|
||||
|
||||
def process_cli_args(self):
|
||||
BaseCommand.process_cli_args(self)
|
||||
if 'filters' in self.args:
|
||||
self.args['Filter'] = _process_filters(self.args.pop('filters'))
|
||||
self._arg_routes.setdefault(self.DEFAULT_ROUTE, [])
|
||||
self._arg_routes[self.DEFAULT_ROUTE].append('Filter')
|
||||
self._arg_routes['Filter'] = self.params
|
||||
|
||||
def configure(self):
|
||||
self.service.configure()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@@ -129,7 +136,7 @@ class BaseRequest(BaseCommand):
|
||||
The name of this action. Used when choosing what to supply for the
|
||||
Action query parameter.
|
||||
'''
|
||||
return self.ACTION or self.__class__.__name__
|
||||
return self.NAME or self.__class__.__name__
|
||||
|
||||
@property
|
||||
def user_agent(self):
|
||||
@@ -144,22 +151,6 @@ class BaseRequest(BaseCommand):
|
||||
pyver=platform.python_version())
|
||||
return self.__user_agent
|
||||
|
||||
@property
|
||||
def service(self):
|
||||
if self._service is None:
|
||||
service_args = {'auth_args': {},
|
||||
'session_args': {}}
|
||||
for (key, val) in self.args.iteritems():
|
||||
if key in self._arg_routes.get(SERVICE, []):
|
||||
service_args[key] = val
|
||||
elif key in self._arg_routes.get(AUTH, []):
|
||||
service_args['auth_args'][key] = val
|
||||
elif key in self._arg_routes.get(SESSION, []):
|
||||
service_args['session_args'][key] = val
|
||||
self._service = self.SERVICE_CLASS(self.config, self.log,
|
||||
**service_args)
|
||||
return self._service
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.response is not None:
|
||||
@@ -167,7 +158,7 @@ class BaseRequest(BaseCommand):
|
||||
else:
|
||||
return None
|
||||
|
||||
def serialize_params(self, args, route, prefix=None):
|
||||
def serialize_params(self, args, prefix=None):
|
||||
'''
|
||||
Given a possibly-nested dict of args and an arg routing destination,
|
||||
transform each element in the dict that matches the corresponding
|
||||
@@ -207,14 +198,13 @@ class BaseRequest(BaseCommand):
|
||||
elif isinstance(args, dict):
|
||||
for (key, val) in args.iteritems():
|
||||
# Prefix.Key1, Prefix.Key2, ...
|
||||
if key in self._arg_routes.get(route, []) or route is _ALWAYS:
|
||||
if prefix:
|
||||
prefixed_key = prefix + '.' + str(key)
|
||||
else:
|
||||
prefixed_key = str(key)
|
||||
|
||||
if isinstance(val, dict) or isinstance(val, list):
|
||||
flattened.update(self.serialize_params(val, _ALWAYS,
|
||||
flattened.update(self.serialize_params(val,
|
||||
prefixed_key))
|
||||
elif isinstance(val, file):
|
||||
flattened[prefixed_key] = val.read()
|
||||
@@ -231,8 +221,7 @@ class BaseRequest(BaseCommand):
|
||||
prefixed_key = str(i_item)
|
||||
|
||||
if isinstance(item, dict) or isinstance(item, list):
|
||||
flattened.update(self.serialize_params(item, _ALWAYS,
|
||||
prefixed_key))
|
||||
flattened.update(self.serialize_params(item, prefixed_key))
|
||||
elif isinstance(item, file):
|
||||
flattened[prefixed_key] = item.read()
|
||||
elif item or item == 0:
|
||||
@@ -259,8 +248,7 @@ class BaseRequest(BaseCommand):
|
||||
4. If the response's status code does not indicate success, log an
|
||||
error and raise a ServerError.
|
||||
'''
|
||||
params = self.serialize_params(self.args, PARAMS)
|
||||
params.update(self.serialize_params(self.params, _ALWAYS))
|
||||
params = self.serialize_params(self.params)
|
||||
headers = dict(self.headers or {})
|
||||
headers.setdefault('User-Agent', self.user_agent)
|
||||
self.log.info('parameters: %s', params)
|
||||
@@ -470,6 +458,3 @@ class _ReadLoggingFileWrapper(object):
|
||||
chunk = self.fileobj.read(size)
|
||||
self.logger.log(self.level, chunk, extra={'append': True})
|
||||
return chunk
|
||||
|
||||
|
||||
_ALWAYS = type('_ALWAYS', (), {'__repr__': lambda self: '_ALWAYS'})()
|
||||
|
||||
@@ -19,116 +19,113 @@ import requests.exceptions
|
||||
import time
|
||||
import urlparse
|
||||
|
||||
from .auth import QuerySignatureV2Auth
|
||||
from .auth import QuerySigV2Auth
|
||||
from .exceptions import ClientError, ServiceInitError
|
||||
from .util import aggregate_subclass_fields
|
||||
|
||||
class BaseService(object):
|
||||
NAME = ''
|
||||
DESCRIPTION = ''
|
||||
API_VERSION = ''
|
||||
MAX_RETRIES = 4
|
||||
MAX_RETRIES = 4 ## TODO: check the config file
|
||||
|
||||
AUTH_CLASS = QuerySignatureV2Auth
|
||||
ENV_URL = 'AWS_URL' # endpoint URL
|
||||
AUTH_CLASS = None
|
||||
ENV_URL = None
|
||||
|
||||
def __init__(self, config, log, url=None, regionspec=None, auth_args=None,
|
||||
session_args=None):
|
||||
self.log = log
|
||||
# The region name currently only matters for sigv4.
|
||||
## FIXME: It also won't work with every config source yet.
|
||||
## TODO: DOCUMENT: if url contains :: it will be split into
|
||||
## regionspec::endpoint
|
||||
## FIXME: Is the above info true any more?
|
||||
self.config = config
|
||||
self.endpoint_url = None
|
||||
self.regionspec = regionspec ## TODO: rename this
|
||||
self._auth_args = auth_args or {}
|
||||
self._session_args = session_args or {}
|
||||
ARGS = []
|
||||
|
||||
# SSL verification is opt-in
|
||||
self._session_args.setdefault('verify', False)
|
||||
def __init__(self, config, log, **kwargs):
|
||||
self.args = kwargs
|
||||
self.config = config
|
||||
self.endpoint = None
|
||||
self.log = log
|
||||
self.session_args = {'verify': False} # SSL verification is opt-in
|
||||
self._session = None
|
||||
|
||||
# Set self.endpoint_url and self.regionspec from __init__ args
|
||||
self._set_url_vars(url)
|
||||
if self.AUTH_CLASS is not None:
|
||||
self.auth = self.AUTH_CLASS(self)
|
||||
else:
|
||||
self.auth = None
|
||||
|
||||
# Grab info from the command line or service-specific config
|
||||
self.read_config()
|
||||
@property
|
||||
def region_name(self):
|
||||
return self.config.get_region()
|
||||
|
||||
if not self.endpoint_url:
|
||||
def collect_arg_objs(self):
|
||||
service_args = aggregate_subclass_fields(self.__class__, 'ARGS')
|
||||
if self.auth is not None:
|
||||
auth_args = self.auth.collect_arg_objs()
|
||||
else:
|
||||
auth_args = []
|
||||
return service_args + auth_args
|
||||
|
||||
def preprocess_arg_objs(self, arg_objs):
|
||||
if self.auth is not None:
|
||||
self.auth.preprocess_arg_objs(arg_objs)
|
||||
|
||||
def configure(self):
|
||||
# self.args gets highest precedence for self.endpoint and user/region
|
||||
self.process_url(self.args.get('url'))
|
||||
if self.args.get('userregion'):
|
||||
self.process_userregion(self.args['userregion'])
|
||||
# Environment comes next
|
||||
self.process_url(os.getenv(self.ENV_URL))
|
||||
# Finally, try the config file
|
||||
self.process_url(self.config.get_region_option(self.NAME + '-url'))
|
||||
|
||||
# Ensure everything is okay and finish up
|
||||
self.validate_config()
|
||||
if self.auth is not None:
|
||||
self.auth.configure()
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
if self._session is not None:
|
||||
return self._session
|
||||
if requests.__version__ >= '1.0':
|
||||
self._session = requests.session()
|
||||
self._session.auth = self.auth
|
||||
for key, val in self.session_args.iteritems():
|
||||
setattr(self._session, key, val)
|
||||
else:
|
||||
self._session = requests.session(auth=self.auth,
|
||||
**self.session_args)
|
||||
return self._session
|
||||
|
||||
def validate_config(self):
|
||||
if self.endpoint is None:
|
||||
regions = ', '.join(sorted(self.config.regions.keys()))
|
||||
errmsg = 'no endpoint to connect to was given'
|
||||
if regions:
|
||||
errmsg += '. Known regions are '
|
||||
errmsg += ', '.join(sorted(self.config.regions.keys()))
|
||||
errmsg += '. Known regions are ' + regions
|
||||
raise ServiceInitError(errmsg)
|
||||
|
||||
auth = self.AUTH_CLASS(self, **self._auth_args)
|
||||
self.session = requests.session(auth=auth, **self._session_args)
|
||||
def process_url(self, url):
|
||||
if url:
|
||||
if '::' in url:
|
||||
userregion, endpoint = url.split('::', 1)
|
||||
else:
|
||||
endpoint = url
|
||||
userregion = None
|
||||
if self.endpoint is None:
|
||||
self.endpoint = url
|
||||
if userregion:
|
||||
self.process_userregion(userregion)
|
||||
|
||||
@classmethod
|
||||
def collect_arg_objs(cls):
|
||||
## TODO: implement this
|
||||
return []
|
||||
|
||||
def read_config(self):
|
||||
'''
|
||||
Read configuration from the environment, files, and so on and use them
|
||||
to populate self.endpoint_url, self.regionspec, and self._auth_args.
|
||||
|
||||
This method's configuration sources are, in order:
|
||||
- An environment variable with the same name as self.ENV_URL
|
||||
- An AWS credential file, from the path given in the
|
||||
AWS_CREDENTIAL_FILE environment variable
|
||||
- Requestbuilder configuration files, from paths given in
|
||||
self.CONFIG_FILES
|
||||
|
||||
Of these, earlier sources take precedence over later sources.
|
||||
|
||||
Subclasses may override this method to add or rearrange configuration
|
||||
sources.
|
||||
'''
|
||||
# Try the environment first
|
||||
if self.ENV_URL in os.environ:
|
||||
self._set_url_vars(os.getenv(self.ENV_URL, None))
|
||||
# Read config files from their default locations
|
||||
self.read_aws_credential_file()
|
||||
self.read_requestbuilder_config()
|
||||
|
||||
def read_requestbuilder_config(self):
|
||||
self._set_url_vars(self.config.get_region_option(self.regionspec,
|
||||
self.NAME + '-url'))
|
||||
secret_key = self.config.get_user_option(self.regionspec, 'secret-key')
|
||||
if secret_key and not self._auth_args.get('secret_key'):
|
||||
self._auth_args['secret_key'] = secret_key
|
||||
key_id = self.config.get_user_option(self.regionspec, 'key-id')
|
||||
if key_id and not self._auth_args.get('key_id'):
|
||||
self._auth_args['key_id'] = key_id
|
||||
|
||||
if self.config.get_region_option_bool(self.regionspec, 'verify-ssl'):
|
||||
self._session_args['verify'] = True
|
||||
|
||||
def read_aws_credential_file(self):
|
||||
'''
|
||||
If the 'AWS_CREDENTIAL_FILE' environment variable exists, parse that
|
||||
file for access keys and use them if keys were not already supplied to
|
||||
__init__.
|
||||
'''
|
||||
if 'AWS_CREDENTIAL_FILE' in os.environ:
|
||||
path = os.getenv('AWS_CREDENTIAL_FILE')
|
||||
path = os.path.expandvars(path)
|
||||
path = os.path.expanduser(path)
|
||||
with open(path) as credfile:
|
||||
for line in credfile:
|
||||
line = line.split('#', 1)[0]
|
||||
if '=' in line:
|
||||
(key, val) = line.split('=', 1)
|
||||
if (key.strip() == 'AWSAccessKeyId' and
|
||||
not self._auth_args.get('key_id')):
|
||||
self._auth_args['key_id'] = val.strip()
|
||||
elif (key.strip() == 'AWSSecretKey' and
|
||||
not self._auth_args.get('secret_key')):
|
||||
self._auth_args['secret_key'] = val.strip()
|
||||
def process_userregion(self, userregion):
|
||||
if '@' in userregion:
|
||||
user, region = userregion.split('@', 1)
|
||||
else:
|
||||
user = None
|
||||
region = userregion
|
||||
if region and self.config.current_region is None:
|
||||
self.config.current_region = region
|
||||
if user and self.config.current_user is None:
|
||||
self.config.current_user = user
|
||||
|
||||
## TODO: nuke Action; the request should make it a param instead
|
||||
## TODO: the same should probably happen with API versions, but the
|
||||
## request would have to deal with service.API_VERSION, too
|
||||
def make_request(self, action, method='GET', path=None, params=None,
|
||||
headers=None, data=None, api_version=None):
|
||||
params = params or {}
|
||||
@@ -143,12 +140,12 @@ class BaseService(object):
|
||||
if path:
|
||||
# We can't simply use urljoin because a path might start with '/'
|
||||
# like it could for S3 keys that start with that character.
|
||||
if self.endpoint_url.endswith('/'):
|
||||
url = self.endpoint_url + path
|
||||
if self.endpoint.endswith('/'):
|
||||
url = self.endpoint + path
|
||||
else:
|
||||
url = self.endpoint_url + '/' + path
|
||||
url = self.endpoint + '/' + path
|
||||
else:
|
||||
url = self.endpoint_url
|
||||
url = self.endpoint
|
||||
|
||||
## TODO: replace pre_send and post_request hooks for use with requests 1
|
||||
hooks = {'pre_send': _log_request_data(self.log),
|
||||
@@ -165,15 +162,6 @@ class BaseService(object):
|
||||
except requests.exceptions.RequestException as exc:
|
||||
raise ClientError(exc)
|
||||
|
||||
def _set_url_vars(self, url):
|
||||
if url:
|
||||
if '::' in url:
|
||||
regionspec, endpoint_url = url.split('::', 1)
|
||||
else:
|
||||
regionspec = None
|
||||
endpoint_url = url
|
||||
self.regionspec = regionspec or self.regionspec
|
||||
self.endpoint_url = endpoint_url or self.endpoint_url
|
||||
|
||||
class RetryOnStatuses(object):
|
||||
def __init__(self, statuses, max_retries, logger=None):
|
||||
|
||||
23
requestbuilder/util.py
Normal file
23
requestbuilder/util.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2013, Eucalyptus Systems, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and/or distribute this software for
|
||||
# any purpose with or without fee is hereby granted, provided that the
|
||||
# above copyright notice and this permission notice appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
def aggregate_subclass_fields(cls, field_name):
|
||||
values = []
|
||||
# pylint doesn't know about classes' built-in mro() method
|
||||
# pylint: disable-msg=E1101
|
||||
for m_class in cls.mro():
|
||||
# pylint: enable-msg=E1101
|
||||
if field_name in vars(m_class):
|
||||
values.extend(getattr(m_class, field_name))
|
||||
return values
|
||||
Reference in New Issue
Block a user