Apply ruff, ruff-format

We disable the E203 (whitespace before ':') and E501 (line too long)
linter rules since these conflict with ruff-format. We also rework a
statement in 'keystoneauth1/tests/unit/test_session.py' since it's
triggering a bug in flake8 [1] that is currently (at time of authoring)
unresolved.

[1] https://github.com/PyCQA/flake8/issues/1948

Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
Change-Id: Ief5c1c57d1e72db9fc881063d4c7e1030e76da43
This commit is contained in:
Stephen Finucane 2024-08-02 11:33:43 +01:00
parent c05e237a8a
commit 127d7be2b7
137 changed files with 8953 additions and 6346 deletions

View File

@ -35,7 +35,7 @@ class ListAuthPluginsDirective(rst.Directive):
has_content = True has_content = True
def report_load_failure(mgr, ep, err): def report_load_failure(mgr, ep, err):
LOG.warning(u'Failed to load %s: %s' % (ep.module_name, err)) LOG.warning(f'Failed to load {ep.module_name}: {err}')
def display_plugin(self, ext): def display_plugin(self, ext):
overline_style = self.options.get('overline-style', '') overline_style = self.options.get('overline-style', '')
@ -60,7 +60,7 @@ class ListAuthPluginsDirective(rst.Directive):
yield "\n" yield "\n"
for opt in ext.obj.get_options(): for opt in ext.obj.get_options():
yield ":%s: %s" % (opt.name, opt.help) yield f":{opt.name}: {opt.help}"
yield "\n" yield "\n"

View File

@ -81,7 +81,7 @@ latex_documents = [
'Openstack Developers', 'Openstack Developers',
'manual', 'manual',
True, True,
), )
] ]
# Disable usage of xindy https://bugzilla.redhat.com/show_bug.cgi?id=1643664 # Disable usage of xindy https://bugzilla.redhat.com/show_bug.cgi?id=1643664

View File

@ -15,7 +15,7 @@ import threading
import time import time
class FairSemaphore(object): class FairSemaphore:
"""Semaphore class that notifies in order of request. """Semaphore class that notifies in order of request.
We cannot use a normal Semaphore because it doesn't give any ordering, We cannot use a normal Semaphore because it doesn't give any ordering,

View File

@ -78,7 +78,7 @@ def before_utcnow(**timedelta_kwargs):
# Detect if running on the Windows Subsystem for Linux # Detect if running on the Windows Subsystem for Linux
try: try:
with open('/proc/version', 'r') as f: with open('/proc/version') as f:
is_windows_linux_subsystem = 'microsoft' in f.read().lower() is_windows_linux_subsystem = 'microsoft' in f.read().lower()
except IOError: except OSError:
is_windows_linux_subsystem = False is_windows_linux_subsystem = False

View File

@ -13,7 +13,9 @@
from keystoneauth1.access.access import * # noqa from keystoneauth1.access.access import * # noqa
__all__ = ('AccessInfo', # noqa: F405 __all__ = ( # noqa: F405
'AccessInfoV2', # noqa: F405 'AccessInfo',
'AccessInfoV3', # noqa: F405 'AccessInfoV2',
'create') # noqa: F405 'AccessInfoV3',
'create',
)

View File

@ -25,10 +25,7 @@ from keystoneauth1.access import service_providers
STALE_TOKEN_DURATION = 30 STALE_TOKEN_DURATION = 30
__all__ = ('AccessInfo', __all__ = ('AccessInfo', 'AccessInfoV2', 'AccessInfoV3', 'create')
'AccessInfoV2',
'AccessInfoV3',
'create')
def create(resp=None, body=None, auth_token=None): def create(resp=None, body=None, auth_token=None):
@ -47,7 +44,6 @@ def create(resp=None, body=None, auth_token=None):
def _missingproperty(f): def _missingproperty(f):
@functools.wraps(f) @functools.wraps(f)
def inner(self): def inner(self):
try: try:
@ -58,7 +54,7 @@ def _missingproperty(f):
return property(inner) return property(inner)
class AccessInfo(object): class AccessInfo:
"""Encapsulates a raw authentication token from keystone. """Encapsulates a raw authentication token from keystone.
Provides helper methods for extracting useful values from that token. Provides helper methods for extracting useful values from that token.
@ -77,7 +73,8 @@ class AccessInfo(object):
def service_catalog(self): def service_catalog(self):
if not self._service_catalog: if not self._service_catalog:
self._service_catalog = self._service_catalog_class.from_token( self._service_catalog = self._service_catalog_class.from_token(
self._data) self._data
)
return self._service_catalog return self._service_catalog
@ -422,7 +419,7 @@ class AccessInfoV2(AccessInfo):
@_missingproperty @_missingproperty
def auth_token(self): def auth_token(self):
set_token = super(AccessInfoV2, self).auth_token set_token = super().auth_token
return set_token or self._data['access']['token']['id'] return set_token or self._data['access']['token']['id']
@property @property
@ -775,7 +772,8 @@ class AccessInfoV3(AccessInfo):
def service_providers(self): def service_providers(self):
if not self._service_providers: if not self._service_providers:
self._service_providers = ( self._service_providers = (
service_providers.ServiceProviders.from_token(self._data)) service_providers.ServiceProviders.from_token(self._data)
)
return self._service_providers return self._service_providers

View File

@ -114,7 +114,8 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
service.setdefault('id', None) service.setdefault('id', None)
service['endpoints'] = self._normalize_endpoints( service['endpoints'] = self._normalize_endpoints(
service.get('endpoints', [])) service.get('endpoints', [])
)
for endpoint in service['endpoints']: for endpoint in service['endpoints']:
endpoint['region_name'] = self._get_endpoint_region(endpoint) endpoint['region_name'] = self._get_endpoint_region(endpoint)
@ -129,9 +130,15 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
interface = [interface] interface = [interface]
return [self.normalize_interface(i) for i in interface] return [self.normalize_interface(i) for i in interface]
def get_endpoints_data(self, service_type=None, interface=None, def get_endpoints_data(
region_name=None, service_name=None, self,
service_id=None, endpoint_id=None): service_type=None,
interface=None,
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch and filter endpoint data for the specified service(s). """Fetch and filter endpoint data for the specified service(s).
Returns endpoints for the specified service (or all) containing Returns endpoints for the specified service (or all) containing
@ -164,17 +171,19 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
matching_endpoints = {} matching_endpoints = {}
for service in self.normalize_catalog(): for service in self.normalize_catalog():
if service_type and not discover._SERVICE_TYPES.is_match( if service_type and not discover._SERVICE_TYPES.is_match(
service_type, service['type']): service_type, service['type']
):
continue continue
if (service_name and service['name'] and if (
service_name != service['name']): service_name
and service['name']
and service_name != service['name']
):
continue continue
if (service_id and service['id'] and if service_id and service['id'] and service_id != service['id']:
service_id != service['id']):
continue continue
matching_endpoints.setdefault(service['type'], []) matching_endpoints.setdefault(service['type'], [])
@ -198,7 +207,9 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
interface=endpoint['interface'], interface=endpoint['interface'],
region_name=endpoint['region_name'], region_name=endpoint['region_name'],
endpoint_id=endpoint['id'], endpoint_id=endpoint['id'],
raw_endpoint=endpoint['raw_endpoint'])) raw_endpoint=endpoint['raw_endpoint'],
)
)
if not interfaces: if not interfaces:
return self._endpoints_by_type(service_type, matching_endpoints) return self._endpoints_by_type(service_type, matching_endpoints)
@ -212,8 +223,9 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
for endpoint in endpoints: for endpoint in endpoints:
matches_by_interface.setdefault(endpoint.interface, []) matches_by_interface.setdefault(endpoint.interface, [])
matches_by_interface[endpoint.interface].append(endpoint) matches_by_interface[endpoint.interface].append(endpoint)
best_interface = [i for i in interfaces best_interface = [
if i in matches_by_interface.keys()][0] i for i in interfaces if i in matches_by_interface.keys()
][0]
ret[matched_service_type] = matches_by_interface[best_interface] ret[matched_service_type] = matches_by_interface[best_interface]
return self._endpoints_by_type(service_type, ret) return self._endpoints_by_type(service_type, ret)
@ -279,9 +291,15 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
# part if we do. Raise this so that we can panic in unit tests. # part if we do. Raise this so that we can panic in unit tests.
raise ValueError("Programming error choosing an endpoint.") raise ValueError("Programming error choosing an endpoint.")
def get_endpoints(self, service_type=None, interface=None, def get_endpoints(
region_name=None, service_name=None, self,
service_id=None, endpoint_id=None): service_type=None,
interface=None,
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch and filter endpoint data for the specified service(s). """Fetch and filter endpoint data for the specified service(s).
Returns endpoints for the specified service (or all) containing Returns endpoints for the specified service (or all) containing
@ -294,17 +312,27 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
Returns a dict keyed by service_type with a list of endpoint dicts Returns a dict keyed by service_type with a list of endpoint dicts
""" """
endpoints_data = self.get_endpoints_data( endpoints_data = self.get_endpoints_data(
service_type=service_type, interface=interface, service_type=service_type,
region_name=region_name, service_name=service_name, interface=interface,
service_id=service_id, endpoint_id=endpoint_id) region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id,
)
endpoints = {} endpoints = {}
for service_type, data in endpoints_data.items(): for service_type, data in endpoints_data.items():
endpoints[service_type] = self._denormalize_endpoints(data) endpoints[service_type] = self._denormalize_endpoints(data)
return endpoints return endpoints
def get_endpoint_data_list(self, service_type=None, interface='public', def get_endpoint_data_list(
region_name=None, service_name=None, self,
service_id=None, endpoint_id=None): service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch a flat list of matching EndpointData objects. """Fetch a flat list of matching EndpointData objects.
Fetch the endpoints from the service catalog for a particular Fetch the endpoints from the service catalog for a particular
@ -327,17 +355,25 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
:returns: a list of matching EndpointData objects :returns: a list of matching EndpointData objects
:rtype: list(`keystoneauth1.discover.EndpointData`) :rtype: list(`keystoneauth1.discover.EndpointData`)
""" """
endpoints = self.get_endpoints_data(service_type=service_type, endpoints = self.get_endpoints_data(
service_type=service_type,
interface=interface, interface=interface,
region_name=region_name, region_name=region_name,
service_name=service_name, service_name=service_name,
service_id=service_id, service_id=service_id,
endpoint_id=endpoint_id) endpoint_id=endpoint_id,
)
return [endpoint for data in endpoints.values() for endpoint in data] return [endpoint for data in endpoints.values() for endpoint in data]
def get_urls(self, service_type=None, interface='public', def get_urls(
region_name=None, service_name=None, self,
service_id=None, endpoint_id=None): service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch endpoint urls from the service catalog. """Fetch endpoint urls from the service catalog.
Fetch the urls of endpoints from the service catalog for a particular Fetch the urls of endpoints from the service catalog for a particular
@ -359,17 +395,25 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
:returns: tuple of urls :returns: tuple of urls
""" """
endpoints = self.get_endpoint_data_list(service_type=service_type, endpoints = self.get_endpoint_data_list(
service_type=service_type,
interface=interface, interface=interface,
region_name=region_name, region_name=region_name,
service_name=service_name, service_name=service_name,
service_id=service_id, service_id=service_id,
endpoint_id=endpoint_id) endpoint_id=endpoint_id,
)
return tuple([endpoint.url for endpoint in endpoints]) return tuple([endpoint.url for endpoint in endpoints])
def url_for(self, service_type=None, interface='public', def url_for(
region_name=None, service_name=None, self,
service_id=None, endpoint_id=None): service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch an endpoint from the service catalog. """Fetch an endpoint from the service catalog.
Fetch the specified endpoint from the service catalog for Fetch the specified endpoint from the service catalog for
@ -389,16 +433,24 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
:param string service_id: The identifier of a service. :param string service_id: The identifier of a service.
:param string endpoint_id: The identifier of an endpoint. :param string endpoint_id: The identifier of an endpoint.
""" """
return self.endpoint_data_for(service_type=service_type, return self.endpoint_data_for(
service_type=service_type,
interface=interface, interface=interface,
region_name=region_name, region_name=region_name,
service_name=service_name, service_name=service_name,
service_id=service_id, service_id=service_id,
endpoint_id=endpoint_id).url endpoint_id=endpoint_id,
).url
def endpoint_data_for(self, service_type=None, interface='public', def endpoint_data_for(
region_name=None, service_name=None, self,
service_id=None, endpoint_id=None): service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch endpoint data from the service catalog. """Fetch endpoint data from the service catalog.
Fetch the specified endpoint data from the service catalog for Fetch the specified endpoint data from the service catalog for
@ -427,34 +479,30 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
region_name=region_name, region_name=region_name,
service_name=service_name, service_name=service_name,
service_id=service_id, service_id=service_id,
endpoint_id=endpoint_id) endpoint_id=endpoint_id,
)
if endpoint_data_list: if endpoint_data_list:
return endpoint_data_list[0] return endpoint_data_list[0]
if service_name and region_name: if service_name and region_name:
msg = ('%(interface)s endpoint for %(service_type)s service ' msg = (
'named %(service_name)s in %(region_name)s region not ' f'{interface} endpoint for {service_type} service '
'found' % f'named {service_name} in {region_name} region not '
{'interface': interface, 'found'
'service_type': service_type, 'service_name': service_name, )
'region_name': region_name})
elif service_name: elif service_name:
msg = ('%(interface)s endpoint for %(service_type)s service ' msg = (
'named %(service_name)s not found' % f'{interface} endpoint for {service_type} service '
{'interface': interface, f'named {service_name} not found'
'service_type': service_type, )
'service_name': service_name})
elif region_name: elif region_name:
msg = ('%(interface)s endpoint for %(service_type)s service ' msg = (
'in %(region_name)s region not found' % f'{interface} endpoint for {service_type} service '
{'interface': interface, f'in {region_name} region not found'
'service_type': service_type, 'region_name': region_name}) )
else: else:
msg = ('%(interface)s endpoint for %(service_type)s service ' msg = f'{interface} endpoint for {service_type} service not found'
'not found' %
{'interface': interface,
'service_type': service_type})
raise exceptions.EndpointNotFound(msg) raise exceptions.EndpointNotFound(msg)
@ -498,8 +546,9 @@ class ServiceCatalogV2(ServiceCatalog):
for endpoint in endpoints: for endpoint in endpoints:
raw_endpoint = endpoint.copy() raw_endpoint = endpoint.copy()
interface_urls = {} interface_urls = {}
interface_keys = [key for key in endpoint.keys() interface_keys = [
if key.endswith('URL')] key for key in endpoint.keys() if key.endswith('URL')
]
for key in interface_keys: for key in interface_keys:
interface = self.normalize_interface(key) interface = self.normalize_interface(key)
interface_urls[interface] = endpoint.pop(key) interface_urls[interface] = endpoint.pop(key)
@ -522,8 +571,7 @@ class ServiceCatalogV2(ServiceCatalog):
:returns: List of endpoint description dicts in original catalog format :returns: List of endpoint description dicts in original catalog format
""" """
raw_endpoints = super(ServiceCatalogV2, self)._denormalize_endpoints( raw_endpoints = super()._denormalize_endpoints(endpoints)
endpoints)
# The same raw endpoint content will be in the list once for each # The same raw endpoint content will be in the list once for each
# v2 endpoint_type entry. We only need one of them in the resulting # v2 endpoint_type entry. We only need one of them in the resulting
# list. So keep a list of the string versions. # list. So keep a list of the string versions.

View File

@ -13,22 +13,24 @@
from keystoneauth1 import exceptions from keystoneauth1 import exceptions
class ServiceProviders(object): class ServiceProviders:
"""Helper methods for dealing with Service Providers.""" """Helper methods for dealing with Service Providers."""
@classmethod @classmethod
def from_token(cls, token): def from_token(cls, token):
if 'token' not in token: if 'token' not in token:
raise ValueError('Token format does not support service' raise ValueError(
'providers.') 'Token format does not support service providers.'
)
return cls(token['token'].get('service_providers', [])) return cls(token['token'].get('service_providers', []))
def __init__(self, service_providers): def __init__(self, service_providers):
def normalize(service_providers_list): def normalize(service_providers_list):
return dict((sp['id'], sp) for sp in service_providers_list return {
if 'id' in sp) sp['id']: sp for sp in service_providers_list if 'id' in sp
}
self._service_providers = normalize(service_providers) self._service_providers = normalize(service_providers)
def _get_service_provider(self, sp_id): def _get_service_provider(self, sp_id):

View File

@ -19,7 +19,7 @@ from keystoneauth1 import _fair_semaphore
from keystoneauth1 import session from keystoneauth1 import session
class Adapter(object): class Adapter:
"""An instance of a session with local variables. """An instance of a session with local variables.
A session is a global object that is shared around amongst many clients. It A session is a global object that is shared around amongst many clients. It
@ -118,23 +118,41 @@ class Adapter(object):
client_name = None client_name = None
client_version = None client_version = None
def __init__(self, session, service_type=None, service_name=None, def __init__(
interface=None, region_name=None, endpoint_override=None, self,
version=None, auth=None, user_agent=None, session,
connect_retries=None, logger=None, allow=None, service_type=None,
additional_headers=None, client_name=None, service_name=None,
client_version=None, allow_version_hack=None, interface=None,
region_name=None,
endpoint_override=None,
version=None,
auth=None,
user_agent=None,
connect_retries=None,
logger=None,
allow=None,
additional_headers=None,
client_name=None,
client_version=None,
allow_version_hack=None,
global_request_id=None, global_request_id=None,
min_version=None, max_version=None, min_version=None,
default_microversion=None, status_code_retries=None, max_version=None,
retriable_status_codes=None, raise_exc=None, default_microversion=None,
rate_limit=None, concurrency=None, status_code_retries=None,
connect_retry_delay=None, status_code_retry_delay=None, retriable_status_codes=None,
raise_exc=None,
rate_limit=None,
concurrency=None,
connect_retry_delay=None,
status_code_retry_delay=None,
): ):
if version and (min_version or max_version): if version and (min_version or max_version):
raise TypeError( raise TypeError(
"version is mutually exclusive with min_version and" "version is mutually exclusive with min_version and"
" max_version") " max_version"
)
# NOTE(jamielennox): when adding new parameters to adapter please also # NOTE(jamielennox): when adding new parameters to adapter please also
# add them to the adapter call in httpclient.HTTPClient.__init__ as # add them to the adapter call in httpclient.HTTPClient.__init__ as
# well as to load_adapter_from_argparse below if the argument is # well as to load_adapter_from_argparse below if the argument is
@ -177,7 +195,8 @@ class Adapter(object):
rate_delay = 1.0 / rate_limit rate_delay = 1.0 / rate_limit
self._rate_semaphore = _fair_semaphore.FairSemaphore( self._rate_semaphore = _fair_semaphore.FairSemaphore(
concurrency, rate_delay) concurrency, rate_delay
)
def _set_endpoint_filter_kwargs(self, kwargs): def _set_endpoint_filter_kwargs(self, kwargs):
if self.service_type: if self.service_type:
@ -205,7 +224,8 @@ class Adapter(object):
# case insensitive. # case insensitive.
if kwargs.get('headers'): if kwargs.get('headers'):
kwargs['headers'] = requests.structures.CaseInsensitiveDict( kwargs['headers'] = requests.structures.CaseInsensitiveDict(
kwargs['headers']) kwargs['headers']
)
else: else:
kwargs['headers'] = requests.structures.CaseInsensitiveDict() kwargs['headers'] = requests.structures.CaseInsensitiveDict()
if self.endpoint_override: if self.endpoint_override:
@ -215,9 +235,13 @@ class Adapter(object):
kwargs.setdefault('auth', self.auth) kwargs.setdefault('auth', self.auth)
if self.user_agent: if self.user_agent:
kwargs.setdefault('user_agent', self.user_agent) kwargs.setdefault('user_agent', self.user_agent)
for arg in ('connect_retries', 'status_code_retries', for arg in (
'connect_retry_delay', 'status_code_retry_delay', 'connect_retries',
'retriable_status_codes'): 'status_code_retries',
'connect_retry_delay',
'status_code_retry_delay',
'retriable_status_codes',
):
if getattr(self, arg) is not None: if getattr(self, arg) is not None:
kwargs.setdefault(arg, getattr(self, arg)) kwargs.setdefault(arg, getattr(self, arg))
if self.logger: if self.logger:
@ -239,15 +263,18 @@ class Adapter(object):
kwargs.setdefault('rate_semaphore', self._rate_semaphore) kwargs.setdefault('rate_semaphore', self._rate_semaphore)
else: else:
warnings.warn('Using keystoneclient sessions has been deprecated. ' warnings.warn(
'Please update your software to use keystoneauth1.') 'Using keystoneclient sessions has been deprecated. '
'Please update your software to use keystoneauth1.'
)
for k, v in self.additional_headers.items(): for k, v in self.additional_headers.items():
kwargs.setdefault('headers', {}).setdefault(k, v) kwargs.setdefault('headers', {}).setdefault(k, v)
if self.global_request_id is not None: if self.global_request_id is not None:
kwargs.setdefault('headers', {}).setdefault( kwargs.setdefault('headers', {}).setdefault(
"X-OpenStack-Request-ID", self.global_request_id) "X-OpenStack-Request-ID", self.global_request_id
)
if self.raise_exc is not None: if self.raise_exc is not None:
kwargs.setdefault('raise_exc', self.raise_exc) kwargs.setdefault('raise_exc', self.raise_exc)
@ -309,10 +336,7 @@ class Adapter(object):
return self.session.get_endpoint_data(auth or self.auth, **kwargs) return self.session.get_endpoint_data(auth or self.auth, **kwargs)
def get_all_version_data( def get_all_version_data(self, interface='public', region_name=None):
self,
interface='public',
region_name=None):
"""Get data about all versions of a service. """Get data about all versions of a service.
:param interface: :param interface:
@ -330,7 +354,8 @@ class Adapter(object):
return self.session.get_all_version_data( return self.session.get_all_version_data(
interface=interface, interface=interface,
region_name=region_name, region_name=region_name,
service_type=self.service_type) service_type=self.service_type,
)
def get_api_major_version(self, auth=None, **kwargs): def get_api_major_version(self, auth=None, **kwargs):
"""Get the major API version as provided by the auth plugin. """Get the major API version as provided by the auth plugin.
@ -419,43 +444,50 @@ class Adapter(object):
adapter_group = parser.add_argument_group( adapter_group = parser.add_argument_group(
'Service Options', 'Service Options',
'Options controlling the specialization of the API' 'Options controlling the specialization of the API'
' Connection from information found in the catalog') ' Connection from information found in the catalog',
)
adapter_group.add_argument( adapter_group.add_argument(
'--os-service-type', '--os-service-type',
metavar='<name>', metavar='<name>',
default=os.environ.get('OS_SERVICE_TYPE', service_type), default=os.environ.get('OS_SERVICE_TYPE', service_type),
help='Service type to request from the catalog') help='Service type to request from the catalog',
)
adapter_group.add_argument( adapter_group.add_argument(
'--os-service-name', '--os-service-name',
metavar='<name>', metavar='<name>',
default=os.environ.get('OS_SERVICE_NAME', None), default=os.environ.get('OS_SERVICE_NAME', None),
help='Service name to request from the catalog') help='Service name to request from the catalog',
)
adapter_group.add_argument( adapter_group.add_argument(
'--os-interface', '--os-interface',
metavar='<name>', metavar='<name>',
default=os.environ.get('OS_INTERFACE', 'public'), default=os.environ.get('OS_INTERFACE', 'public'),
help='API Interface to use [public, internal, admin]') help='API Interface to use [public, internal, admin]',
)
adapter_group.add_argument( adapter_group.add_argument(
'--os-region-name', '--os-region-name',
metavar='<name>', metavar='<name>',
default=os.environ.get('OS_REGION_NAME', None), default=os.environ.get('OS_REGION_NAME', None),
help='Region of the cloud to use') help='Region of the cloud to use',
)
adapter_group.add_argument( adapter_group.add_argument(
'--os-endpoint-override', '--os-endpoint-override',
metavar='<name>', metavar='<name>',
default=os.environ.get('OS_ENDPOINT_OVERRIDE', None), default=os.environ.get('OS_ENDPOINT_OVERRIDE', None),
help='Endpoint to use instead of the endpoint in the catalog') help='Endpoint to use instead of the endpoint in the catalog',
)
adapter_group.add_argument( adapter_group.add_argument(
'--os-api-version', '--os-api-version',
metavar='<name>', metavar='<name>',
default=os.environ.get('OS_API_VERSION', None), default=os.environ.get('OS_API_VERSION', None),
help='Which version of the service API to use') help='Which version of the service API to use',
)
# TODO(efried): Move this to loading.adapter.Adapter # TODO(efried): Move this to loading.adapter.Adapter
@classmethod @classmethod
@ -469,66 +501,62 @@ class Adapter(object):
""" """
service_env = service_type.upper().replace('-', '_') service_env = service_type.upper().replace('-', '_')
adapter_group = parser.add_argument_group( adapter_group = parser.add_argument_group(
'{service_type} Service Options'.format( f'{service_type.title()} Service Options',
service_type=service_type.title()), f'Options controlling the specialization of the {service_type.title()}'
'Options controlling the specialization of the {service_type}' ' API Connection from information found in the catalog',
' API Connection from information found in the catalog'.format( )
service_type=service_type.title()))
adapter_group.add_argument( adapter_group.add_argument(
'--os-{service_type}-service-type'.format( f'--os-{service_type}-service-type',
service_type=service_type),
metavar='<name>', metavar='<name>',
default=os.environ.get( default=os.environ.get(f'OS_{service_env}_SERVICE_TYPE', None),
'OS_{service_type}_SERVICE_TYPE'.format( help=(
service_type=service_env), None), 'Service type to request from the catalog for the'
help=('Service type to request from the catalog for the' f' {service_type} service'
' {service_type} service'.format( ),
service_type=service_type))) )
adapter_group.add_argument( adapter_group.add_argument(
'--os-{service_type}-service-name'.format( f'--os-{service_type}-service-name',
service_type=service_type),
metavar='<name>', metavar='<name>',
default=os.environ.get( default=os.environ.get(f'OS_{service_env}_SERVICE_NAME', None),
'OS_{service_type}_SERVICE_NAME'.format( help=(
service_type=service_env), None), 'Service name to request from the catalog for the'
help=('Service name to request from the catalog for the' f' {service_type} service'
' {service_type} service'.format( ),
service_type=service_type))) )
adapter_group.add_argument( adapter_group.add_argument(
'--os-{service_type}-interface'.format( f'--os-{service_type}-interface',
service_type=service_type),
metavar='<name>', metavar='<name>',
default=os.environ.get( default=os.environ.get(f'OS_{service_env}_INTERFACE', None),
'OS_{service_type}_INTERFACE'.format( help=(
service_type=service_env), None), f'API Interface to use for the {service_type} service'
help=('API Interface to use for the {service_type} service' ' [public, internal, admin]'
' [public, internal, admin]'.format( ),
service_type=service_type))) )
adapter_group.add_argument( adapter_group.add_argument(
'--os-{service_type}-api-version'.format( f'--os-{service_type}-api-version',
service_type=service_type),
metavar='<name>', metavar='<name>',
default=os.environ.get( default=os.environ.get(f'OS_{service_env}_API_VERSION', None),
'OS_{service_type}_API_VERSION'.format( help=(
service_type=service_env), None), 'Which version of the service API to use for'
help=('Which version of the service API to use for' f' the {service_type} service'
' the {service_type} service'.format( ),
service_type=service_type))) )
adapter_group.add_argument( adapter_group.add_argument(
'--os-{service_type}-endpoint-override'.format( f'--os-{service_type}-endpoint-override',
service_type=service_type),
metavar='<name>', metavar='<name>',
default=os.environ.get( default=os.environ.get(
'OS_{service_type}_ENDPOINT_OVERRIDE'.format( f'OS_{service_env}_ENDPOINT_OVERRIDE', None
service_type=service_env), None), ),
help=('Endpoint to use for the {service_type} service' help=(
' instead of the endpoint in the catalog'.format( f'Endpoint to use for the {service_type} service'
service_type=service_type))) ' instead of the endpoint in the catalog'
),
)
class LegacyJsonAdapter(Adapter): class LegacyJsonAdapter(Adapter):
@ -549,7 +577,7 @@ class LegacyJsonAdapter(Adapter):
except KeyError: except KeyError:
pass pass
resp = super(LegacyJsonAdapter, self).request(*args, **kwargs) resp = super().request(*args, **kwargs)
try: try:
body = resp.json() body = resp.json()

View File

@ -62,18 +62,13 @@ def get_version_data(session, url, authenticated=None, version_header=None):
The return is a list of dicts of the form:: The return is a list of dicts of the form::
[{ [
{
'status': 'STABLE', 'status': 'STABLE',
'id': 'v2.3', 'id': 'v2.3',
'links': [ 'links': [
{ {'href': 'http://network.example.com/v2.3', 'rel': 'self'},
'href': 'http://network.example.com/v2.3', {'href': 'http://network.example.com/', 'rel': 'collection'},
'rel': 'self',
},
{
'href': 'http://network.example.com/',
'rel': 'collection',
},
], ],
'min_version': '2.0', 'min_version': '2.0',
'max_version': '2.7', 'max_version': '2.7',
@ -117,7 +112,8 @@ def get_version_data(session, url, authenticated=None, version_header=None):
# it's the only thing returning a [] here - and that's ok. # it's the only thing returning a [] here - and that's ok.
if isinstance(body_resp, list): if isinstance(body_resp, list):
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
'Invalid Response - List returned instead of dict') 'Invalid Response - List returned instead of dict'
)
# In the event of querying a root URL we will get back a list of # In the event of querying a root URL we will get back a list of
# available versions. # available versions.
@ -163,8 +159,9 @@ def get_version_data(session, url, authenticated=None, version_header=None):
return [body_resp] return [body_resp]
err_text = resp.text[:50] + '...' if len(resp.text) > 50 else resp.text err_text = resp.text[:50] + '...' if len(resp.text) > 50 else resp.text
raise exceptions.DiscoveryFailure('Invalid Response - Bad version data ' raise exceptions.DiscoveryFailure(
'returned: %s' % err_text) 'Invalid Response - Bad version data ' f'returned: {err_text}'
)
def normalize_version_number(version): def normalize_version_number(version):
@ -253,11 +250,12 @@ def normalize_version_number(version):
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
raise TypeError('Invalid version specified: %s' % version) raise TypeError(f'Invalid version specified: {version}')
def _normalize_version_args( def _normalize_version_args(
version, min_version, max_version, service_type=None): version, min_version, max_version, service_type=None
):
# The sins of our fathers become the blood on our hands. # The sins of our fathers become the blood on our hands.
# If a user requests an old-style service type such as volumev2, then they # If a user requests an old-style service type such as volumev2, then they
# are inherently requesting the major API version 2. It's not a good # are inherently requesting the major API version 2. It's not a good
@ -270,17 +268,20 @@ def _normalize_version_args(
# as this, but in order to move forward without breaking people, we have # as this, but in order to move forward without breaking people, we have
# to just cry in the corner while striking ourselves with thorned branches. # to just cry in the corner while striking ourselves with thorned branches.
# That said, for sure only do this hack for officially known service_types. # That said, for sure only do this hack for officially known service_types.
if (service_type and if (
_SERVICE_TYPES.is_known(service_type) and service_type
service_type[-1].isdigit() and and _SERVICE_TYPES.is_known(service_type)
service_type[-2] == 'v'): and service_type[-1].isdigit()
and service_type[-2] == 'v'
):
implied_version = normalize_version_number(service_type[-1]) implied_version = normalize_version_number(service_type[-1])
else: else:
implied_version = None implied_version = None
if version and (min_version or max_version): if version and (min_version or max_version):
raise ValueError( raise ValueError(
"version is mutually exclusive with min_version and max_version") "version is mutually exclusive with min_version and max_version"
)
if version: if version:
# Explode this into min_version and max_version # Explode this into min_version and max_version
@ -291,15 +292,16 @@ def _normalize_version_args(
raise exceptions.ImpliedVersionMismatch( raise exceptions.ImpliedVersionMismatch(
service_type=service_type, service_type=service_type,
implied=implied_version, implied=implied_version,
given=version_to_string(version)) given=version_to_string(version),
)
return min_version, max_version return min_version, max_version
if min_version == 'latest': if min_version == 'latest':
if max_version not in (None, 'latest'): if max_version not in (None, 'latest'):
raise ValueError( raise ValueError(
"min_version is 'latest' and max_version is {max_version}" f"min_version is 'latest' and max_version is {max_version}"
" but is only allowed to be 'latest' or None".format( " but is only allowed to be 'latest' or None"
max_version=max_version)) )
max_version = 'latest' max_version = 'latest'
# Normalize e.g. empty string to None # Normalize e.g. empty string to None
@ -326,7 +328,8 @@ def _normalize_version_args(
raise exceptions.ImpliedMinVersionMismatch( raise exceptions.ImpliedMinVersionMismatch(
service_type=service_type, service_type=service_type,
implied=implied_version, implied=implied_version,
given=version_to_string(min_version)) given=version_to_string(min_version),
)
else: else:
min_version = implied_version min_version = implied_version
@ -338,7 +341,8 @@ def _normalize_version_args(
raise exceptions.ImpliedMaxVersionMismatch( raise exceptions.ImpliedMaxVersionMismatch(
service_type=service_type, service_type=service_type,
implied=implied_version, implied=implied_version,
given=version_to_string(max_version)) given=version_to_string(max_version),
)
else: else:
max_version = (implied_version[0], LATEST) max_version = (implied_version[0], LATEST)
return min_version, max_version return min_version, max_version
@ -477,7 +481,8 @@ def _combine_relative_url(discovery_url, version_url):
path, path,
parsed_version_url.params, parsed_version_url.params,
parsed_version_url.query, parsed_version_url.query,
parsed_version_url.fragment).geturl() parsed_version_url.fragment,
).geturl()
def _version_from_url(url): def _version_from_url(url):
@ -499,7 +504,7 @@ def _version_from_url(url):
return None return None
class Status(object): class Status:
CURRENT = 'CURRENT' CURRENT = 'CURRENT'
SUPPORTED = 'SUPPORTED' SUPPORTED = 'SUPPORTED'
DEPRECATED = 'DEPRECATED' DEPRECATED = 'DEPRECATED'
@ -528,19 +533,23 @@ class Status(object):
return status return status
class Discover(object): class Discover:
CURRENT_STATUSES = ('stable', 'current', 'supported') CURRENT_STATUSES = ('stable', 'current', 'supported')
DEPRECATED_STATUSES = ('deprecated',) DEPRECATED_STATUSES = ('deprecated',)
EXPERIMENTAL_STATUSES = ('experimental',) EXPERIMENTAL_STATUSES = ('experimental',)
def __init__(self, session, url, authenticated=None): def __init__(self, session, url, authenticated=None):
self._url = url self._url = url
self._data = get_version_data(session, url, self._data = get_version_data(
authenticated=authenticated) session, url, authenticated=authenticated
)
def raw_version_data(self, allow_experimental=False, def raw_version_data(
allow_deprecated=True, allow_unknown=False): self,
allow_experimental=False,
allow_deprecated=True,
allow_unknown=False,
):
"""Get raw version information from URL. """Get raw version information from URL.
Raw data indicates that only minimal validation processing is performed Raw data indicates that only minimal validation processing is performed
@ -560,8 +569,10 @@ class Discover(object):
try: try:
status = v['status'] status = v['status']
except KeyError: except KeyError:
_LOGGER.warning('Skipping over invalid version data. ' _LOGGER.warning(
'No stability status in version.') 'Skipping over invalid version data. '
'No stability status in version.'
)
continue continue
status = status.lower() status = status.lower()
@ -633,8 +644,10 @@ class Discover(object):
rel = link['rel'] rel = link['rel']
url = _combine_relative_url(self._url, link['href']) url = _combine_relative_url(self._url, link['href'])
except (KeyError, TypeError): except (KeyError, TypeError):
_LOGGER.info('Skipping invalid version link. ' _LOGGER.info(
'Missing link URL or relationship.') 'Skipping invalid version link. '
'Missing link URL or relationship.'
)
continue continue
if rel.lower() == 'self': if rel.lower() == 'self':
@ -642,12 +655,15 @@ class Discover(object):
elif rel.lower() == 'collection': elif rel.lower() == 'collection':
collection_url = url collection_url = url
if not self_url: if not self_url:
_LOGGER.info('Skipping invalid version data. ' _LOGGER.info(
'Missing link to endpoint.') 'Skipping invalid version data. '
'Missing link to endpoint.'
)
continue continue
versions.append( versions.append(
VersionData(version=version_number, VersionData(
version=version_number,
url=self_url, url=self_url,
collection=collection_url, collection=collection_url,
min_microversion=min_microversion, min_microversion=min_microversion,
@ -655,7 +671,9 @@ class Discover(object):
next_min_version=next_min_version, next_min_version=next_min_version,
not_before=not_before, not_before=not_before,
status=Status.normalize(v['status']), status=Status.normalize(v['status']),
raw_status=v['status'])) raw_status=v['status'],
)
)
versions.sort(key=lambda v: v['version'], reverse=reverse) versions.sort(key=lambda v: v['version'], reverse=reverse)
return versions return versions
@ -723,9 +741,9 @@ class Discover(object):
data = self.data_for(version, **kwargs) data = self.data_for(version, **kwargs)
return data['url'] if data else None return data['url'] if data else None
def versioned_data_for(self, url=None, def versioned_data_for(
min_version=None, max_version=None, self, url=None, min_version=None, max_version=None, **kwargs
**kwargs): ):
"""Return endpoint data for the service at a url. """Return endpoint data for the service at a url.
min_version and max_version can be given either as strings or tuples. min_version and max_version can be given either as strings or tuples.
@ -747,15 +765,17 @@ class Discover(object):
:rtype: dict :rtype: dict
""" """
min_version, max_version = _normalize_version_args( min_version, max_version = _normalize_version_args(
None, min_version, max_version) None, min_version, max_version
)
no_version = not max_version and not min_version no_version = not max_version and not min_version
version_data = self.version_data(reverse=True, **kwargs) version_data = self.version_data(reverse=True, **kwargs)
# If we don't have to check a min_version, we can short # If we don't have to check a min_version, we can short
# circuit anything else # circuit anything else
if (max_version == (LATEST, LATEST) and if max_version == (LATEST, LATEST) and (
(not min_version or min_version == (LATEST, LATEST))): not min_version or min_version == (LATEST, LATEST)
):
# because we reverse we can just take the first entry # because we reverse we can just take the first entry
return version_data[0] return version_data[0]
@ -774,8 +794,11 @@ class Discover(object):
if _latest_soft_match(min_version, data['version']): if _latest_soft_match(min_version, data['version']):
return data return data
# Only validate version bounds if versions were specified # Only validate version bounds if versions were specified
if min_version and max_version and version_between( if (
min_version, max_version, data['version']): min_version
and max_version
and version_between(min_version, max_version, data['version'])
):
return data return data
# If there is no version requested and we could not find a matching # If there is no version requested and we could not find a matching
@ -805,8 +828,9 @@ class Discover(object):
:returns: The url for the specified version or None if no match. :returns: The url for the specified version or None if no match.
:rtype: str :rtype: str
""" """
data = self.versioned_data_for(min_version=min_version, data = self.versioned_data_for(
max_version=max_version, **kwargs) min_version=min_version, max_version=max_version, **kwargs
)
return data['url'] if data else None return data['url'] if data else None
@ -823,8 +847,9 @@ class VersionData(dict):
next_min_version=None, next_min_version=None,
not_before=None, not_before=None,
status='CURRENT', status='CURRENT',
raw_status=None): raw_status=None,
super(VersionData, self).__init__() ):
super().__init__()
self['version'] = version self['version'] = version
self['url'] = url self['url'] = url
self['collection'] = collection self['collection'] = collection
@ -883,7 +908,7 @@ class VersionData(dict):
return self.get('raw_status') return self.get('raw_status')
class EndpointData(object): class EndpointData:
"""Normalized information about a discovered endpoint. """Normalized information about a discovered endpoint.
Contains url, version, microversion, interface and region information. Contains url, version, microversion, interface and region information.
@ -894,7 +919,8 @@ class EndpointData(object):
possibilities. possibilities.
""" """
def __init__(self, def __init__(
self,
catalog_url=None, catalog_url=None,
service_url=None, service_url=None,
service_type=None, service_type=None,
@ -910,7 +936,8 @@ class EndpointData(object):
max_microversion=None, max_microversion=None,
next_min_version=None, next_min_version=None,
not_before=None, not_before=None,
status=None): status=None,
):
self.catalog_url = catalog_url self.catalog_url = catalog_url
self.service_url = service_url self.service_url = service_url
self.service_type = service_type self.service_type = service_type
@ -962,19 +989,35 @@ class EndpointData(object):
def __str__(self): def __str__(self):
"""Produce a string like EndpointData{key=val, ...}, for debugging.""" """Produce a string like EndpointData{key=val, ...}, for debugging."""
str_attrs = ( str_attrs = (
'api_version', 'catalog_url', 'endpoint_id', 'interface', 'api_version',
'major_version', 'max_microversion', 'min_microversion', 'catalog_url',
'next_min_version', 'not_before', 'raw_endpoint', 'region_name', 'endpoint_id',
'service_id', 'service_name', 'service_type', 'service_url', 'url') 'interface',
return "%s{%s}" % (self.__class__.__name__, ', '.join( 'major_version',
["%s=%s" % (attr, getattr(self, attr)) for attr in str_attrs])) 'max_microversion',
'min_microversion',
'next_min_version',
'not_before',
'raw_endpoint',
'region_name',
'service_id',
'service_name',
'service_type',
'service_url',
'url',
)
return "{}{{{}}}".format(
self.__class__.__name__,
', '.join([f"{attr}={getattr(self, attr)}" for attr in str_attrs]),
)
@property @property
def url(self): def url(self):
return self.service_url or self.catalog_url return self.service_url or self.catalog_url
def get_current_versioned_data(self, session, allow=None, cache=None, def get_current_versioned_data(
project_id=None): self, session, allow=None, cache=None, project_id=None
):
"""Run version discovery on the current endpoint. """Run version discovery on the current endpoint.
A simplified version of get_versioned_data, get_current_versioned_data A simplified version of get_versioned_data, get_current_versioned_data
@ -1001,16 +1044,29 @@ class EndpointData(object):
could not be discovered. could not be discovered.
""" """
min_version, max_version = _normalize_version_args( min_version, max_version = _normalize_version_args(
self.api_version, None, None) self.api_version, None, None
)
return self.get_versioned_data( return self.get_versioned_data(
session=session, allow=allow, cache=cache, allow_version_hack=True, session=session,
allow=allow,
cache=cache,
allow_version_hack=True,
discover_versions=True, discover_versions=True,
min_version=min_version, max_version=max_version) min_version=min_version,
max_version=max_version,
)
def get_versioned_data(self, session, allow=None, cache=None, def get_versioned_data(
allow_version_hack=True, project_id=None, self,
session,
allow=None,
cache=None,
allow_version_hack=True,
project_id=None,
discover_versions=True, discover_versions=True,
min_version=None, max_version=None): min_version=None,
max_version=None,
):
"""Run version discovery for the service described. """Run version discovery for the service described.
Performs Version Discovery and returns a new EndpointData object with Performs Version Discovery and returns a new EndpointData object with
@ -1050,7 +1106,8 @@ class EndpointData(object):
could not be discovered. could not be discovered.
""" """
min_version, max_version = _normalize_version_args( min_version, max_version = _normalize_version_args(
None, min_version, max_version) None, min_version, max_version
)
if not allow: if not allow:
allow = {} allow = {}
@ -1059,10 +1116,15 @@ class EndpointData(object):
new_data = copy.copy(self) new_data = copy.copy(self)
new_data._set_version_info( new_data._set_version_info(
session=session, allow=allow, cache=cache, session=session,
allow_version_hack=allow_version_hack, project_id=project_id, allow=allow,
discover_versions=discover_versions, min_version=min_version, cache=cache,
max_version=max_version) allow_version_hack=allow_version_hack,
project_id=project_id,
discover_versions=discover_versions,
min_version=min_version,
max_version=max_version,
)
return new_data return new_data
def get_all_version_string_data(self, session, project_id=None): def get_all_version_string_data(self, session, project_id=None):
@ -1082,7 +1144,8 @@ class EndpointData(object):
# Ignore errors here - we're just searching for one of the # Ignore errors here - we're just searching for one of the
# URLs that will give us data. # URLs that will give us data.
_LOGGER.debug( _LOGGER.debug(
"Failed attempt at discovery on %s: %s", vers_url, str(e)) "Failed attempt at discovery on %s: %s", vers_url, str(e)
)
continue continue
for version in d.version_string_data(): for version in d.version_string_data():
versions.append(version) versions.append(version)
@ -1109,10 +1172,17 @@ class EndpointData(object):
return [VersionData(url=url, version=version)] return [VersionData(url=url, version=version)]
def _set_version_info(self, session, allow=None, cache=None, def _set_version_info(
allow_version_hack=True, project_id=None, self,
session,
allow=None,
cache=None,
allow_version_hack=True,
project_id=None,
discover_versions=False, discover_versions=False,
min_version=None, max_version=None): min_version=None,
max_version=None,
):
match_url = None match_url = None
no_version = not max_version and not min_version no_version = not max_version and not min_version
@ -1134,38 +1204,48 @@ class EndpointData(object):
# satisfy the request without further work # satisfy the request without further work
if self._disc: if self._disc:
discovered_data = self._disc.versioned_data_for( discovered_data = self._disc.versioned_data_for(
min_version=min_version, max_version=max_version, min_version=min_version,
url=match_url, **allow) max_version=max_version,
url=match_url,
**allow,
)
if not discovered_data: if not discovered_data:
self._run_discovery( self._run_discovery(
session=session, cache=cache, session=session,
min_version=min_version, max_version=max_version, cache=cache,
project_id=project_id, allow_version_hack=allow_version_hack, min_version=min_version,
discover_versions=discover_versions) max_version=max_version,
project_id=project_id,
allow_version_hack=allow_version_hack,
discover_versions=discover_versions,
)
if not self._disc: if not self._disc:
return return
discovered_data = self._disc.versioned_data_for( discovered_data = self._disc.versioned_data_for(
min_version=min_version, max_version=max_version, min_version=min_version,
url=match_url, **allow) max_version=max_version,
url=match_url,
**allow,
)
if not discovered_data: if not discovered_data:
if min_version and not max_version: if min_version and not max_version:
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
"Minimum version {min_version} was not found".format( f"Minimum version {version_to_string(min_version)} was not found"
min_version=version_to_string(min_version))) )
elif max_version and not min_version: elif max_version and not min_version:
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
"Maximum version {max_version} was not found".format( f"Maximum version {version_to_string(max_version)} was not found"
max_version=version_to_string(max_version))) )
elif min_version and max_version: elif min_version and max_version:
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
"No version found between {min_version}" f"No version found between {version_to_string(min_version)}"
" and {max_version}".format( f" and {version_to_string(max_version)}"
min_version=version_to_string(min_version), )
max_version=version_to_string(max_version)))
else: else:
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
"No version data found remotely at all") "No version data found remotely at all"
)
self.min_microversion = discovered_data['min_microversion'] self.min_microversion = discovered_data['min_microversion']
self.max_microversion = discovered_data['max_microversion'] self.max_microversion = discovered_data['max_microversion']
@ -1184,25 +1264,35 @@ class EndpointData(object):
# for example a "v2" path from http://host/admin should resolve as # for example a "v2" path from http://host/admin should resolve as
# http://host/admin/v2 where it would otherwise be host/v2. # http://host/admin/v2 where it would otherwise be host/v2.
# This has no effect on absolute urls returned from url_for. # This has no effect on absolute urls returned from url_for.
url = urllib.parse.urljoin(self._disc._url.rstrip('/') + '/', url = urllib.parse.urljoin(
discovered_url) self._disc._url.rstrip('/') + '/', discovered_url
)
# If we had to pop a project_id from the catalog_url, put it back on # If we had to pop a project_id from the catalog_url, put it back on
if self._saved_project_id: if self._saved_project_id:
url = urllib.parse.urljoin(url.rstrip('/') + '/', url = urllib.parse.urljoin(
self._saved_project_id) url.rstrip('/') + '/', self._saved_project_id
)
self.service_url = url self.service_url = url
def _run_discovery(self, session, cache, min_version, max_version, def _run_discovery(
project_id, allow_version_hack, discover_versions): self,
session,
cache,
min_version,
max_version,
project_id,
allow_version_hack,
discover_versions,
):
tried = set() tried = set()
for vers_url in self._get_discovery_url_choices( for vers_url in self._get_discovery_url_choices(
project_id=project_id, project_id=project_id,
allow_version_hack=allow_version_hack, allow_version_hack=allow_version_hack,
min_version=min_version, min_version=min_version,
max_version=max_version): max_version=max_version,
):
if self._catalog_matches_exactly and not discover_versions: if self._catalog_matches_exactly and not discover_versions:
# The version we started with is correct, and we don't want # The version we started with is correct, and we don't want
# new data # new data
@ -1214,13 +1304,14 @@ class EndpointData(object):
try: try:
self._disc = get_discovery( self._disc = get_discovery(
session, vers_url, session, vers_url, cache=cache, authenticated=False
cache=cache, )
authenticated=False)
break break
except (exceptions.DiscoveryFailure, except (
exceptions.DiscoveryFailure,
exceptions.HttpError, exceptions.HttpError,
exceptions.ConnectionError) as exc: exceptions.ConnectionError,
) as exc:
_LOGGER.debug('No version document at %s: %s', vers_url, exc) _LOGGER.debug('No version document at %s: %s', vers_url, exc)
continue continue
if not self._disc: if not self._disc:
@ -1242,7 +1333,9 @@ class EndpointData(object):
_LOGGER.warning( _LOGGER.warning(
'Failed to contact the endpoint at %s for ' 'Failed to contact the endpoint at %s for '
'discovery. Fallback to using that endpoint as ' 'discovery. Fallback to using that endpoint as '
'the base url.', self.url) 'the base url.',
self.url,
)
return return
else: else:
@ -1252,14 +1345,20 @@ class EndpointData(object):
# date enough to properly specify a version and keystoneauth # date enough to properly specify a version and keystoneauth
# can't deliver. # can't deliver.
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
"Unable to find a version discovery document at %s, " "Unable to find a version discovery document at {}, "
"the service is unavailable or misconfigured. " "the service is unavailable or misconfigured. "
"Required version range (%s - %s), version hack disabled." "Required version range ({} - {}), version hack disabled.".format(
% (self.url, min_version or "any", max_version or "any")) self.url, min_version or "any", max_version or "any"
)
)
def _get_discovery_url_choices( def _get_discovery_url_choices(
self, project_id=None, allow_version_hack=True, self,
min_version=None, max_version=None): project_id=None,
allow_version_hack=True,
min_version=None,
max_version=None,
):
"""Find potential locations for version discovery URLs. """Find potential locations for version discovery URLs.
min_version and max_version are already normalized, so will either be min_version and max_version are already normalized, so will either be
@ -1295,19 +1394,27 @@ class EndpointData(object):
'/'.join(url_parts), '/'.join(url_parts),
url.params, url.params,
url.query, url.query,
url.fragment).geturl() url.fragment,
).geturl()
except TypeError: except TypeError:
pass pass
else: else:
# `is_between` means version bounds were specified *and* the URL # `is_between` means version bounds were specified *and* the URL
# version is between them. # version is between them.
is_between = min_version and max_version and version_between( is_between = (
min_version, max_version, url_version) min_version
exact_match = (is_between and max_version and and max_version
max_version[0] == url_version[0]) and version_between(min_version, max_version, url_version)
high_match = (is_between and max_version and )
max_version[1] != LATEST and exact_match = (
version_match(max_version, url_version)) is_between and max_version and max_version[0] == url_version[0]
)
high_match = (
is_between
and max_version
and max_version[1] != LATEST
and version_match(max_version, url_version)
)
if exact_match or is_between: if exact_match or is_between:
self._catalog_matches_version = True self._catalog_matches_version = True
self._catalog_matches_exactly = exact_match self._catalog_matches_exactly = exact_match
@ -1316,13 +1423,19 @@ class EndpointData(object):
# return it just yet. It's a good option, but unless we # return it just yet. It's a good option, but unless we
# have an exact match or match the max requested, we want # have an exact match or match the max requested, we want
# to try for an unversioned endpoint first. # to try for an unversioned endpoint first.
catalog_discovery = urllib.parse.ParseResult( catalog_discovery = (
urllib.parse.ParseResult(
url.scheme, url.scheme,
url.netloc, url.netloc,
'/'.join(url_parts), '/'.join(url_parts),
url.params, url.params,
url.query, url.query,
url.fragment).geturl().rstrip('/') + '/' url.fragment,
)
.geturl()
.rstrip('/')
+ '/'
)
# If we found a viable catalog endpoint and it's # If we found a viable catalog endpoint and it's
# an exact match or matches the max, go ahead and give # an exact match or matches the max, go ahead and give
@ -1342,7 +1455,8 @@ class EndpointData(object):
'/'.join(url_parts), '/'.join(url_parts),
url.params, url.params,
url.query, url.query,
url.fragment).geturl() url.fragment,
).geturl()
# Since this is potentially us constructing a base URL from the # Since this is potentially us constructing a base URL from the
# versioned URL - we need to make sure it has a trailing /. But # versioned URL - we need to make sure it has a trailing /. But
# we only want to do that if we have built a new URL - not if # we only want to do that if we have built a new URL - not if
@ -1448,7 +1562,8 @@ def get_discovery(session, url, cache=None, authenticated=False):
'', '',
parsed_url.params, parsed_url.params,
parsed_url.query, parsed_url.query,
parsed_url.fragment).geturl() parsed_url.fragment,
).geturl()
for cache in caches: for cache in caches:
disc = cache.get(url) disc = cache.get(url)
@ -1468,7 +1583,7 @@ def get_discovery(session, url, cache=None, authenticated=False):
return disc return disc
class _VersionHacks(object): class _VersionHacks:
"""A container to abstract the list of version hacks. """A container to abstract the list of version hacks.
This could be done as simply a dictionary but is abstracted like this to This could be done as simply a dictionary but is abstracted like this to

View File

@ -28,5 +28,5 @@ class MissingAuthMethods(base.ClientException):
self.methods = body['receipt']['methods'] self.methods = body['receipt']['methods']
self.required_auth_methods = body['required_auth_methods'] self.required_auth_methods = body['required_auth_methods']
self.expires_at = utils.parse_isotime(body['receipt']['expires_at']) self.expires_at = utils.parse_isotime(body['receipt']['expires_at'])
message = "%s: %s" % (self.message, self.required_auth_methods) message = f"{self.message}: {self.required_auth_methods}"
super(MissingAuthMethods, self).__init__(message) super().__init__(message)

View File

@ -13,12 +13,14 @@
from keystoneauth1.exceptions import base from keystoneauth1.exceptions import base
__all__ = ('AuthPluginException', __all__ = (
'AuthPluginException',
'MissingAuthPlugin', 'MissingAuthPlugin',
'NoMatchingPlugin', 'NoMatchingPlugin',
'UnsupportedParameters', 'UnsupportedParameters',
'OptionError', 'OptionError',
'MissingRequiredOptions') 'MissingRequiredOptions',
)
class AuthPluginException(base.ClientException): class AuthPluginException(base.ClientException):
@ -41,8 +43,8 @@ class NoMatchingPlugin(AuthPluginException):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
msg = 'The plugin %s could not be found' % name msg = f'The plugin {name} could not be found'
super(NoMatchingPlugin, self).__init__(msg) super().__init__(msg)
class UnsupportedParameters(AuthPluginException): class UnsupportedParameters(AuthPluginException):
@ -59,7 +61,7 @@ class UnsupportedParameters(AuthPluginException):
self.names = names self.names = names
m = 'The following parameters were given that are unsupported: %s' m = 'The following parameters were given that are unsupported: %s'
super(UnsupportedParameters, self).__init__(m % ', '.join(self.names)) super().__init__(m % ', '.join(self.names))
class OptionError(AuthPluginException): class OptionError(AuthPluginException):
@ -90,4 +92,4 @@ class MissingRequiredOptions(OptionError):
names = ", ".join(o.dest for o in options) names = ", ".join(o.dest for o in options)
m = 'Auth plugin requires parameters which were not given: %s' m = 'Auth plugin requires parameters which were not given: %s'
super(MissingRequiredOptions, self).__init__(m % names) super().__init__(m % names)

View File

@ -21,4 +21,4 @@ class ClientException(Exception):
def __init__(self, message=None): def __init__(self, message=None):
self.message = message or self.message self.message = message or self.message
super(ClientException, self).__init__(self.message) super().__init__(self.message)

View File

@ -13,9 +13,7 @@
from keystoneauth1.exceptions import base from keystoneauth1.exceptions import base
__all__ = ('CatalogException', __all__ = ('CatalogException', 'EmptyCatalog', 'EndpointNotFound')
'EmptyCatalog',
'EndpointNotFound')
class CatalogException(base.ClientException): class CatalogException(base.ClientException):

View File

@ -13,12 +13,14 @@
from keystoneauth1.exceptions import base from keystoneauth1.exceptions import base
__all__ = ('ConnectionError', __all__ = (
'ConnectionError',
'ConnectTimeout', 'ConnectTimeout',
'ConnectFailure', 'ConnectFailure',
'SSLError', 'SSLError',
'RetriableConnectionFailure', 'RetriableConnectionFailure',
'UnknownConnectionError') 'UnknownConnectionError',
)
class RetriableConnectionFailure(Exception): class RetriableConnectionFailure(Exception):
@ -47,5 +49,5 @@ class UnknownConnectionError(ConnectionError):
"""An error was encountered but we don't know what it is.""" """An error was encountered but we don't know what it is."""
def __init__(self, msg, original): def __init__(self, msg, original):
super(UnknownConnectionError, self).__init__(msg) super().__init__(msg)
self.original = original self.original = original

View File

@ -17,11 +17,13 @@ from keystoneauth1.exceptions import base
_SERVICE_TYPES = os_service_types.ServiceTypes() _SERVICE_TYPES = os_service_types.ServiceTypes()
__all__ = ('DiscoveryFailure', __all__ = (
'DiscoveryFailure',
'ImpliedVersionMismatch', 'ImpliedVersionMismatch',
'ImpliedMinVersionMismatch', 'ImpliedMinVersionMismatch',
'ImpliedMaxVersionMismatch', 'ImpliedMaxVersionMismatch',
'VersionNotAvailable') 'VersionNotAvailable',
)
class DiscoveryFailure(base.ClientException): class DiscoveryFailure(base.ClientException):
@ -36,17 +38,12 @@ class ImpliedVersionMismatch(ValueError):
label = 'version' label = 'version'
def __init__(self, service_type, implied, given): def __init__(self, service_type, implied, given):
super(ImpliedVersionMismatch, self).__init__( super().__init__(
"service_type {service_type} was given which implies" f"service_type {service_type} was given which implies"
" major API version {implied} but {label} of" f" major API version {str(implied[0])} but {self.label} of"
" {given} was also given. Please update your code" f" {given} was also given. Please update your code"
" to use the official service_type {official_type}.".format( f" to use the official service_type {_SERVICE_TYPES.get_service_type(service_type)}."
service_type=service_type, )
implied=str(implied[0]),
given=given,
label=self.label,
official_type=_SERVICE_TYPES.get_service_type(service_type),
))
class ImpliedMinVersionMismatch(ImpliedVersionMismatch): class ImpliedMinVersionMismatch(ImpliedVersionMismatch):

View File

@ -25,8 +25,8 @@ from keystoneauth1.exceptions import auth
from keystoneauth1.exceptions import base from keystoneauth1.exceptions import base
__all__ = ('HttpError', __all__ = (
'HttpError',
'HTTPClientError', 'HTTPClientError',
'BadRequest', 'BadRequest',
'Unauthorized', 'Unauthorized',
@ -47,7 +47,6 @@ __all__ = ('HttpError',
'RequestedRangeNotSatisfiable', 'RequestedRangeNotSatisfiable',
'ExpectationFailed', 'ExpectationFailed',
'UnprocessableEntity', 'UnprocessableEntity',
'HttpServerError', 'HttpServerError',
'InternalServerError', 'InternalServerError',
'HttpNotImplemented', 'HttpNotImplemented',
@ -55,8 +54,8 @@ __all__ = ('HttpError',
'ServiceUnavailable', 'ServiceUnavailable',
'GatewayTimeout', 'GatewayTimeout',
'HttpVersionNotSupported', 'HttpVersionNotSupported',
'from_response',
'from_response') )
class HttpError(base.ClientException): class HttpError(base.ClientException):
@ -65,10 +64,17 @@ class HttpError(base.ClientException):
http_status = 0 http_status = 0
message = "HTTP Error" message = "HTTP Error"
def __init__(self, message=None, details=None, def __init__(
response=None, request_id=None, self,
url=None, method=None, http_status=None, message=None,
retry_after=0): details=None,
response=None,
request_id=None,
url=None,
method=None,
http_status=None,
retry_after=0,
):
self.http_status = http_status or self.http_status self.http_status = http_status or self.http_status
self.message = message or self.message self.message = message or self.message
self.details = details self.details = details
@ -76,11 +82,11 @@ class HttpError(base.ClientException):
self.response = response self.response = response
self.url = url self.url = url
self.method = method self.method = method
formatted_string = "%s (HTTP %s)" % (self.message, self.http_status) formatted_string = f"{self.message} (HTTP {self.http_status})"
self.retry_after = retry_after self.retry_after = retry_after
if request_id: if request_id:
formatted_string += " (Request-ID: %s)" % request_id formatted_string += f" (Request-ID: {request_id})"
super(HttpError, self).__init__(formatted_string) super().__init__(formatted_string)
class HTTPClientError(HttpError): class HTTPClientError(HttpError):
@ -256,7 +262,7 @@ class RequestEntityTooLarge(HTTPClientError):
except (KeyError, ValueError): except (KeyError, ValueError):
self.retry_after = 0 self.retry_after = 0
super(RequestEntityTooLarge, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class RequestUriTooLong(HTTPClientError): class RequestUriTooLong(HTTPClientError):
@ -377,11 +383,11 @@ class HttpVersionNotSupported(HttpServerError):
# _code_map contains all the classes that have http_status attribute. # _code_map contains all the classes that have http_status attribute.
_code_map = dict( _code_map = {
(getattr(obj, 'http_status', None), obj) getattr(obj, 'http_status', None): obj
for name, obj in vars(sys.modules[__name__]).items() for name, obj in vars(sys.modules[__name__]).items()
if inspect.isclass(obj) and getattr(obj, 'http_status', False) if inspect.isclass(obj) and getattr(obj, 'http_status', False)
) }
def from_response(response, method, url): def from_response(response, method, url):
@ -414,8 +420,9 @@ def from_response(response, method, url):
error = body["error"] error = body["error"]
kwargs["message"] = error.get("message") kwargs["message"] = error.get("message")
kwargs["details"] = error.get("details") kwargs["details"] = error.get("details")
elif (isinstance(body, dict) and elif isinstance(body, dict) and isinstance(
isinstance(body.get("errors"), list)): body.get("errors"), list
):
# if the error response follows the API SIG guidelines, it # if the error response follows the API SIG guidelines, it
# will return a list of errors. in this case, only the first # will return a list of errors. in this case, only the first
# error is shown, but if there are multiple the user will be # error is shown, but if there are multiple the user will be
@ -429,13 +436,15 @@ def from_response(response, method, url):
if len(errors) > 1: if len(errors) > 1:
# if there is more than one error, let the user know # if there is more than one error, let the user know
# that multiple were seen. # that multiple were seen.
msg_hdr = ("Multiple error responses, " msg_hdr = (
"showing first only: ") "Multiple error responses, showing first only: "
)
else: else:
msg_hdr = "" msg_hdr = ""
kwargs["message"] = "{}{}".format(msg_hdr, kwargs["message"] = "{}{}".format(
errors[0].get("title")) msg_hdr, errors[0].get("title")
)
kwargs["details"] = errors[0].get("detail") kwargs["details"] = errors[0].get("detail")
else: else:
kwargs["message"] = "Unrecognized schema in response body." kwargs["message"] = "Unrecognized schema in response body."
@ -444,8 +453,10 @@ def from_response(response, method, url):
kwargs["details"] = response.text kwargs["details"] = response.text
# we check explicity for 401 in case of auth receipts # we check explicity for 401 in case of auth receipts
if (response.status_code == 401 if (
and "Openstack-Auth-Receipt" in response.headers): response.status_code == 401
and "Openstack-Auth-Receipt" in response.headers
):
return auth.MissingAuthMethods(response) return auth.MissingAuthMethods(response)
try: try:

View File

@ -14,18 +14,21 @@
from keystoneauth1.exceptions import auth_plugins from keystoneauth1.exceptions import auth_plugins
__all__ = ( __all__ = (
'InvalidDiscoveryEndpoint', 'InvalidOidcDiscoveryDocument', 'InvalidDiscoveryEndpoint',
'OidcAccessTokenEndpointNotFound', 'OidcAuthorizationEndpointNotFound', 'InvalidOidcDiscoveryDocument',
'OidcGrantTypeMissmatch', 'OidcPluginNotSupported', 'OidcAccessTokenEndpointNotFound',
'OidcAuthorizationEndpointNotFound',
'OidcGrantTypeMissmatch',
'OidcPluginNotSupported',
) )
class InvalidDiscoveryEndpoint(auth_plugins.AuthPluginException): class InvalidDiscoveryEndpoint(auth_plugins.AuthPluginException):
message = "OpenID Connect Discovery Document endpoint not set.""" message = "OpenID Connect Discovery Document endpoint not set."
class InvalidOidcDiscoveryDocument(auth_plugins.AuthPluginException): class InvalidOidcDiscoveryDocument(auth_plugins.AuthPluginException):
message = "OpenID Connect Discovery Document is not valid JSON.""" message = "OpenID Connect Discovery Document is not valid JSON."
class OidcAccessTokenEndpointNotFound(auth_plugins.AuthPluginException): class OidcAccessTokenEndpointNotFound(auth_plugins.AuthPluginException):
@ -37,7 +40,8 @@ class OidcAuthorizationEndpointNotFound(auth_plugins.AuthPluginException):
class OidcDeviceAuthorizationEndpointNotFound( class OidcDeviceAuthorizationEndpointNotFound(
auth_plugins.AuthPluginException): auth_plugins.AuthPluginException
):
message = "OpenID Connect device authorization endpoint not provided." message = "OpenID Connect device authorization endpoint not provided."

View File

@ -21,5 +21,5 @@ class InvalidResponse(base.ClientException):
message = "Invalid response from server." message = "Invalid response from server."
def __init__(self, response): def __init__(self, response):
super(InvalidResponse, self).__init__() super().__init__()
self.response = response self.response = response

View File

@ -20,5 +20,5 @@ class ServiceProviderNotFound(base.ClientException):
def __init__(self, sp_id): def __init__(self, sp_id):
self.sp_id = sp_id self.sp_id = sp_id
msg = 'The Service Provider %(sp)s could not be found' % {'sp': sp_id} msg = f'The Service Provider {sp_id} could not be found'
super(ServiceProviderNotFound, self).__init__(msg) super().__init__(msg)

View File

@ -15,7 +15,6 @@ from keystoneauth1 import loading
class Saml2Password(loading.BaseFederationLoader): class Saml2Password(loading.BaseFederationLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return _saml2.V3Saml2Password return _saml2.V3Saml2Password
@ -25,25 +24,29 @@ class Saml2Password(loading.BaseFederationLoader):
return _saml2._V3_SAML2_AVAILABLE return _saml2._V3_SAML2_AVAILABLE
def get_options(self): def get_options(self):
options = super(Saml2Password, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('identity-provider-url', [
loading.Opt(
'identity-provider-url',
required=True, required=True,
help=('An Identity Provider URL, where the SAML2 ' help=(
'authentication request will be sent.')), 'An Identity Provider URL, where the SAML2 '
'authentication request will be sent.'
),
),
loading.Opt('username', help='Username', required=True), loading.Opt('username', help='Username', required=True),
loading.Opt('password', loading.Opt(
secret=True, 'password', secret=True, help='Password', required=True
help='Password', ),
required=True) ]
]) )
return options return options
class ADFSPassword(loading.BaseFederationLoader): class ADFSPassword(loading.BaseFederationLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return _saml2.V3ADFSPassword return _saml2.V3ADFSPassword
@ -53,24 +56,33 @@ class ADFSPassword(loading.BaseFederationLoader):
return _saml2._V3_ADFS_AVAILABLE return _saml2._V3_ADFS_AVAILABLE
def get_options(self): def get_options(self):
options = super(ADFSPassword, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('identity-provider-url', [
loading.Opt(
'identity-provider-url',
required=True, required=True,
help=('An Identity Provider URL, where the SAML ' help=(
'authentication request will be sent.')), 'An Identity Provider URL, where the SAML '
loading.Opt('service-provider-endpoint', 'authentication request will be sent.'
),
),
loading.Opt(
'service-provider-endpoint',
required=True, required=True,
help="Service Provider's Endpoint"), help="Service Provider's Endpoint",
loading.Opt('service-provider-entity-id', ),
loading.Opt(
'service-provider-entity-id',
required=True, required=True,
help="Service Provider's SAML Entity ID"), help="Service Provider's SAML Entity ID",
),
loading.Opt('username', help='Username', required=True), loading.Opt('username', help='Username', required=True),
loading.Opt('password', loading.Opt(
secret=True, 'password', secret=True, required=True, help='Password'
required=True, ),
help='Password') ]
]) )
return options return options

View File

@ -35,21 +35,34 @@ class Password(base.BaseSAMLPlugin):
NAMESPACES = { NAMESPACES = {
's': 'http://www.w3.org/2003/05/soap-envelope', 's': 'http://www.w3.org/2003/05/soap-envelope',
'a': 'http://www.w3.org/2005/08/addressing', 'a': 'http://www.w3.org/2005/08/addressing',
'u': ('http://docs.oasis-open.org/wss/2004/01/oasis-200401-' 'u': (
'wss-wssecurity-utility-1.0.xsd') 'http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd'
),
} }
ADFS_TOKEN_NAMESPACES = { ADFS_TOKEN_NAMESPACES = {
's': 'http://www.w3.org/2003/05/soap-envelope', 's': 'http://www.w3.org/2003/05/soap-envelope',
't': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512' 't': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512',
} }
ADFS_ASSERTION_XPATH = ('/s:Envelope/s:Body' ADFS_ASSERTION_XPATH = (
'/s:Envelope/s:Body'
'/t:RequestSecurityTokenResponseCollection' '/t:RequestSecurityTokenResponseCollection'
'/t:RequestSecurityTokenResponse') '/t:RequestSecurityTokenResponse'
)
def __init__(self, auth_url, identity_provider, identity_provider_url, def __init__(
service_provider_endpoint, username, password, self,
protocol, service_provider_entity_id=None, **kwargs): auth_url,
identity_provider,
identity_provider_url,
service_provider_endpoint,
username,
password,
protocol,
service_provider_entity_id=None,
**kwargs,
):
"""Constructor for ``ADFSPassword``. """Constructor for ``ADFSPassword``.
:param auth_url: URL of the Identity Service :param auth_url: URL of the Identity Service
@ -78,10 +91,15 @@ class Password(base.BaseSAMLPlugin):
:type password: string :type password: string
""" """
super(Password, self).__init__( super().__init__(
auth_url=auth_url, identity_provider=identity_provider, auth_url=auth_url,
identity_provider=identity_provider,
identity_provider_url=identity_provider_url, identity_provider_url=identity_provider_url,
username=username, password=password, protocol=protocol, **kwargs) username=username,
password=password,
protocol=protocol,
**kwargs,
)
self.service_provider_endpoint = service_provider_endpoint self.service_provider_endpoint = service_provider_endpoint
self.service_provider_entity_id = service_provider_entity_id self.service_provider_entity_id = service_provider_entity_id
@ -123,9 +141,11 @@ class Password(base.BaseSAMLPlugin):
""" """
date_created = datetime.datetime.now(datetime.timezone.utc).replace( date_created = datetime.datetime.now(datetime.timezone.utc).replace(
tzinfo=None) tzinfo=None
)
date_expires = date_created + datetime.timedelta( date_expires = date_created + datetime.timedelta(
seconds=self.DEFAULT_ADFS_TOKEN_EXPIRATION) seconds=self.DEFAULT_ADFS_TOKEN_EXPIRATION
)
return [_time.strftime(fmt) for _time in (date_created, date_expires)] return [_time.strftime(fmt) for _time in (date_created, date_expires)]
def _prepare_adfs_request(self): def _prepare_adfs_request(self):
@ -135,132 +155,188 @@ class Password(base.BaseSAMLPlugin):
""" """
WSS_SECURITY_NAMESPACE = { WSS_SECURITY_NAMESPACE = {
'o': ('http://docs.oasis-open.org/wss/2004/01/oasis-200401-' 'o': (
'wss-wssecurity-secext-1.0.xsd') 'http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd'
)
} }
TRUST_NAMESPACE = { TRUST_NAMESPACE = {
'trust': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512' 'trust': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512'
} }
WSP_NAMESPACE = { WSP_NAMESPACE = {'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy'}
'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy'
}
WSA_NAMESPACE = { WSA_NAMESPACE = {'wsa': 'http://www.w3.org/2005/08/addressing'}
'wsa': 'http://www.w3.org/2005/08/addressing'
}
root = etree.Element( root = etree.Element(
'{http://www.w3.org/2003/05/soap-envelope}Envelope', '{http://www.w3.org/2003/05/soap-envelope}Envelope',
nsmap=self.NAMESPACES) nsmap=self.NAMESPACES,
)
header = etree.SubElement( header = etree.SubElement(
root, '{http://www.w3.org/2003/05/soap-envelope}Header') root, '{http://www.w3.org/2003/05/soap-envelope}Header'
)
action = etree.SubElement( action = etree.SubElement(
header, "{http://www.w3.org/2005/08/addressing}Action") header, "{http://www.w3.org/2005/08/addressing}Action"
)
action.set( action.set(
"{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1") "{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1"
action.text = ('http://docs.oasis-open.org/ws-sx/ws-trust/200512' )
'/RST/Issue') action.text = (
'http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue'
)
messageID = etree.SubElement( messageID = etree.SubElement(
header, '{http://www.w3.org/2005/08/addressing}MessageID') header, '{http://www.w3.org/2005/08/addressing}MessageID'
)
messageID.text = 'urn:uuid:' + uuid.uuid4().hex messageID.text = 'urn:uuid:' + uuid.uuid4().hex
replyID = etree.SubElement( replyID = etree.SubElement(
header, '{http://www.w3.org/2005/08/addressing}ReplyTo') header, '{http://www.w3.org/2005/08/addressing}ReplyTo'
)
address = etree.SubElement( address = etree.SubElement(
replyID, '{http://www.w3.org/2005/08/addressing}Address') replyID, '{http://www.w3.org/2005/08/addressing}Address'
)
address.text = 'http://www.w3.org/2005/08/addressing/anonymous' address.text = 'http://www.w3.org/2005/08/addressing/anonymous'
to = etree.SubElement( to = etree.SubElement(
header, '{http://www.w3.org/2005/08/addressing}To') header, '{http://www.w3.org/2005/08/addressing}To'
)
to.set("{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1") to.set("{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1")
security = etree.SubElement( security = etree.SubElement(
header, '{http://docs.oasis-open.org/wss/2004/01/oasis-200401-' header,
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd}Security', 'wss-wssecurity-secext-1.0.xsd}Security',
nsmap=WSS_SECURITY_NAMESPACE) nsmap=WSS_SECURITY_NAMESPACE,
)
security.set( security.set(
"{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1") "{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1"
)
timestamp = etree.SubElement( timestamp = etree.SubElement(
security, ('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-' security,
'wss-wssecurity-utility-1.0.xsd}Timestamp')) (
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Timestamp'
),
)
timestamp.set( timestamp.set(
('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-' (
'wss-wssecurity-utility-1.0.xsd}Id'), '_0') '{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Id'
),
'_0',
)
created = etree.SubElement( created = etree.SubElement(
timestamp, ('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-' timestamp,
'wss-wssecurity-utility-1.0.xsd}Created')) (
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Created'
),
)
expires = etree.SubElement( expires = etree.SubElement(
timestamp, ('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-' timestamp,
'wss-wssecurity-utility-1.0.xsd}Expires')) (
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Expires'
),
)
created.text, expires.text = self._token_dates() created.text, expires.text = self._token_dates()
usernametoken = etree.SubElement( usernametoken = etree.SubElement(
security, '{http://docs.oasis-open.org/wss/2004/01/oasis-200401-' security,
'wss-wssecurity-secext-1.0.xsd}UsernameToken') '{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd}UsernameToken',
)
usernametoken.set( usernametoken.set(
('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-' (
'wssecurity-utility-1.0.xsd}u'), "uuid-%s-1" % uuid.uuid4().hex) '{http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-'
'wssecurity-utility-1.0.xsd}u'
),
f"uuid-{uuid.uuid4().hex}-1",
)
username = etree.SubElement( username = etree.SubElement(
usernametoken, ('{http://docs.oasis-open.org/wss/2004/01/oasis-' usernametoken,
'200401-wss-wssecurity-secext-1.0.xsd}Username')) (
'{http://docs.oasis-open.org/wss/2004/01/oasis-'
'200401-wss-wssecurity-secext-1.0.xsd}Username'
),
)
password = etree.SubElement( password = etree.SubElement(
usernametoken, ('{http://docs.oasis-open.org/wss/2004/01/oasis-' usernametoken,
'200401-wss-wssecurity-secext-1.0.xsd}Password'), (
Type=('http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-' '{http://docs.oasis-open.org/wss/2004/01/oasis-'
'username-token-profile-1.0#PasswordText')) '200401-wss-wssecurity-secext-1.0.xsd}Password'
),
Type=(
'http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-'
'username-token-profile-1.0#PasswordText'
),
)
body = etree.SubElement( body = etree.SubElement(
root, "{http://www.w3.org/2003/05/soap-envelope}Body") root, "{http://www.w3.org/2003/05/soap-envelope}Body"
)
request_security_token = etree.SubElement( request_security_token = etree.SubElement(
body, ('{http://docs.oasis-open.org/ws-sx/ws-trust/200512}' body,
'RequestSecurityToken'), nsmap=TRUST_NAMESPACE) (
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}'
'RequestSecurityToken'
),
nsmap=TRUST_NAMESPACE,
)
applies_to = etree.SubElement( applies_to = etree.SubElement(
request_security_token, request_security_token,
'{http://schemas.xmlsoap.org/ws/2004/09/policy}AppliesTo', '{http://schemas.xmlsoap.org/ws/2004/09/policy}AppliesTo',
nsmap=WSP_NAMESPACE) nsmap=WSP_NAMESPACE,
)
endpoint_reference = etree.SubElement( endpoint_reference = etree.SubElement(
applies_to, applies_to,
'{http://www.w3.org/2005/08/addressing}EndpointReference', '{http://www.w3.org/2005/08/addressing}EndpointReference',
nsmap=WSA_NAMESPACE) nsmap=WSA_NAMESPACE,
)
wsa_address = etree.SubElement( wsa_address = etree.SubElement(
endpoint_reference, endpoint_reference, '{http://www.w3.org/2005/08/addressing}Address'
'{http://www.w3.org/2005/08/addressing}Address') )
keytype = etree.SubElement( keytype = etree.SubElement(
request_security_token, request_security_token,
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}KeyType') '{http://docs.oasis-open.org/ws-sx/ws-trust/200512}KeyType',
keytype.text = ('http://docs.oasis-open.org/ws-sx/' )
'ws-trust/200512/Bearer') keytype.text = (
'http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer'
)
request_type = etree.SubElement( request_type = etree.SubElement(
request_security_token, request_security_token,
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}RequestType') '{http://docs.oasis-open.org/ws-sx/ws-trust/200512}RequestType',
request_type.text = ('http://docs.oasis-open.org/ws-sx/' )
'ws-trust/200512/Issue') request_type.text = (
'http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue'
)
token_type = etree.SubElement( token_type = etree.SubElement(
request_security_token, request_security_token,
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}TokenType') '{http://docs.oasis-open.org/ws-sx/ws-trust/200512}TokenType',
)
token_type.text = 'urn:oasis:names:tc:SAML:1.0:assertion' token_type.text = 'urn:oasis:names:tc:SAML:1.0:assertion'
# After constructing the request, let's plug in some values # After constructing the request, let's plug in some values
username.text = self.username username.text = self.username
password.text = self.password password.text = self.password
to.text = self.identity_provider_url to.text = self.identity_provider_url
wsa_address.text = (self.service_provider_entity_id or wsa_address.text = (
self.service_provider_endpoint) self.service_provider_entity_id or self.service_provider_endpoint
)
self.prepared_request = root self.prepared_request = root
@ -289,12 +365,14 @@ class Password(base.BaseSAMLPlugin):
recognized. recognized.
""" """
def _get_failure(e): def _get_failure(e):
xpath = '/s:Envelope/s:Body/s:Fault/s:Code/s:Subcode/s:Value' xpath = '/s:Envelope/s:Body/s:Fault/s:Code/s:Subcode/s:Value'
content = e.response.content content = e.response.content
try: try:
obj = self.str_to_xml(content).xpath( obj = self.str_to_xml(content).xpath(
xpath, namespaces=self.NAMESPACES) xpath, namespaces=self.NAMESPACES
)
obj = self._first(obj) obj = self._first(obj)
return obj.text return obj.text
# NOTE(marek-denis): etree.Element.xpath() doesn't raise an # NOTE(marek-denis): etree.Element.xpath() doesn't raise an
@ -309,13 +387,18 @@ class Password(base.BaseSAMLPlugin):
request_security_token = self.xml_to_str(self.prepared_request) request_security_token = self.xml_to_str(self.prepared_request)
try: try:
response = session.post( response = session.post(
url=self.identity_provider_url, headers=self.HEADER_SOAP, url=self.identity_provider_url,
data=request_security_token, authenticated=False) headers=self.HEADER_SOAP,
data=request_security_token,
authenticated=False,
)
except exceptions.InternalServerError as e: except exceptions.InternalServerError as e:
reason = _get_failure(e) reason = _get_failure(e)
raise exceptions.AuthorizationFailure(reason) raise exceptions.AuthorizationFailure(reason)
msg = ('Error parsing XML returned from ' msg = (
'the ADFS Identity Provider, reason: %s') 'Error parsing XML returned from '
'the ADFS Identity Provider, reason: %s'
)
self.adfs_token = self.str_to_xml(response.content, msg) self.adfs_token = self.str_to_xml(response.content, msg)
def _prepare_sp_request(self): def _prepare_sp_request(self):
@ -329,7 +412,8 @@ class Password(base.BaseSAMLPlugin):
""" """
assertion = self.adfs_token.xpath( assertion = self.adfs_token.xpath(
self.ADFS_ASSERTION_XPATH, namespaces=self.ADFS_TOKEN_NAMESPACES) self.ADFS_ASSERTION_XPATH, namespaces=self.ADFS_TOKEN_NAMESPACES
)
assertion = self._first(assertion) assertion = self._first(assertion)
assertion = self.xml_to_str(assertion) assertion = self.xml_to_str(assertion)
# TODO(marek-denis): Ideally no string replacement should occur. # TODO(marek-denis): Ideally no string replacement should occur.
@ -338,7 +422,8 @@ class Password(base.BaseSAMLPlugin):
# from scratch and reuse values from the adfs security token. # from scratch and reuse values from the adfs security token.
assertion = assertion.replace( assertion = assertion.replace(
b'http://docs.oasis-open.org/ws-sx/ws-trust/200512', b'http://docs.oasis-open.org/ws-sx/ws-trust/200512',
b'http://schemas.xmlsoap.org/ws/2005/02/trust') b'http://schemas.xmlsoap.org/ws/2005/02/trust',
)
encoded_assertion = urllib.parse.quote(assertion) encoded_assertion = urllib.parse.quote(assertion)
self.encoded_assertion = 'wa=wsignin1.0&wresult=' + encoded_assertion self.encoded_assertion = 'wa=wsignin1.0&wresult=' + encoded_assertion
@ -358,8 +443,12 @@ class Password(base.BaseSAMLPlugin):
""" """
session.post( session.post(
url=self.service_provider_endpoint, data=self.encoded_assertion, url=self.service_provider_endpoint,
headers=self.HEADER_X_FORM, redirect=False, authenticated=False) data=self.encoded_assertion,
headers=self.HEADER_X_FORM,
redirect=False,
authenticated=False,
)
def _access_service_provider(self, session): def _access_service_provider(self, session):
"""Access protected endpoint and fetch unscoped token. """Access protected endpoint and fetch unscoped token.
@ -382,9 +471,11 @@ class Password(base.BaseSAMLPlugin):
if self._cookies(session) is False: if self._cookies(session) is False:
raise exceptions.AuthorizationFailure( raise exceptions.AuthorizationFailure(
"Session object doesn't contain a cookie, therefore you are " "Session object doesn't contain a cookie, therefore you are "
"not allowed to enter the Identity Provider's protected area.") "not allowed to enter the Identity Provider's protected area."
self.authenticated_response = session.get(self.federated_token_url, )
authenticated=False) self.authenticated_response = session.get(
self.federated_token_url, authenticated=False
)
def get_unscoped_auth_ref(self, session, *kwargs): def get_unscoped_auth_ref(self, session, *kwargs):
"""Retrieve unscoped token after authentcation with ADFS server. """Retrieve unscoped token after authentcation with ADFS server.

View File

@ -23,21 +23,27 @@ class _Saml2TokenAuthMethod(v3.AuthMethod):
_method_parameters = [] _method_parameters = []
def get_auth_data(self, session, auth, headers, **kwargs): def get_auth_data(self, session, auth, headers, **kwargs):
raise exceptions.MethodNotImplemented('This method should never ' raise exceptions.MethodNotImplemented(
'be called') 'This method should never be called'
)
class BaseSAMLPlugin(v3.FederationBaseAuth): class BaseSAMLPlugin(v3.FederationBaseAuth):
HTTP_MOVED_TEMPORARILY = 302 HTTP_MOVED_TEMPORARILY = 302
HTTP_SEE_OTHER = 303 HTTP_SEE_OTHER = 303
_auth_method_class = _Saml2TokenAuthMethod _auth_method_class = _Saml2TokenAuthMethod
def __init__(self, auth_url, def __init__(
identity_provider, identity_provider_url, self,
username, password, protocol, auth_url,
**kwargs): identity_provider,
identity_provider_url,
username,
password,
protocol,
**kwargs,
):
"""Class constructor accepting following parameters. """Class constructor accepting following parameters.
:param auth_url: URL of the Identity Service :param auth_url: URL of the Identity Service
@ -68,10 +74,12 @@ class BaseSAMLPlugin(v3.FederationBaseAuth):
:type protocol: string :type protocol: string
""" """
super(BaseSAMLPlugin, self).__init__( super().__init__(
auth_url=auth_url, identity_provider=identity_provider, auth_url=auth_url,
identity_provider=identity_provider,
protocol=protocol, protocol=protocol,
**kwargs) **kwargs,
)
self.identity_provider_url = identity_provider_url self.identity_provider_url = identity_provider_url
self.username = username self.username = username
self.password = password self.password = password

View File

@ -28,7 +28,7 @@ _PAOS_NAMESPACE = 'urn:liberty:paos:2003-08'
_ECP_NAMESPACE = 'urn:oasis:names:tc:SAML:2.0:profiles:SSO:ecp' _ECP_NAMESPACE = 'urn:oasis:names:tc:SAML:2.0:profiles:SSO:ecp'
_PAOS_HEADER = 'application/vnd.paos+xml' _PAOS_HEADER = 'application/vnd.paos+xml'
_PAOS_VER = 'ver="%s";"%s"' % (_PAOS_NAMESPACE, _ECP_NAMESPACE) _PAOS_VER = f'ver="{_PAOS_NAMESPACE}";"{_ECP_NAMESPACE}"'
_XML_NAMESPACES = { _XML_NAMESPACES = {
'ecp': _ECP_NAMESPACE, 'ecp': _ECP_NAMESPACE,
@ -72,14 +72,14 @@ def _response_xml(response, name):
try: try:
return etree.XML(response.content) return etree.XML(response.content)
except etree.XMLSyntaxError as e: except etree.XMLSyntaxError as e:
msg = 'SAML2: Error parsing XML returned from %s: %s' % (name, e) msg = f'SAML2: Error parsing XML returned from {name}: {e}'
raise InvalidResponse(msg) raise InvalidResponse(msg)
def _str_from_xml(xml, path): def _str_from_xml(xml, path):
li = xml.xpath(path, namespaces=_XML_NAMESPACES) li = xml.xpath(path, namespaces=_XML_NAMESPACES)
if len(li) != 1: if len(li) != 1:
raise IndexError('%s should provide a single element list' % path) raise IndexError(f'{path} should provide a single element list')
return li[0] return li[0]
@ -115,7 +115,7 @@ class _SamlAuth(requests.auth.AuthBase):
""" """
def __init__(self, identity_provider_url, requests_auth): def __init__(self, identity_provider_url, requests_auth):
super(_SamlAuth, self).__init__() super().__init__()
self.identity_provider_url = identity_provider_url self.identity_provider_url = identity_provider_url
self.requests_auth = requests_auth self.requests_auth = requests_auth
@ -132,8 +132,10 @@ class _SamlAuth(requests.auth.AuthBase):
return request return request
def _handle_response(self, response, **kwargs): def _handle_response(self, response, **kwargs):
if (response.status_code == 200 and if (
response.headers.get('Content-Type') == _PAOS_HEADER): response.status_code == 200
and response.headers.get('Content-Type') == _PAOS_HEADER
):
response = self._ecp_retry(response, **kwargs) response = self._ecp_retry(response, **kwargs)
return response return response
@ -151,33 +153,40 @@ class _SamlAuth(requests.auth.AuthBase):
authn_request.remove(authn_request[0]) authn_request.remove(authn_request[0])
idp_response = send('POST', idp_response = send(
'POST',
self.identity_provider_url, self.identity_provider_url,
headers={'Content-type': 'text/xml'}, headers={'Content-type': 'text/xml'},
data=etree.tostring(authn_request), data=etree.tostring(authn_request),
auth=self.requests_auth) auth=self.requests_auth,
)
history.append(idp_response) history.append(idp_response)
authn_response = _response_xml(idp_response, 'Identity Provider') authn_response = _response_xml(idp_response, 'Identity Provider')
idp_consumer_url = _str_from_xml(authn_response, idp_consumer_url = _str_from_xml(
_XPATH_IDP_CONSUMER_URL) authn_response, _XPATH_IDP_CONSUMER_URL
)
if sp_consumer_url != idp_consumer_url: if sp_consumer_url != idp_consumer_url:
# send fault message to the SP, discard the response # send fault message to the SP, discard the response
send('POST', send(
'POST',
sp_consumer_url, sp_consumer_url,
data=_SOAP_FAULT, data=_SOAP_FAULT,
headers={'Content-Type': _PAOS_HEADER}) headers={'Content-Type': _PAOS_HEADER},
)
# prepare error message and raise an exception. # prepare error message and raise an exception.
msg = ('Consumer URLs from Service Provider %(service_provider)s ' msg = (
'Consumer URLs from Service Provider %(service_provider)s '
'%(sp_consumer_url)s and Identity Provider ' '%(sp_consumer_url)s and Identity Provider '
'%(identity_provider)s %(idp_consumer_url)s are not equal') '%(identity_provider)s %(idp_consumer_url)s are not equal'
)
msg = msg % { msg = msg % {
'service_provider': sp_response.request.url, 'service_provider': sp_response.request.url,
'sp_consumer_url': sp_consumer_url, 'sp_consumer_url': sp_consumer_url,
'identity_provider': self.identity_provider_url, 'identity_provider': self.identity_provider_url,
'idp_consumer_url': idp_consumer_url 'idp_consumer_url': idp_consumer_url,
} }
raise ConsumerMismatch(msg) raise ConsumerMismatch(msg)
@ -186,19 +195,22 @@ class _SamlAuth(requests.auth.AuthBase):
# idp_consumer_url is the URL on the SP that handles the ECP body # idp_consumer_url is the URL on the SP that handles the ECP body
# returned and creates an authenticated session. # returned and creates an authenticated session.
final_resp = send('POST', final_resp = send(
'POST',
idp_consumer_url, idp_consumer_url,
headers={'Content-Type': _PAOS_HEADER}, headers={'Content-Type': _PAOS_HEADER},
cookies=idp_response.cookies, cookies=idp_response.cookies,
data=etree.tostring(authn_response)) data=etree.tostring(authn_response),
)
history.append(final_resp) history.append(final_resp)
# the SP should then redirect us back to the original URL to retry the # the SP should then redirect us back to the original URL to retry the
# original request. # original request.
if final_resp.status_code in (requests.codes.found, if final_resp.status_code in (
requests.codes.other): requests.codes.found,
requests.codes.other,
):
# Consume content and release the original connection # Consume content and release the original connection
# to allow our new request to reuse the same one. # to allow our new request to reuse the same one.
sp_response.content sp_response.content
@ -216,13 +228,15 @@ class _SamlAuth(requests.auth.AuthBase):
class _FederatedSaml(v3.FederationBaseAuth): class _FederatedSaml(v3.FederationBaseAuth):
def __init__(
def __init__(self, auth_url, identity_provider, protocol, self,
identity_provider_url, **kwargs): auth_url,
super(_FederatedSaml, self).__init__(auth_url,
identity_provider, identity_provider,
protocol, protocol,
**kwargs) identity_provider_url,
**kwargs,
):
super().__init__(auth_url, identity_provider, protocol, **kwargs)
self.identity_provider_url = identity_provider_url self.identity_provider_url = identity_provider_url
@abc.abstractmethod @abc.abstractmethod
@ -234,9 +248,11 @@ class _FederatedSaml(v3.FederationBaseAuth):
auth = _SamlAuth(self.identity_provider_url, method) auth = _SamlAuth(self.identity_provider_url, method)
try: try:
resp = session.get(self.federated_token_url, resp = session.get(
self.federated_token_url,
requests_auth=auth, requests_auth=auth,
authenticated=False) authenticated=False,
)
except SamlException as e: except SamlException as e:
raise exceptions.AuthorizationFailure(str(e)) raise exceptions.AuthorizationFailure(str(e))
@ -287,13 +303,23 @@ class Password(_FederatedSaml):
""" """
def __init__(self, auth_url, identity_provider, protocol, def __init__(
identity_provider_url, username, password, **kwargs): self,
super(Password, self).__init__(auth_url, auth_url,
identity_provider, identity_provider,
protocol, protocol,
identity_provider_url, identity_provider_url,
**kwargs) username,
password,
**kwargs,
):
super().__init__(
auth_url,
identity_provider,
protocol,
identity_provider_url,
**kwargs,
)
self.username = username self.username = username
self.password = password self.password = password

View File

@ -43,7 +43,8 @@ def _mutual_auth(value):
def _requests_auth(mutual_authentication): def _requests_auth(mutual_authentication):
return requests_kerberos.HTTPKerberosAuth( return requests_kerberos.HTTPKerberosAuth(
mutual_authentication=_mutual_auth(mutual_authentication)) mutual_authentication=_mutual_auth(mutual_authentication)
)
def _dependency_check(): def _dependency_check():
@ -57,12 +58,11 @@ packages. These can be installed with::
class KerberosMethod(v3.AuthMethod): class KerberosMethod(v3.AuthMethod):
_method_parameters = ['mutual_auth'] _method_parameters = ['mutual_auth']
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
_dependency_check() _dependency_check()
super(KerberosMethod, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_auth_data(self, session, auth, headers, request_kwargs, **kwargs): def get_auth_data(self, session, auth, headers, request_kwargs, **kwargs):
# NOTE(jamielennox): request_kwargs is passed as a kwarg however it is # NOTE(jamielennox): request_kwargs is passed as a kwarg however it is
@ -82,16 +82,18 @@ class MappedKerberos(federation.FederationBaseAuth):
use the standard keystone auth process to scope that to any given project. use the standard keystone auth process to scope that to any given project.
""" """
def __init__(self, auth_url, identity_provider, protocol, def __init__(
mutual_auth=None, **kwargs): self, auth_url, identity_provider, protocol, mutual_auth=None, **kwargs
):
_dependency_check() _dependency_check()
self.mutual_auth = mutual_auth self.mutual_auth = mutual_auth
super(MappedKerberos, self).__init__(auth_url, identity_provider, super().__init__(auth_url, identity_provider, protocol, **kwargs)
protocol, **kwargs)
def get_unscoped_auth_ref(self, session, **kwargs): def get_unscoped_auth_ref(self, session, **kwargs):
resp = session.get(self.federated_token_url, resp = session.get(
self.federated_token_url,
requests_auth=_requests_auth(self.mutual_auth), requests_auth=_requests_auth(self.mutual_auth),
authenticated=False) authenticated=False,
)
return access.create(body=resp.json(), resp=resp) return access.create(body=resp.json(), resp=resp)

View File

@ -16,7 +16,6 @@ from keystoneauth1 import loading
class Kerberos(loading.BaseV3Loader): class Kerberos(loading.BaseV3Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return kerberos.Kerberos return kerberos.Kerberos
@ -26,14 +25,18 @@ class Kerberos(loading.BaseV3Loader):
return kerberos.requests_kerberos is not None return kerberos.requests_kerberos is not None
def get_options(self): def get_options(self):
options = super(Kerberos, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('mutual-auth', [
loading.Opt(
'mutual-auth',
required=False, required=False,
default='optional', default='optional',
help='Configures Kerberos Mutual Authentication'), help='Configures Kerberos Mutual Authentication',
]) )
]
)
return options return options
@ -41,16 +44,17 @@ class Kerberos(loading.BaseV3Loader):
if kwargs.get('mutual_auth'): if kwargs.get('mutual_auth'):
value = kwargs.get('mutual_auth') value = kwargs.get('mutual_auth')
if not (value.lower() in ['required', 'optional', 'disabled']): if not (value.lower() in ['required', 'optional', 'disabled']):
m = ('You need to provide a valid value for kerberos mutual ' m = (
'You need to provide a valid value for kerberos mutual '
'authentication. It can be one of the following: ' 'authentication. It can be one of the following: '
'(required, optional, disabled)') '(required, optional, disabled)'
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(Kerberos, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class MappedKerberos(loading.BaseFederationLoader): class MappedKerberos(loading.BaseFederationLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return kerberos.MappedKerberos return kerberos.MappedKerberos
@ -60,14 +64,18 @@ class MappedKerberos(loading.BaseFederationLoader):
return kerberos.requests_kerberos is not None return kerberos.requests_kerberos is not None
def get_options(self): def get_options(self):
options = super(MappedKerberos, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('mutual-auth', [
loading.Opt(
'mutual-auth',
required=False, required=False,
default='optional', default='optional',
help='Configures Kerberos Mutual Authentication'), help='Configures Kerberos Mutual Authentication',
]) )
]
)
return options return options
@ -75,9 +83,11 @@ class MappedKerberos(loading.BaseFederationLoader):
if kwargs.get('mutual_auth'): if kwargs.get('mutual_auth'):
value = kwargs.get('mutual_auth') value = kwargs.get('mutual_auth')
if not (value.lower() in ['required', 'optional', 'disabled']): if not (value.lower() in ['required', 'optional', 'disabled']):
m = ('You need to provide a valid value for kerberos mutual ' m = (
'You need to provide a valid value for kerberos mutual '
'authentication. It can be one of the following: ' 'authentication. It can be one of the following: '
'(required, optional, disabled)') '(required, optional, disabled)'
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(MappedKerberos, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)

View File

@ -17,7 +17,6 @@ from keystoneauth1 import loading
# NOTE(jamielennox): This is not a BaseV3Loader because we don't want to # NOTE(jamielennox): This is not a BaseV3Loader because we don't want to
# include the scoping options like project-id in the option list # include the scoping options like project-id in the option list
class V3OAuth1(loading.BaseIdentityLoader): class V3OAuth1(loading.BaseIdentityLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return v3.OAuth1 return v3.OAuth1
@ -27,21 +26,25 @@ class V3OAuth1(loading.BaseIdentityLoader):
return v3.oauth1 is not None return v3.oauth1 is not None
def get_options(self): def get_options(self):
options = super(V3OAuth1, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('consumer-key', [
loading.Opt(
'consumer-key', required=True, help='OAuth Consumer ID/Key'
),
loading.Opt(
'consumer-secret',
required=True, required=True,
help='OAuth Consumer ID/Key'), help='OAuth Consumer Secret',
loading.Opt('consumer-secret', ),
required=True, loading.Opt(
help='OAuth Consumer Secret'), 'access-key', required=True, help='OAuth Access Key'
loading.Opt('access-key', ),
required=True, loading.Opt(
help='OAuth Access Key'), 'access-secret', required=True, help='OAuth Access Secret'
loading.Opt('access-secret', ),
required=True, ]
help='OAuth Access Secret'), )
])
return options return options

View File

@ -44,37 +44,46 @@ class OAuth1Method(v3.AuthMethod):
:param string access_secret: Access token secret. :param string access_secret: Access token secret.
""" """
_method_parameters = ['consumer_key', 'consumer_secret', _method_parameters = [
'access_key', 'access_secret'] 'consumer_key',
'consumer_secret',
'access_key',
'access_secret',
]
def get_auth_data(self, session, auth, headers, **kwargs): def get_auth_data(self, session, auth, headers, **kwargs):
# Add the oauth specific content into the headers # Add the oauth specific content into the headers
oauth_client = oauth1.Client(self.consumer_key, oauth_client = oauth1.Client(
self.consumer_key,
client_secret=self.consumer_secret, client_secret=self.consumer_secret,
resource_owner_key=self.access_key, resource_owner_key=self.access_key,
resource_owner_secret=self.access_secret, resource_owner_secret=self.access_secret,
signature_method=oauth1.SIGNATURE_HMAC) signature_method=oauth1.SIGNATURE_HMAC,
)
o_url, o_headers, o_body = oauth_client.sign(auth.token_url, o_url, o_headers, o_body = oauth_client.sign(
http_method='POST') auth.token_url, http_method='POST'
)
headers.update(o_headers) headers.update(o_headers)
return 'oauth1', {} return 'oauth1', {}
def get_cache_id_elements(self): def get_cache_id_elements(self):
return dict(('oauth1_%s' % p, getattr(self, p)) return {
for p in self._method_parameters) f'oauth1_{p}': getattr(self, p) for p in self._method_parameters
}
class OAuth1(v3.AuthConstructor): class OAuth1(v3.AuthConstructor):
_auth_method_class = OAuth1Method _auth_method_class = OAuth1Method
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(OAuth1, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.has_scope_parameters: if self.has_scope_parameters:
LOG.warning('Scoping parameters such as a project were provided ' LOG.warning(
'Scoping parameters such as a project were provided '
'to the OAuth1 plugin. Because OAuth1 access is ' 'to the OAuth1 plugin. Because OAuth1 access is '
'always scoped to a project these will be ignored by ' 'always scoped to a project these will be ignored by '
'the identity server') 'the identity server'
)

View File

@ -33,7 +33,8 @@ V2Token = v2.Token
V3Token = v3.Token V3Token = v3.Token
V3FederationToken = v3.V3FederationToken V3FederationToken = v3.V3FederationToken
__all__ = ('DiscoveryList', __all__ = (
'DiscoveryList',
'FixtureValidationError', 'FixtureValidationError',
'LoadingFixture', 'LoadingFixture',
'TestPlugin', 'TestPlugin',
@ -43,4 +44,4 @@ __all__ = ('DiscoveryList',
'V3Token', 'V3Token',
'V3FederationToken', 'V3FederationToken',
'VersionDiscovery', 'VersionDiscovery',
) )

View File

@ -12,11 +12,7 @@
from keystoneauth1 import _utils as utils from keystoneauth1 import _utils as utils
__all__ = ('DiscoveryList', __all__ = ('DiscoveryList', 'V2Discovery', 'V3Discovery', 'VersionDiscovery')
'V2Discovery',
'V3Discovery',
'VersionDiscovery',
)
_DEFAULT_DAYS_AGO = 30 _DEFAULT_DAYS_AGO = 30
@ -32,7 +28,7 @@ class DiscoveryBase(dict):
""" """
def __init__(self, id, status=None, updated=None): def __init__(self, id, status=None, updated=None):
super(DiscoveryBase, self).__init__() super().__init__()
self.id = id self.id = id
self.status = status or 'stable' self.status = status or 'stable'
@ -103,7 +99,7 @@ class VersionDiscovery(DiscoveryBase):
""" """
def __init__(self, href, id, **kwargs): def __init__(self, href, id, **kwargs):
super(VersionDiscovery, self).__init__(id, **kwargs) super().__init__(id, **kwargs)
self.add_link(href) self.add_link(href)
@ -122,7 +118,7 @@ class MicroversionDiscovery(DiscoveryBase):
""" """
def __init__(self, href, id, min_version='', max_version='', **kwargs): def __init__(self, href, id, min_version='', max_version='', **kwargs):
super(MicroversionDiscovery, self).__init__(id, **kwargs) super().__init__(id, **kwargs)
self.add_link(href) self.add_link(href)
@ -160,7 +156,7 @@ class NovaMicroversionDiscovery(DiscoveryBase):
""" """
def __init__(self, href, id, min_version=None, version=None, **kwargs): def __init__(self, href, id, min_version=None, version=None, **kwargs):
super(NovaMicroversionDiscovery, self).__init__(id, **kwargs) super().__init__(id, **kwargs)
self.add_link(href) self.add_link(href)
@ -204,7 +200,7 @@ class V2Discovery(DiscoveryBase):
_DESC_URL = 'https://developer.openstack.org/api-ref/identity/v2/' _DESC_URL = 'https://developer.openstack.org/api-ref/identity/v2/'
def __init__(self, href, id=None, html=True, pdf=True, **kwargs): def __init__(self, href, id=None, html=True, pdf=True, **kwargs):
super(V2Discovery, self).__init__(id or 'v2.0', **kwargs) super().__init__(id or 'v2.0', **kwargs)
self.add_link(href) self.add_link(href)
@ -219,9 +215,11 @@ class V2Discovery(DiscoveryBase):
The standard structure includes a link to a HTML document with the The standard structure includes a link to a HTML document with the
API specification. Add it to this entry. API specification. Add it to this entry.
""" """
self.add_link(href=self._DESC_URL + 'content', self.add_link(
href=self._DESC_URL + 'content',
rel='describedby', rel='describedby',
type='text/html') type='text/html',
)
def add_pdf_description(self): def add_pdf_description(self):
"""Add the PDF described by links. """Add the PDF described by links.
@ -229,9 +227,11 @@ class V2Discovery(DiscoveryBase):
The standard structure includes a link to a PDF document with the The standard structure includes a link to a PDF document with the
API specification. Add it to this entry. API specification. Add it to this entry.
""" """
self.add_link(href=self._DESC_URL + 'identity-dev-guide-2.0.pdf', self.add_link(
href=self._DESC_URL + 'identity-dev-guide-2.0.pdf',
rel='describedby', rel='describedby',
type='application/pdf') type='application/pdf',
)
class V3Discovery(DiscoveryBase): class V3Discovery(DiscoveryBase):
@ -249,7 +249,7 @@ class V3Discovery(DiscoveryBase):
""" """
def __init__(self, href, id=None, json=True, xml=True, **kwargs): def __init__(self, href, id=None, json=True, xml=True, **kwargs):
super(V3Discovery, self).__init__(id or 'v3.0', **kwargs) super().__init__(id or 'v3.0', **kwargs)
self.add_link(href) self.add_link(href)
@ -264,8 +264,10 @@ class V3Discovery(DiscoveryBase):
The standard structure includes a list of media-types that the endpoint The standard structure includes a list of media-types that the endpoint
supports. Add JSON to the list. supports. Add JSON to the list.
""" """
self.add_media_type(base='application/json', self.add_media_type(
type='application/vnd.openstack.identity-v3+json') base='application/json',
type='application/vnd.openstack.identity-v3+json',
)
def add_xml_media_type(self): def add_xml_media_type(self):
"""Add the XML media-type links. """Add the XML media-type links.
@ -273,8 +275,10 @@ class V3Discovery(DiscoveryBase):
The standard structure includes a list of media-types that the endpoint The standard structure includes a list of media-types that the endpoint
supports. Add XML to the list. supports. Add XML to the list.
""" """
self.add_media_type(base='application/xml', self.add_media_type(
type='application/vnd.openstack.identity-v3+xml') base='application/xml',
type='application/vnd.openstack.identity-v3+xml',
)
class DiscoveryList(dict): class DiscoveryList(dict):
@ -298,22 +302,47 @@ class DiscoveryList(dict):
TEST_URL = 'http://keystone.host:5000/' TEST_URL = 'http://keystone.host:5000/'
def __init__(self, href=None, v2=True, v3=True, v2_id=None, v3_id=None, def __init__(
v2_status=None, v2_updated=None, v2_html=True, v2_pdf=True, self,
v3_status=None, v3_updated=None, v3_json=True, v3_xml=True): href=None,
super(DiscoveryList, self).__init__(versions={'values': []}) v2=True,
v3=True,
v2_id=None,
v3_id=None,
v2_status=None,
v2_updated=None,
v2_html=True,
v2_pdf=True,
v3_status=None,
v3_updated=None,
v3_json=True,
v3_xml=True,
):
super().__init__(versions={'values': []})
href = href or self.TEST_URL href = href or self.TEST_URL
if v2: if v2:
v2_href = href.rstrip('/') + '/v2.0' v2_href = href.rstrip('/') + '/v2.0'
self.add_v2(v2_href, id=v2_id, status=v2_status, self.add_v2(
updated=v2_updated, html=v2_html, pdf=v2_pdf) v2_href,
id=v2_id,
status=v2_status,
updated=v2_updated,
html=v2_html,
pdf=v2_pdf,
)
if v3: if v3:
v3_href = href.rstrip('/') + '/v3' v3_href = href.rstrip('/') + '/v3'
self.add_v3(v3_href, id=v3_id, status=v3_status, self.add_v3(
updated=v3_updated, json=v3_json, xml=v3_xml) v3_href,
id=v3_id,
status=v3_status,
updated=v3_updated,
json=v3_json,
xml=v3_xml,
)
@property @property
def versions(self): def versions(self):

View File

@ -25,11 +25,16 @@ from keystoneauth1 import session
class BetamaxFixture(fixtures.Fixture): class BetamaxFixture(fixtures.Fixture):
def __init__(
def __init__(self, cassette_name, cassette_library_dir=None, self,
serializer=None, record=False, cassette_name,
cassette_library_dir=None,
serializer=None,
record=False,
pre_record_hook=hooks.pre_record_hook, pre_record_hook=hooks.pre_record_hook,
serializer_name=None, request_matchers=None): serializer_name=None,
request_matchers=None,
):
"""Configure Betamax for the test suite. """Configure Betamax for the test suite.
:param str cassette_name: :param str cassette_name:
@ -93,10 +98,12 @@ class BetamaxFixture(fixtures.Fixture):
return self._serializer_name return self._serializer_name
def setUp(self): def setUp(self):
super(BetamaxFixture, self).setUp() super().setUp()
self.mockpatch = mock.patch.object( self.mockpatch = mock.patch.object(
session, '_construct_session', session,
partial(_construct_session_with_betamax, self)) '_construct_session',
partial(_construct_session_with_betamax, self),
)
self.mockpatch.start() self.mockpatch.start()
# Unpatch during cleanup # Unpatch during cleanup
self.addCleanup(self.mockpatch.stop) self.addCleanup(self.mockpatch.stop)
@ -116,7 +123,8 @@ def _construct_session_with_betamax(fixture, session_obj=None):
with betamax.Betamax.configure() as config: with betamax.Betamax.configure() as config:
config.before_record(callback=fixture.pre_record_hook) config.before_record(callback=fixture.pre_record_hook)
fixture.recorder = betamax.Betamax( fixture.recorder = betamax.Betamax(
session_obj, cassette_library_dir=fixture.cassette_library_dir) session_obj, cassette_library_dir=fixture.cassette_library_dir
)
record = 'none' record = 'none'
serializer = None serializer = None
@ -126,10 +134,12 @@ def _construct_session_with_betamax(fixture, session_obj=None):
serializer = fixture.serializer_name serializer = fixture.serializer_name
fixture.recorder.use_cassette(fixture.cassette_name, fixture.recorder.use_cassette(
fixture.cassette_name,
serialize_with=serializer, serialize_with=serializer,
record=record, record=record,
**fixture.use_cassette_kwargs) **fixture.use_cassette_kwargs,
)
fixture.recorder.start() fixture.recorder.start()
fixture.addCleanup(fixture.recorder.stop) fixture.addCleanup(fixture.recorder.stop)

View File

@ -18,10 +18,7 @@ from keystoneauth1 import discover
from keystoneauth1 import loading from keystoneauth1 import loading
from keystoneauth1 import plugin from keystoneauth1 import plugin
__all__ = ( __all__ = ('LoadingFixture', 'TestPlugin')
'LoadingFixture',
'TestPlugin',
)
DEFAULT_TEST_ENDPOINT = 'https://openstack.example.com/%(service_type)s' DEFAULT_TEST_ENDPOINT = 'https://openstack.example.com/%(service_type)s'
@ -62,12 +59,10 @@ class TestPlugin(plugin.BaseAuthPlugin):
auth_type = 'test_plugin' auth_type = 'test_plugin'
def __init__(self, def __init__(
token=None, self, token=None, endpoint=None, user_id=None, project_id=None
endpoint=None, ):
user_id=None, super().__init__()
project_id=None):
super(TestPlugin, self).__init__()
self.token = token or uuid.uuid4().hex self.token = token or uuid.uuid4().hex
self.endpoint = endpoint or DEFAULT_TEST_ENDPOINT self.endpoint = endpoint or DEFAULT_TEST_ENDPOINT
@ -98,9 +93,8 @@ class TestPlugin(plugin.BaseAuthPlugin):
class _TestPluginLoader(loading.BaseLoader): class _TestPluginLoader(loading.BaseLoader):
def __init__(self, plugin): def __init__(self, plugin):
super(_TestPluginLoader, self).__init__() super().__init__()
self._plugin = plugin self._plugin = plugin
def create_plugin(self, **kwargs): def create_plugin(self, **kwargs):
@ -129,12 +123,10 @@ class LoadingFixture(fixtures.Fixture):
MOCK_POINT = 'keystoneauth1.loading.base.get_plugin_loader' MOCK_POINT = 'keystoneauth1.loading.base.get_plugin_loader'
def __init__(self, def __init__(
token=None, self, token=None, endpoint=None, user_id=None, project_id=None
endpoint=None, ):
user_id=None, super().__init__()
project_id=None):
super(LoadingFixture, self).__init__()
# these are created and saved here so that a test could use them # these are created and saved here so that a test could use them
self.token = token or uuid.uuid4().hex self.token = token or uuid.uuid4().hex
@ -143,16 +135,19 @@ class LoadingFixture(fixtures.Fixture):
self.project_id = project_id or uuid.uuid4().hex self.project_id = project_id or uuid.uuid4().hex
def setUp(self): def setUp(self):
super(LoadingFixture, self).setUp() super().setUp()
self.useFixture(fixtures.MonkeyPatch(self.MOCK_POINT, self.useFixture(
self.get_plugin_loader)) fixtures.MonkeyPatch(self.MOCK_POINT, self.get_plugin_loader)
)
def create_plugin(self): def create_plugin(self):
return TestPlugin(token=self.token, return TestPlugin(
token=self.token,
endpoint=self.endpoint, endpoint=self.endpoint,
user_id=self.user_id, user_id=self.user_id,
project_id=self.project_id) project_id=self.project_id,
)
def get_plugin_loader(self, auth_type): def get_plugin_loader(self, auth_type):
plugin = self.create_plugin() plugin = self.create_plugin()
@ -171,6 +166,6 @@ class LoadingFixture(fixtures.Fixture):
endpoint = _format_endpoint(self.endpoint, **kwargs) endpoint = _format_endpoint(self.endpoint, **kwargs)
if path: if path:
endpoint = "%s/%s" % (endpoint.rstrip('/'), path.lstrip('/')) endpoint = "{}/{}".format(endpoint.rstrip('/'), path.lstrip('/'))
return endpoint return endpoint

View File

@ -20,7 +20,7 @@ import yaml
def _should_use_block(value): def _should_use_block(value):
for c in u"\u000a\u000d\u001c\u001d\u001e\u0085\u2028\u2029": for c in "\u000a\u000d\u001c\u001d\u001e\u0085\u2028\u2029":
if c in value: if c in value:
return True return True
return False return False
@ -40,7 +40,7 @@ def _represent_scalar(self, tag, value, style=None):
def _unicode_representer(dumper, uni): def _unicode_representer(dumper, uni):
node = yaml.ScalarNode(tag=u'tag:yaml.org,2002:str', value=uni) node = yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=uni)
return node return node
@ -49,9 +49,12 @@ def _indent_json(val):
return '' return ''
return json.dumps( return json.dumps(
json.loads(val), indent=2, json.loads(val),
separators=(',', ': '), sort_keys=False, indent=2,
default=str) separators=(',', ': '),
sort_keys=False,
default=str,
)
def _is_json_body(interaction): def _is_json_body(interaction):
@ -60,13 +63,11 @@ def _is_json_body(interaction):
class YamlJsonSerializer(betamax.serializers.base.BaseSerializer): class YamlJsonSerializer(betamax.serializers.base.BaseSerializer):
name = "yamljson" name = "yamljson"
@staticmethod @staticmethod
def generate_cassette_name(cassette_library_dir, cassette_name): def generate_cassette_name(cassette_library_dir, cassette_name):
return os.path.join( return os.path.join(cassette_library_dir, f"{cassette_name}.yaml")
cassette_library_dir, "{name}.yaml".format(name=cassette_name))
def serialize(self, cassette_data): def serialize(self, cassette_data):
# Reserialize internal json with indentation # Reserialize internal json with indentation
@ -74,7 +75,8 @@ class YamlJsonSerializer(betamax.serializers.base.BaseSerializer):
for key in ('request', 'response'): for key in ('request', 'response'):
if _is_json_body(interaction[key]): if _is_json_body(interaction[key]):
interaction[key]['body']['string'] = _indent_json( interaction[key]['body']['string'] = _indent_json(
interaction[key]['body']['string']) interaction[key]['body']['string']
)
class MyDumper(yaml.Dumper): class MyDumper(yaml.Dumper):
"""Specialized Dumper which does nice blocks and unicode.""" """Specialized Dumper which does nice blocks and unicode."""
@ -84,7 +86,8 @@ class YamlJsonSerializer(betamax.serializers.base.BaseSerializer):
MyDumper.add_representer(str, _unicode_representer) MyDumper.add_representer(str, _unicode_representer)
return yaml.dump( return yaml.dump(
cassette_data, Dumper=MyDumper, default_flow_style=False) cassette_data, Dumper=MyDumper, default_flow_style=False
)
def deserialize(self, cassette_data): def deserialize(self, cassette_data):
try: try:

View File

@ -18,15 +18,23 @@ from keystoneauth1.fixture import exception
class _Service(dict): class _Service(dict):
def add_endpoint(
def add_endpoint(self, public, admin=None, internal=None, self,
tenant_id=None, region=None, id=None): public,
data = {'tenantId': tenant_id or uuid.uuid4().hex, admin=None,
internal=None,
tenant_id=None,
region=None,
id=None,
):
data = {
'tenantId': tenant_id or uuid.uuid4().hex,
'publicURL': public, 'publicURL': public,
'adminURL': admin or public, 'adminURL': admin or public,
'internalURL': internal or public, 'internalURL': internal or public,
'region': region, 'region': region,
'id': id or uuid.uuid4().hex} 'id': id or uuid.uuid4().hex,
}
self.setdefault('endpoints', []).append(data) self.setdefault('endpoints', []).append(data)
return data return data
@ -41,11 +49,21 @@ class Token(dict):
that matter to them and not copy and paste sample. that matter to them and not copy and paste sample.
""" """
def __init__(self, token_id=None, expires=None, issued=None, def __init__(
tenant_id=None, tenant_name=None, user_id=None, self,
user_name=None, trust_id=None, trustee_user_id=None, token_id=None,
audit_id=None, audit_chain_id=None): expires=None,
super(Token, self).__init__() issued=None,
tenant_id=None,
tenant_name=None,
user_id=None,
user_name=None,
trust_id=None,
trustee_user_id=None,
audit_id=None,
audit_chain_id=None,
):
super().__init__()
self.token_id = token_id or uuid.uuid4().hex self.token_id = token_id or uuid.uuid4().hex
self.user_id = user_id or uuid.uuid4().hex self.user_id = user_id or uuid.uuid4().hex
@ -75,8 +93,9 @@ class Token(dict):
if trust_id or trustee_user_id: if trust_id or trustee_user_id:
# the trustee_user_id will generally be the same as the user_id as # the trustee_user_id will generally be the same as the user_id as
# the token is being issued to the trustee # the token is being issued to the trustee
self.set_trust(id=trust_id, self.set_trust(
trustee_user_id=trustee_user_id or user_id) id=trust_id, trustee_user_id=trustee_user_id or user_id
)
if audit_chain_id: if audit_chain_id:
self.audit_chain_id = audit_chain_id self.audit_chain_id = audit_chain_id
@ -237,8 +256,10 @@ class Token(dict):
def remove_service(self, type): def remove_service(self, type):
self.root['serviceCatalog'] = [ self.root['serviceCatalog'] = [
f for f in self.root.setdefault('serviceCatalog', []) f
if f['type'] != type] for f in self.root.setdefault('serviceCatalog', [])
if f['type'] != type
]
def set_scope(self, id=None, name=None): def set_scope(self, id=None, name=None):
self.tenant_id = id or uuid.uuid4().hex self.tenant_id = id or uuid.uuid4().hex

View File

@ -25,16 +25,19 @@ class _Service(dict):
""" """
def add_endpoint(self, interface, url, region=None, id=None): def add_endpoint(self, interface, url, region=None, id=None):
data = {'id': id or uuid.uuid4().hex, data = {
'id': id or uuid.uuid4().hex,
'interface': interface, 'interface': interface,
'url': url, 'url': url,
'region': region, 'region': region,
'region_id': region} 'region_id': region,
}
self.setdefault('endpoints', []).append(data) self.setdefault('endpoints', []).append(data)
return data return data
def add_standard_endpoints(self, public=None, admin=None, internal=None, def add_standard_endpoints(
region=None): self, public=None, admin=None, internal=None, region=None
):
ret = [] ret = []
if public: if public:
@ -56,18 +59,36 @@ class Token(dict):
that matter to them and not copy and paste sample. that matter to them and not copy and paste sample.
""" """
def __init__(self, expires=None, issued=None, user_id=None, user_name=None, def __init__(
user_domain_id=None, user_domain_name=None, methods=None, self,
project_id=None, project_name=None, project_domain_id=None, expires=None,
project_domain_name=None, domain_id=None, domain_name=None, issued=None,
trust_id=None, trust_impersonation=None, trustee_user_id=None, user_id=None,
trustor_user_id=None, application_credential_id=None, user_name=None,
user_domain_id=None,
user_domain_name=None,
methods=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None,
domain_id=None,
domain_name=None,
trust_id=None,
trust_impersonation=None,
trustee_user_id=None,
trustor_user_id=None,
application_credential_id=None,
application_credential_access_rules=None, application_credential_access_rules=None,
oauth_access_token_id=None, oauth_consumer_id=None, oauth_access_token_id=None,
audit_id=None, audit_chain_id=None, oauth_consumer_id=None,
is_admin_project=None, project_is_domain=None, audit_id=None,
oauth2_thumbprint=None): audit_chain_id=None,
super(Token, self).__init__() is_admin_project=None,
project_is_domain=None,
oauth2_thumbprint=None,
):
super().__init__()
self.user_id = user_id or uuid.uuid4().hex self.user_id = user_id or uuid.uuid4().hex
self.user_name = user_name or uuid.uuid4().hex self.user_name = user_name or uuid.uuid4().hex
@ -97,32 +118,47 @@ class Token(dict):
# expires should be able to be passed as a string so ignore # expires should be able to be passed as a string so ignore
self.expires_str = expires self.expires_str = expires
if (project_id or project_name or if (
project_domain_id or project_domain_name): project_id
self.set_project_scope(id=project_id, or project_name
or project_domain_id
or project_domain_name
):
self.set_project_scope(
id=project_id,
name=project_name, name=project_name,
domain_id=project_domain_id, domain_id=project_domain_id,
domain_name=project_domain_name, domain_name=project_domain_name,
is_domain=project_is_domain) is_domain=project_is_domain,
)
if domain_id or domain_name: if domain_id or domain_name:
self.set_domain_scope(id=domain_id, name=domain_name) self.set_domain_scope(id=domain_id, name=domain_name)
if (trust_id or (trust_impersonation is not None) or if (
trustee_user_id or trustor_user_id): trust_id
self.set_trust_scope(id=trust_id, or (trust_impersonation is not None)
or trustee_user_id
or trustor_user_id
):
self.set_trust_scope(
id=trust_id,
impersonation=trust_impersonation, impersonation=trust_impersonation,
trustee_user_id=trustee_user_id, trustee_user_id=trustee_user_id,
trustor_user_id=trustor_user_id) trustor_user_id=trustor_user_id,
)
if application_credential_id: if application_credential_id:
self.set_application_credential( self.set_application_credential(
application_credential_id, application_credential_id,
access_rules=application_credential_access_rules) access_rules=application_credential_access_rules,
)
if oauth_access_token_id or oauth_consumer_id: if oauth_access_token_id or oauth_consumer_id:
self.set_oauth(access_token_id=oauth_access_token_id, self.set_oauth(
consumer_id=oauth_consumer_id) access_token_id=oauth_access_token_id,
consumer_id=oauth_consumer_id,
)
if audit_chain_id: if audit_chain_id:
self.audit_chain_id = audit_chain_id self.audit_chain_id = audit_chain_id
@ -326,7 +362,8 @@ class Token(dict):
@application_credential_id.setter @application_credential_id.setter
def application_credential_id(self, value): def application_credential_id(self, value):
application_credential = self.root.setdefault( application_credential = self.root.setdefault(
'application_credential', {}) 'application_credential', {}
)
application_credential.setdefault('id', value) application_credential.setdefault('id', value)
@property @property
@ -336,7 +373,8 @@ class Token(dict):
@application_credential_access_rules.setter @application_credential_access_rules.setter
def application_credential_access_rules(self, value): def application_credential_access_rules(self, value):
application_credential = self.root.setdefault( application_credential = self.root.setdefault(
'application_credential', {}) 'application_credential', {}
)
application_credential.setdefault('access_rules', value) application_credential.setdefault('access_rules', value)
@property @property
@ -438,8 +476,7 @@ class Token(dict):
def add_role(self, name=None, id=None): def add_role(self, name=None, id=None):
roles = self.root.setdefault('roles', []) roles = self.root.setdefault('roles', [])
data = {'id': id or uuid.uuid4().hex, data = {'id': id or uuid.uuid4().hex, 'name': name or uuid.uuid4().hex}
'name': name or uuid.uuid4().hex}
roles.append(data) roles.append(data)
return data return data
@ -453,11 +490,17 @@ class Token(dict):
def remove_service(self, type): def remove_service(self, type):
self.root.setdefault('catalog', []) self.root.setdefault('catalog', [])
self.root['catalog'] = [ self.root['catalog'] = [
f for f in self.root.setdefault('catalog', []) f for f in self.root.setdefault('catalog', []) if f['type'] != type
if f['type'] != type] ]
def set_project_scope(self, id=None, name=None, domain_id=None, def set_project_scope(
domain_name=None, is_domain=None): self,
id=None,
name=None,
domain_id=None,
domain_name=None,
is_domain=None,
):
self.project_id = id or uuid.uuid4().hex self.project_id = id or uuid.uuid4().hex
self.project_name = name or uuid.uuid4().hex self.project_name = name or uuid.uuid4().hex
self.project_domain_id = domain_id or uuid.uuid4().hex self.project_domain_id = domain_id or uuid.uuid4().hex
@ -477,8 +520,13 @@ class Token(dict):
# entire system. # entire system.
self.system = {'all': True} self.system = {'all': True}
def set_trust_scope(self, id=None, impersonation=False, def set_trust_scope(
trustee_user_id=None, trustor_user_id=None): self,
id=None,
impersonation=False,
trustee_user_id=None,
trustor_user_id=None,
):
self.trust_id = id or uuid.uuid4().hex self.trust_id = id or uuid.uuid4().hex
self.trust_impersonation = impersonation self.trust_impersonation = impersonation
self.trustee_user_id = trustee_user_id or uuid.uuid4().hex self.trustee_user_id = trustee_user_id or uuid.uuid4().hex
@ -488,8 +536,9 @@ class Token(dict):
self.oauth_access_token_id = access_token_id or uuid.uuid4().hex self.oauth_access_token_id = access_token_id or uuid.uuid4().hex
self.oauth_consumer_id = consumer_id or uuid.uuid4().hex self.oauth_consumer_id = consumer_id or uuid.uuid4().hex
def set_application_credential(self, application_credential_id, def set_application_credential(
access_rules=None): self, application_credential_id, access_rules=None
):
self.application_credential_id = application_credential_id self.application_credential_id = application_credential_id
if access_rules is not None: if access_rules is not None:
self.application_credential_access_rules = access_rules self.application_credential_access_rules = access_rules
@ -517,20 +566,22 @@ class V3FederationToken(Token):
FEDERATED_DOMAIN_ID = 'Federated' FEDERATED_DOMAIN_ID = 'Federated'
def __init__(self, methods=None, identity_provider=None, protocol=None, def __init__(
groups=None): self, methods=None, identity_provider=None, protocol=None, groups=None
):
methods = methods or ['saml2'] methods = methods or ['saml2']
super(V3FederationToken, self).__init__(methods=methods) super().__init__(methods=methods)
self._user_domain = {'id': V3FederationToken.FEDERATED_DOMAIN_ID} self._user_domain = {'id': V3FederationToken.FEDERATED_DOMAIN_ID}
self.add_federation_info_to_user(identity_provider, protocol, groups) self.add_federation_info_to_user(identity_provider, protocol, groups)
def add_federation_info_to_user(self, identity_provider=None, def add_federation_info_to_user(
protocol=None, groups=None): self, identity_provider=None, protocol=None, groups=None
):
data = { data = {
"OS-FEDERATION": { "OS-FEDERATION": {
"identity_provider": identity_provider or uuid.uuid4().hex, "identity_provider": identity_provider or uuid.uuid4().hex,
"protocol": protocol or uuid.uuid4().hex, "protocol": protocol or uuid.uuid4().hex,
"groups": groups or [{"id": uuid.uuid4().hex}] "groups": groups or [{"id": uuid.uuid4().hex}],
} }
} }
self._user.update(data) self._user.update(data)

View File

@ -18,7 +18,6 @@ errors so that core devs don't have to.
""" """
import re import re
from hacking import core from hacking import core
@ -27,10 +26,11 @@ from hacking import core
@core.flake8ext @core.flake8ext
def check_oslo_namespace_imports(logical_line, blank_before, filename): def check_oslo_namespace_imports(logical_line, blank_before, filename):
oslo_namespace_imports = re.compile( oslo_namespace_imports = re.compile(
r"(((from)|(import))\s+oslo\.)|(from\s+oslo\s+import\s+)") r"(((from)|(import))\s+oslo\.)|(from\s+oslo\s+import\s+)"
)
if re.match(oslo_namespace_imports, logical_line): if re.match(oslo_namespace_imports, logical_line):
msg = ("K333: '%s' must be used instead of '%s'.") % ( msg = ("K333: '{}' must be used instead of '{}'.").format(
logical_line.replace('oslo.', 'oslo_'), logical_line.replace('oslo.', 'oslo_'), logical_line
logical_line) )
yield (0, msg) yield (0, msg)

View File

@ -25,15 +25,14 @@ class HTTPBasicAuth(plugin.FixedEndpointPlugin):
""" """
def __init__(self, endpoint=None, username=None, password=None): def __init__(self, endpoint=None, username=None, password=None):
super(HTTPBasicAuth, self).__init__(endpoint) super().__init__(endpoint)
self.username = username self.username = username
self.password = password self.password = password
def get_token(self, session, **kwargs): def get_token(self, session, **kwargs):
if self.username is None or self.password is None: if self.username is None or self.password is None:
return None return None
token = bytes('%s:%s' % (self.username, self.password), token = bytes(f'{self.username}:{self.password}', encoding='utf-8')
encoding='utf-8')
encoded = base64.b64encode(token) encoded = base64.b64encode(token)
return str(encoded, encoding='utf-8') return str(encoded, encoding='utf-8')
@ -41,5 +40,5 @@ class HTTPBasicAuth(plugin.FixedEndpointPlugin):
token = self.get_token(session) token = self.get_token(session)
if not token: if not token:
return None return None
auth = 'Basic %s' % token auth = f'Basic {token}'
return {AUTH_HEADER_NAME: auth} return {AUTH_HEADER_NAME: auth}

View File

@ -70,7 +70,8 @@ V3OAuth2ClientCredential = v3.OAuth2ClientCredential
V3OAuth2mTlsClientCredential = v3.OAuth2mTlsClientCredential V3OAuth2mTlsClientCredential = v3.OAuth2mTlsClientCredential
"""See :class:`keystoneauth1.identity.v3.OAuth2mTlsClientCredential`""" """See :class:`keystoneauth1.identity.v3.OAuth2mTlsClientCredential`"""
__all__ = ('BaseIdentityPlugin', __all__ = (
'BaseIdentityPlugin',
'Password', 'Password',
'Token', 'Token',
'V2Password', 'V2Password',
@ -86,4 +87,5 @@ __all__ = ('BaseIdentityPlugin',
'V3ApplicationCredential', 'V3ApplicationCredential',
'V3MultiFactor', 'V3MultiFactor',
'V3OAuth2ClientCredential', 'V3OAuth2ClientCredential',
'V3OAuth2mTlsClientCredential') 'V3OAuth2mTlsClientCredential',
)

View File

@ -31,8 +31,7 @@ class AccessInfoPlugin(base.BaseIdentityPlugin):
""" """
def __init__(self, auth_ref, auth_url=None): def __init__(self, auth_ref, auth_url=None):
super(AccessInfoPlugin, self).__init__(auth_url=auth_url, super().__init__(auth_url=auth_url, reauthenticate=False)
reauthenticate=False)
self.auth_ref = auth_ref self.auth_ref = auth_ref
def get_auth_ref(self, session, **kwargs): def get_auth_ref(self, session, **kwargs):

View File

@ -27,14 +27,12 @@ LOG = utils.get_logger(__name__)
class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta): class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# we count a token as valid (not needing refreshing) if it is valid for at # we count a token as valid (not needing refreshing) if it is valid for at
# least this many seconds before the token expiry time # least this many seconds before the token expiry time
MIN_TOKEN_LIFE_SECONDS = 120 MIN_TOKEN_LIFE_SECONDS = 120
def __init__(self, auth_url=None, reauthenticate=True): def __init__(self, auth_url=None, reauthenticate=True):
super().__init__()
super(BaseIdentityPlugin, self).__init__()
self.auth_url = auth_url self.auth_url = auth_url
self.auth_ref = None self.auth_ref = None
@ -152,11 +150,22 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
return False return False
def get_endpoint_data(self, session, service_type=None, interface=None, def get_endpoint_data(
region_name=None, service_name=None, allow=None, self,
allow_version_hack=True, discover_versions=True, session,
skip_discovery=False, min_version=None, service_type=None,
max_version=None, endpoint_override=None, **kwargs): interface=None,
region_name=None,
service_name=None,
allow=None,
allow_version_hack=True,
discover_versions=True,
skip_discovery=False,
min_version=None,
max_version=None,
endpoint_override=None,
**kwargs,
):
"""Return a valid endpoint data for a service. """Return a valid endpoint data for a service.
If a valid token is not present then a new one will be fetched using If a valid token is not present then a new one will be fetched using
@ -223,7 +232,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
allow = allow or {} allow = allow or {}
min_version, max_version = discover._normalize_version_args( min_version, max_version = discover._normalize_version_args(
None, min_version, max_version, service_type=service_type) None, min_version, max_version, service_type=service_type
)
# NOTE(jamielennox): if you specifically ask for requests to be sent to # NOTE(jamielennox): if you specifically ask for requests to be sent to
# the auth url then we can ignore many of the checks. Typically if you # the auth url then we can ignore many of the checks. Typically if you
@ -233,7 +243,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
if interface is plugin.AUTH_INTERFACE: if interface is plugin.AUTH_INTERFACE:
endpoint_data = discover.EndpointData( endpoint_data = discover.EndpointData(
service_url=self.auth_url, service_url=self.auth_url,
service_type=service_type or 'identity') service_type=service_type or 'identity',
)
project_id = None project_id = None
elif endpoint_override: elif endpoint_override:
# TODO(mordred) Make a code path that will look for a # TODO(mordred) Make a code path that will look for a
@ -246,7 +257,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
catalog_url=endpoint_override, catalog_url=endpoint_override,
interface=interface, interface=interface,
region_name=region_name, region_name=region_name,
service_name=service_name) service_name=service_name,
)
# Setting an endpoint_override then calling get_endpoint_data means # Setting an endpoint_override then calling get_endpoint_data means
# you absolutely want the discovery info for the URL in question. # you absolutely want the discovery info for the URL in question.
# There are no code flows where this will happen for any other # There are no code flows where this will happen for any other
@ -255,9 +267,11 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
project_id = self.get_project_id(session) project_id = self.get_project_id(session)
else: else:
if not service_type: if not service_type:
LOG.warning('Plugin cannot return an endpoint without ' LOG.warning(
'Plugin cannot return an endpoint without '
'knowing the service type that is required. Add ' 'knowing the service type that is required. Add '
'service_type to endpoint filtering data.') 'service_type to endpoint filtering data.'
)
return None return None
# It's possible for things higher in the stack, because of # It's possible for things higher in the stack, because of
@ -273,7 +287,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
service_type=service_type, service_type=service_type,
interface=interface, interface=interface,
region_name=region_name, region_name=region_name,
service_name=service_name) service_name=service_name,
)
if not endpoint_data: if not endpoint_data:
return None return None
@ -288,10 +303,14 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
max_version=max_version, max_version=max_version,
cache=self._discovery_cache, cache=self._discovery_cache,
discover_versions=discover_versions, discover_versions=discover_versions,
allow_version_hack=allow_version_hack, allow=allow) allow_version_hack=allow_version_hack,
except (exceptions.DiscoveryFailure, allow=allow,
)
except (
exceptions.DiscoveryFailure,
exceptions.HttpError, exceptions.HttpError,
exceptions.ConnectionError): exceptions.ConnectionError,
):
# If a version was requested, we didn't find it, return # If a version was requested, we didn't find it, return
# None. # None.
if max_version or min_version: if max_version or min_version:
@ -300,12 +319,21 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# should be fine # should be fine
return endpoint_data return endpoint_data
def get_endpoint(self, session, service_type=None, interface=None, def get_endpoint(
region_name=None, service_name=None, version=None, self,
allow=None, allow_version_hack=True, session,
service_type=None,
interface=None,
region_name=None,
service_name=None,
version=None,
allow=None,
allow_version_hack=True,
skip_discovery=False, skip_discovery=False,
min_version=None, max_version=None, min_version=None,
**kwargs): max_version=None,
**kwargs,
):
"""Return a valid endpoint for a service. """Return a valid endpoint for a service.
If a valid token is not present then a new one will be fetched using If a valid token is not present then a new one will be fetched using
@ -363,26 +391,45 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# Explode `version` into min_version and max_version - everything below # Explode `version` into min_version and max_version - everything below
# here uses the latter rather than the former. # here uses the latter rather than the former.
min_version, max_version = discover._normalize_version_args( min_version, max_version = discover._normalize_version_args(
version, min_version, max_version, service_type=service_type) version, min_version, max_version, service_type=service_type
)
# Set discover_versions to False since we're only going to return # Set discover_versions to False since we're only going to return
# a URL. Fetching the microversion data would be needlessly # a URL. Fetching the microversion data would be needlessly
# expensive in the common case. However, discover_versions=False # expensive in the common case. However, discover_versions=False
# will still run discovery if the version requested is not the # will still run discovery if the version requested is not the
# version in the catalog. # version in the catalog.
endpoint_data = self.get_endpoint_data( endpoint_data = self.get_endpoint_data(
session, service_type=service_type, interface=interface, session,
region_name=region_name, service_name=service_name, service_type=service_type,
allow=allow, min_version=min_version, max_version=max_version, interface=interface,
discover_versions=False, skip_discovery=skip_discovery, region_name=region_name,
allow_version_hack=allow_version_hack, **kwargs) service_name=service_name,
allow=allow,
min_version=min_version,
max_version=max_version,
discover_versions=False,
skip_discovery=skip_discovery,
allow_version_hack=allow_version_hack,
**kwargs,
)
return endpoint_data.url if endpoint_data else None return endpoint_data.url if endpoint_data else None
def get_api_major_version(self, session, service_type=None, interface=None, def get_api_major_version(
region_name=None, service_name=None, self,
version=None, allow=None, session,
allow_version_hack=True, skip_discovery=False, service_type=None,
discover_versions=False, min_version=None, interface=None,
max_version=None, **kwargs): region_name=None,
service_name=None,
version=None,
allow=None,
allow_version_hack=True,
skip_discovery=False,
discover_versions=False,
min_version=None,
max_version=None,
**kwargs,
):
"""Return the major API version for a service. """Return the major API version for a service.
If a valid token is not present then a new one will be fetched using If a valid token is not present then a new one will be fetched using
@ -456,8 +503,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
``block-storage`` service and they do:: ``block-storage`` service and they do::
client = adapter.Adapter( client = adapter.Adapter(
session, service_type='block-storage', min_version=2, session, service_type='block-storage', min_version=2, max_version=3
max_version=3) )
volume_version = client.get_api_major_version() volume_version = client.get_api_major_version()
The version actually be returned with no api calls other than getting The version actually be returned with no api calls other than getting
@ -485,15 +532,23 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# Explode `version` into min_version and max_version - everything below # Explode `version` into min_version and max_version - everything below
# here uses the latter rather than the former. # here uses the latter rather than the former.
min_version, max_version = discover._normalize_version_args( min_version, max_version = discover._normalize_version_args(
version, min_version, max_version, service_type=service_type) version, min_version, max_version, service_type=service_type
)
# Using functools.partial here just to reduce copy-pasta of params # Using functools.partial here just to reduce copy-pasta of params
get_endpoint_data = functools.partial( get_endpoint_data = functools.partial(
self.get_endpoint_data, self.get_endpoint_data,
session, service_type=service_type, interface=interface, session,
region_name=region_name, service_name=service_name, service_type=service_type,
allow=allow, min_version=min_version, max_version=max_version, interface=interface,
region_name=region_name,
service_name=service_name,
allow=allow,
min_version=min_version,
max_version=max_version,
skip_discovery=skip_discovery, skip_discovery=skip_discovery,
allow_version_hack=allow_version_hack, **kwargs) allow_version_hack=allow_version_hack,
**kwargs,
)
data = get_endpoint_data(discover_versions=discover_versions) data = get_endpoint_data(discover_versions=discover_versions)
if (not data or not data.api_version) and not discover_versions: if (not data or not data.api_version) and not discover_versions:
# It's possible that no version was requested and the endpoint # It's possible that no version was requested and the endpoint
@ -505,9 +560,14 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
return None return None
return data.api_version return data.api_version
def get_all_version_data(self, session, interface='public', def get_all_version_data(
region_name=None, service_type=None, self,
**kwargs): session,
interface='public',
region_name=None,
service_type=None,
**kwargs,
):
"""Get version data for all services in the catalog. """Get version data for all services in the catalog.
:param session: A session object that can be used for communication. :param session: A session object that can be used for communication.
@ -539,12 +599,12 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
for endpoint_service_type, services in endpoints_data.items(): for endpoint_service_type, services in endpoints_data.items():
if service_types.is_known(endpoint_service_type): if service_types.is_known(endpoint_service_type):
endpoint_service_type = service_types.get_service_type( endpoint_service_type = service_types.get_service_type(
endpoint_service_type) endpoint_service_type
)
for service in services: for service in services:
versions = service.get_all_version_string_data( versions = service.get_all_version_string_data(
session=session, session=session, project_id=self.get_project_id(session)
project_id=self.get_project_id(session),
) )
if service.region_name not in version_data: if service.region_name not in version_data:
@ -566,15 +626,15 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
def get_sp_auth_url(self, session, sp_id, **kwargs): def get_sp_auth_url(self, session, sp_id, **kwargs):
try: try:
return self.get_access( return self.get_access(session).service_providers.get_auth_url(
session).service_providers.get_auth_url(sp_id) sp_id
)
except exceptions.ServiceProviderNotFound: except exceptions.ServiceProviderNotFound:
return None return None
def get_sp_url(self, session, sp_id, **kwargs): def get_sp_url(self, session, sp_id, **kwargs):
try: try:
return self.get_access( return self.get_access(session).service_providers.get_sp_url(sp_id)
session).service_providers.get_sp_url(sp_id)
except exceptions.ServiceProviderNotFound: except exceptions.ServiceProviderNotFound:
return None return None
@ -602,9 +662,12 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
:returns: A discovery object with the results of looking up that URL. :returns: A discovery object with the results of looking up that URL.
""" """
return discover.get_discovery(session=session, url=url, return discover.get_discovery(
session=session,
url=url,
cache=self._discovery_cache, cache=self._discovery_cache,
authenticated=authenticated) authenticated=authenticated,
)
def get_cache_id_elements(self): def get_cache_id_elements(self):
"""Get the elements for this auth plugin that make it unique. """Get the elements for this auth plugin that make it unique.
@ -667,8 +730,10 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
:rtype: str or None if no auth present. :rtype: str or None if no auth present.
""" """
if self.auth_ref: if self.auth_ref:
data = {'auth_token': self.auth_ref.auth_token, data = {
'body': self.auth_ref._data} 'auth_token': self.auth_ref.auth_token,
'body': self.auth_ref._data,
}
return json.dumps(data) return json.dumps(data)
@ -680,7 +745,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
""" """
if data: if data:
auth_data = json.loads(data) auth_data = json.loads(data)
self.auth_ref = access.create(body=auth_data['body'], self.auth_ref = access.create(
auth_token=auth_data['auth_token']) body=auth_data['body'], auth_token=auth_data['auth_token']
)
else: else:
self.auth_ref = None self.auth_ref = None

View File

@ -15,7 +15,4 @@ from keystoneauth1.identity.generic.password import Password # noqa
from keystoneauth1.identity.generic.token import Token # noqa from keystoneauth1.identity.generic.token import Token # noqa
__all__ = ('BaseGenericPlugin', __all__ = ('BaseGenericPlugin', 'Password', 'Token')
'Password',
'Token',
)

View File

@ -29,7 +29,9 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
URL and then proxy all calls from the base plugin to the versioned one. URL and then proxy all calls from the base plugin to the versioned one.
""" """
def __init__(self, auth_url, def __init__(
self,
auth_url,
tenant_id=None, tenant_id=None,
tenant_name=None, tenant_name=None,
project_id=None, project_id=None,
@ -42,9 +44,9 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
trust_id=None, trust_id=None,
default_domain_id=None, default_domain_id=None,
default_domain_name=None, default_domain_name=None,
reauthenticate=True): reauthenticate=True,
super(BaseGenericPlugin, self).__init__(auth_url=auth_url, ):
reauthenticate=reauthenticate) super().__init__(auth_url=auth_url, reauthenticate=reauthenticate)
self._project_id = project_id or tenant_id self._project_id = project_id or tenant_id
self._project_name = project_name or tenant_name self._project_name = project_name or tenant_name
@ -86,21 +88,30 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
:returns: True if a domain parameter is set, false otherwise. :returns: True if a domain parameter is set, false otherwise.
""" """
return any([self._domain_id, self._domain_name, return any(
self._project_domain_id, self._project_domain_name]) [
self._domain_id,
self._domain_name,
self._project_domain_id,
self._project_domain_name,
]
)
@property @property
def _v2_params(self): def _v2_params(self):
"""Return the parameters that are common to v2 plugins.""" """Return the parameters that are common to v2 plugins."""
return {'trust_id': self._trust_id, return {
'trust_id': self._trust_id,
'tenant_id': self._project_id, 'tenant_id': self._project_id,
'tenant_name': self._project_name, 'tenant_name': self._project_name,
'reauthenticate': self.reauthenticate} 'reauthenticate': self.reauthenticate,
}
@property @property
def _v3_params(self): def _v3_params(self):
"""Return the parameters that are common to v3 plugins.""" """Return the parameters that are common to v3 plugins."""
return {'trust_id': self._trust_id, return {
'trust_id': self._trust_id,
'system_scope': self._system_scope, 'system_scope': self._system_scope,
'project_id': self._project_id, 'project_id': self._project_id,
'project_name': self._project_name, 'project_name': self._project_name,
@ -108,7 +119,8 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
'project_domain_name': self.project_domain_name, 'project_domain_name': self.project_domain_name,
'domain_id': self._domain_id, 'domain_id': self._domain_id,
'domain_name': self._domain_name, 'domain_name': self._domain_name,
'reauthenticate': self.reauthenticate} 'reauthenticate': self.reauthenticate,
}
@property @property
def project_domain_id(self): def project_domain_id(self):
@ -130,16 +142,20 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
plugin = None plugin = None
try: try:
disc = self.get_discovery(session, disc = self.get_discovery(
self.auth_url, session, self.auth_url, authenticated=False
authenticated=False) )
except (exceptions.DiscoveryFailure, except (
exceptions.DiscoveryFailure,
exceptions.HttpError, exceptions.HttpError,
exceptions.SSLError, exceptions.SSLError,
exceptions.ConnectionError) as e: exceptions.ConnectionError,
LOG.warning('Failed to discover available identity versions when ' ) as e:
LOG.warning(
'Failed to discover available identity versions when '
'contacting %s. Attempting to parse version from URL.', 'contacting %s. Attempting to parse version from URL.',
self.auth_url) self.auth_url,
)
url_parts = urllib.parse.urlparse(self.auth_url) url_parts = urllib.parse.urlparse(self.auth_url)
path = url_parts.path.lower() path = url_parts.path.lower()
@ -147,7 +163,8 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
if path.startswith('/v2.0'): if path.startswith('/v2.0'):
if self._has_domain_scope: if self._has_domain_scope:
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
'Cannot use v2 authentication with domain scope') 'Cannot use v2 authentication with domain scope'
)
plugin = self.create_plugin(session, (2, 0), self.auth_url) plugin = self.create_plugin(session, (2, 0), self.auth_url)
elif path.startswith('/v3'): elif path.startswith('/v3'):
plugin = self.create_plugin(session, (3, 0), self.auth_url) plugin = self.create_plugin(session, (3, 0), self.auth_url)
@ -155,7 +172,8 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
'Could not find versioned identity endpoints when ' 'Could not find versioned identity endpoints when '
'attempting to authenticate. Please check that your ' 'attempting to authenticate. Please check that your '
'auth_url is correct. %s' % e) f'auth_url is correct. {e}'
)
else: else:
# NOTE(jamielennox): version_data is always in oldest to newest # NOTE(jamielennox): version_data is always in oldest to newest
@ -172,23 +190,28 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
for data in disc_data: for data in disc_data:
version = data['version'] version = data['version']
if (discover.version_match((2,), version) and if (
self._has_domain_scope): discover.version_match((2,), version)
and self._has_domain_scope
):
# NOTE(jamielennox): if there are domain parameters there # NOTE(jamielennox): if there are domain parameters there
# is no point even trying against v2 APIs. # is no point even trying against v2 APIs.
v2_with_domain_scope = True v2_with_domain_scope = True
continue continue
plugin = self.create_plugin(session, plugin = self.create_plugin(
session,
version, version,
data['url'], data['url'],
raw_status=data['raw_status']) raw_status=data['raw_status'],
)
if plugin: if plugin:
break break
if not plugin and v2_with_domain_scope: if not plugin and v2_with_domain_scope:
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
'Cannot use v2 authentication with domain scope') 'Cannot use v2 authentication with domain scope'
)
if plugin: if plugin:
return plugin return plugin
@ -196,7 +219,8 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
# so there were no URLs that i could use for auth of any version. # so there were no URLs that i could use for auth of any version.
raise exceptions.DiscoveryFailure( raise exceptions.DiscoveryFailure(
'Could not find versioned identity endpoints when attempting ' 'Could not find versioned identity endpoints when attempting '
'to authenticate. Please check that your auth_url is correct.') 'to authenticate. Please check that your auth_url is correct.'
)
def get_auth_ref(self, session, **kwargs): def get_auth_ref(self, session, **kwargs):
if not self._plugin: if not self._plugin:
@ -212,11 +236,13 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
if not _implemented: if not _implemented:
raise NotImplementedError() raise NotImplementedError()
return {'auth_url': self.auth_url, return {
'auth_url': self.auth_url,
'project_id': self._project_id, 'project_id': self._project_id,
'project_name': self._project_name, 'project_name': self._project_name,
'project_domain_id': self.project_domain_id, 'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name, 'project_domain_name': self.project_domain_name,
'domain_id': self._domain_id, 'domain_id': self._domain_id,
'domain_name': self._domain_name, 'domain_name': self._domain_name,
'trust_id': self._trust_id} 'trust_id': self._trust_id,
}

View File

@ -27,9 +27,17 @@ class Password(base.BaseGenericPlugin):
""" """
def __init__(self, auth_url, username=None, user_id=None, password=None, def __init__(
user_domain_id=None, user_domain_name=None, **kwargs): self,
super(Password, self).__init__(auth_url=auth_url, **kwargs) auth_url,
username=None,
user_id=None,
password=None,
user_domain_id=None,
user_domain_name=None,
**kwargs,
):
super().__init__(auth_url=auth_url, **kwargs)
self._username = username self._username = username
self._user_id = user_id self._user_id = user_id
@ -42,23 +50,27 @@ class Password(base.BaseGenericPlugin):
if self._user_domain_id or self._user_domain_name: if self._user_domain_id or self._user_domain_name:
return None return None
return v2.Password(auth_url=url, return v2.Password(
auth_url=url,
user_id=self._user_id, user_id=self._user_id,
username=self._username, username=self._username,
password=self._password, password=self._password,
**self._v2_params) **self._v2_params,
)
elif discover.version_match((3,), version): elif discover.version_match((3,), version):
u_domain_id = self._user_domain_id or self._default_domain_id u_domain_id = self._user_domain_id or self._default_domain_id
u_domain_name = self._user_domain_name or self._default_domain_name u_domain_name = self._user_domain_name or self._default_domain_name
return v3.Password(auth_url=url, return v3.Password(
auth_url=url,
user_id=self._user_id, user_id=self._user_id,
username=self._username, username=self._username,
user_domain_id=u_domain_id, user_domain_id=u_domain_id,
user_domain_name=u_domain_name, user_domain_name=u_domain_name,
password=self._password, password=self._password,
**self._v3_params) **self._v3_params,
)
@property @property
def user_domain_id(self): def user_domain_id(self):
@ -77,8 +89,7 @@ class Password(base.BaseGenericPlugin):
self._user_domain_name = value self._user_domain_name = value
def get_cache_id_elements(self): def get_cache_id_elements(self):
elements = super(Password, self).get_cache_id_elements( elements = super().get_cache_id_elements(_implemented=True)
_implemented=True)
elements['username'] = self._username elements['username'] = self._username
elements['user_id'] = self._user_id elements['user_id'] = self._user_id
elements['password'] = self._password elements['password'] = self._password

View File

@ -23,7 +23,7 @@ class Token(base.BaseGenericPlugin):
""" """
def __init__(self, auth_url, token=None, **kwargs): def __init__(self, auth_url, token=None, **kwargs):
super(Token, self).__init__(auth_url, **kwargs) super().__init__(auth_url, **kwargs)
self._token = token self._token = token
def create_plugin(self, session, version, url, raw_status=None): def create_plugin(self, session, version, url, raw_status=None):
@ -34,6 +34,6 @@ class Token(base.BaseGenericPlugin):
return v3.Token(url, self._token, **self._v3_params) return v3.Token(url, self._token, **self._v3_params)
def get_cache_id_elements(self): def get_cache_id_elements(self):
elements = super(Token, self).get_cache_id_elements(_implemented=True) elements = super().get_cache_id_elements(_implemented=True)
elements['token'] = self._token elements['token'] = self._token
return elements return elements

View File

@ -31,13 +31,15 @@ class Auth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
is going to expire. (optional) default True is going to expire. (optional) default True
""" """
def __init__(self, auth_url, def __init__(
self,
auth_url,
trust_id=None, trust_id=None,
tenant_id=None, tenant_id=None,
tenant_name=None, tenant_name=None,
reauthenticate=True): reauthenticate=True,
super(Auth, self).__init__(auth_url=auth_url, ):
reauthenticate=reauthenticate) super().__init__(auth_url=auth_url, reauthenticate=reauthenticate)
self.trust_id = trust_id self.trust_id = trust_id
self.tenant_id = tenant_id self.tenant_id = tenant_id
@ -56,8 +58,9 @@ class Auth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
params['auth']['trust_id'] = self.trust_id params['auth']['trust_id'] = self.trust_id
_logger.debug('Making authentication request to %s', url) _logger.debug('Making authentication request to %s', url)
resp = session.post(url, json=params, headers=headers, resp = session.post(
authenticated=False, log=False) url, json=params, headers=headers, authenticated=False, log=False
)
try: try:
resp_data = resp.json() resp_data = resp.json()
@ -106,9 +109,15 @@ class Password(Auth):
:raises TypeError: if a user_id or username is not provided. :raises TypeError: if a user_id or username is not provided.
""" """
def __init__(self, auth_url, username=_NOT_PASSED, password=None, def __init__(
user_id=_NOT_PASSED, **kwargs): self,
super(Password, self).__init__(auth_url, **kwargs) auth_url,
username=_NOT_PASSED,
password=None,
user_id=_NOT_PASSED,
**kwargs,
):
super().__init__(auth_url, **kwargs)
if username is _NOT_PASSED and user_id is _NOT_PASSED: if username is _NOT_PASSED and user_id is _NOT_PASSED:
msg = 'You need to specify either a username or user_id' msg = 'You need to specify either a username or user_id'
@ -134,13 +143,15 @@ class Password(Auth):
return {'passwordCredentials': auth} return {'passwordCredentials': auth}
def get_cache_id_elements(self): def get_cache_id_elements(self):
return {'username': self.username, return {
'username': self.username,
'user_id': self.user_id, 'user_id': self.user_id,
'password': self.password, 'password': self.password,
'auth_url': self.auth_url, 'auth_url': self.auth_url,
'tenant_id': self.tenant_id, 'tenant_id': self.tenant_id,
'tenant_name': self.tenant_name, 'tenant_name': self.tenant_name,
'trust_id': self.trust_id} 'trust_id': self.trust_id,
}
class Token(Auth): class Token(Auth):
@ -156,7 +167,7 @@ class Token(Auth):
""" """
def __init__(self, auth_url, token, **kwargs): def __init__(self, auth_url, token, **kwargs):
super(Token, self).__init__(auth_url, **kwargs) super().__init__(auth_url, **kwargs)
self.token = token self.token = token
def get_auth_data(self, headers=None): def get_auth_data(self, headers=None):
@ -165,8 +176,10 @@ class Token(Auth):
return {'token': {'id': self.token}} return {'token': {'id': self.token}}
def get_cache_id_elements(self): def get_cache_id_elements(self):
return {'token': self.token, return {
'token': self.token,
'auth_url': self.auth_url, 'auth_url': self.auth_url,
'tenant_id': self.tenant_id, 'tenant_id': self.tenant_id,
'tenant_name': self.tenant_name, 'tenant_name': self.tenant_name,
'trust_id': self.trust_id} 'trust_id': self.trust_id,
}

View File

@ -27,40 +27,29 @@ from keystoneauth1.identity.v3.oauth2_client_credential import * # noqa
from keystoneauth1.identity.v3.oauth2_mtls_client_credential import * # noqa from keystoneauth1.identity.v3.oauth2_mtls_client_credential import * # noqa
__all__ = ('ApplicationCredential', __all__ = (
'ApplicationCredential',
'ApplicationCredentialMethod', 'ApplicationCredentialMethod',
'Auth', 'Auth',
'AuthConstructor', 'AuthConstructor',
'AuthMethod', 'AuthMethod',
'BaseAuth', 'BaseAuth',
'FederationBaseAuth', 'FederationBaseAuth',
'Keystone2Keystone', 'Keystone2Keystone',
'Password', 'Password',
'PasswordMethod', 'PasswordMethod',
'Token', 'Token',
'TokenMethod', 'TokenMethod',
'OidcAccessToken', 'OidcAccessToken',
'OidcAuthorizationCode', 'OidcAuthorizationCode',
'OidcClientCredentials', 'OidcClientCredentials',
'OidcPassword', 'OidcPassword',
'TOTPMethod', 'TOTPMethod',
'TOTP', 'TOTP',
'TokenlessAuth', 'TokenlessAuth',
'ReceiptMethod', 'ReceiptMethod',
'MultiFactor', 'MultiFactor',
'OAuth2ClientCredential', 'OAuth2ClientCredential',
'OAuth2ClientCredentialMethod', 'OAuth2ClientCredentialMethod',
'OAuth2mTlsClientCredential', 'OAuth2mTlsClientCredential',
) )

View File

@ -37,13 +37,15 @@ class ApplicationCredentialMethod(base.AuthMethod):
provided. provided.
""" """
_method_parameters = ['application_credential_secret', _method_parameters = [
'application_credential_secret',
'application_credential_id', 'application_credential_id',
'application_credential_name', 'application_credential_name',
'user_id', 'user_id',
'username', 'username',
'user_domain_id', 'user_domain_id',
'user_domain_name'] 'user_domain_name',
]
def get_auth_data(self, session, auth, headers, **kwargs): def get_auth_data(self, session, auth, headers, **kwargs):
auth_data = {'secret': self.application_credential_secret} auth_data = {'secret': self.application_credential_secret}
@ -62,13 +64,16 @@ class ApplicationCredentialMethod(base.AuthMethod):
auth_data['user']['domain'] = {'id': self.user_domain_id} auth_data['user']['domain'] = {'id': self.user_domain_id}
elif self.user_domain_name: elif self.user_domain_name:
auth_data['user']['domain'] = { auth_data['user']['domain'] = {
'name': self.user_domain_name} 'name': self.user_domain_name
}
return 'application_credential', auth_data return 'application_credential', auth_data
def get_cache_id_elements(self): def get_cache_id_elements(self):
return dict(('application_credential_%s' % p, getattr(self, p)) return {
for p in self._method_parameters) f'application_credential_{p}': getattr(self, p)
for p in self._method_parameters
}
class ApplicationCredential(base.AuthConstructor): class ApplicationCredential(base.AuthConstructor):

View File

@ -41,7 +41,9 @@ class BaseAuth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
token. (optional) default True. token. (optional) default True.
""" """
def __init__(self, auth_url, def __init__(
self,
auth_url,
trust_id=None, trust_id=None,
system_scope=None, system_scope=None,
domain_id=None, domain_id=None,
@ -51,9 +53,9 @@ class BaseAuth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
project_domain_id=None, project_domain_id=None,
project_domain_name=None, project_domain_name=None,
reauthenticate=True, reauthenticate=True,
include_catalog=True): include_catalog=True,
super(BaseAuth, self).__init__(auth_url=auth_url, ):
reauthenticate=reauthenticate) super().__init__(auth_url=auth_url, reauthenticate=reauthenticate)
self.trust_id = trust_id self.trust_id = trust_id
self.system_scope = system_scope self.system_scope = system_scope
self.domain_id = domain_id self.domain_id = domain_id
@ -67,7 +69,7 @@ class BaseAuth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
@property @property
def token_url(self): def token_url(self):
"""The full URL where we will send authentication data.""" """The full URL where we will send authentication data."""
return '%s/auth/tokens' % self.auth_url.rstrip('/') return '{}/auth/tokens'.format(self.auth_url.rstrip('/'))
@abc.abstractmethod @abc.abstractmethod
def get_auth_ref(self, session, **kwargs): def get_auth_ref(self, session, **kwargs):
@ -76,9 +78,14 @@ class BaseAuth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
@property @property
def has_scope_parameters(self): def has_scope_parameters(self):
"""Return true if parameters can be used to create a scoped token.""" """Return true if parameters can be used to create a scoped token."""
return (self.domain_id or self.domain_name or return (
self.project_id or self.project_name or self.domain_id
self.trust_id or self.system_scope) or self.domain_name
or self.project_id
or self.project_name
or self.trust_id
or self.system_scope
)
class Auth(BaseAuth): class Auth(BaseAuth):
@ -104,7 +111,7 @@ class Auth(BaseAuth):
def __init__(self, auth_url, auth_methods, **kwargs): def __init__(self, auth_url, auth_methods, **kwargs):
self.unscoped = kwargs.pop('unscoped', False) self.unscoped = kwargs.pop('unscoped', False)
super(Auth, self).__init__(auth_url=auth_url, **kwargs) super().__init__(auth_url=auth_url, **kwargs)
self.auth_methods = auth_methods self.auth_methods = auth_methods
def add_method(self, method): def add_method(self, method):
@ -119,7 +126,8 @@ class Auth(BaseAuth):
for method in self.auth_methods: for method in self.auth_methods:
name, auth_data = method.get_auth_data( name, auth_data = method.get_auth_data(
session, self, headers, request_kwargs=rkwargs) session, self, headers, request_kwargs=rkwargs
)
# NOTE(adriant): Methods like ReceiptMethod don't # NOTE(adriant): Methods like ReceiptMethod don't
# want anything added to the request data, so they # want anything added to the request data, so they
# explicitly return None, which we check for. # explicitly return None, which we check for.
@ -129,19 +137,23 @@ class Auth(BaseAuth):
if not ident: if not ident:
raise exceptions.AuthorizationFailure( raise exceptions.AuthorizationFailure(
'Authentication method required (e.g. password)') 'Authentication method required (e.g. password)'
)
mutual_exclusion = [bool(self.domain_id or self.domain_name), mutual_exclusion = [
bool(self.domain_id or self.domain_name),
bool(self.project_id or self.project_name), bool(self.project_id or self.project_name),
bool(self.trust_id), bool(self.trust_id),
bool(self.system_scope), bool(self.system_scope),
bool(self.unscoped)] bool(self.unscoped),
]
if sum(mutual_exclusion) > 1: if sum(mutual_exclusion) > 1:
raise exceptions.AuthorizationFailure( raise exceptions.AuthorizationFailure(
message='Authentication cannot be scoped to multiple' message='Authentication cannot be scoped to multiple'
' targets. Pick one of: project, domain, ' ' targets. Pick one of: project, domain, '
'trust, system or unscoped') 'trust, system or unscoped'
)
if self.domain_id: if self.domain_id:
body['auth']['scope'] = {'domain': {'id': self.domain_id}} body['auth']['scope'] = {'domain': {'id': self.domain_id}}
@ -174,7 +186,7 @@ class Auth(BaseAuth):
token_url = self.token_url token_url = self.token_url
if not self.auth_url.rstrip('/').endswith('v3'): if not self.auth_url.rstrip('/').endswith('v3'):
token_url = '%s/v3/auth/tokens' % self.auth_url.rstrip('/') token_url = '{}/v3/auth/tokens'.format(self.auth_url.rstrip('/'))
# NOTE(jamielennox): we add nocatalog here rather than in token_url # NOTE(jamielennox): we add nocatalog here rather than in token_url
# directly as some federation plugins require the base token_url # directly as some federation plugins require the base token_url
@ -182,8 +194,14 @@ class Auth(BaseAuth):
token_url += '?nocatalog' token_url += '?nocatalog'
_logger.debug('Making authentication request to %s', token_url) _logger.debug('Making authentication request to %s', token_url)
resp = session.post(token_url, json=body, headers=headers, resp = session.post(
authenticated=False, log=False, **rkwargs) token_url,
json=body,
headers=headers,
authenticated=False,
log=False,
**rkwargs,
)
try: try:
_logger.debug(json.dumps(resp.json())) _logger.debug(json.dumps(resp.json()))
@ -194,21 +212,24 @@ class Auth(BaseAuth):
if 'token' not in resp_data: if 'token' not in resp_data:
raise exceptions.InvalidResponse(response=resp) raise exceptions.InvalidResponse(response=resp)
return access.AccessInfoV3(auth_token=resp.headers['X-Subject-Token'], return access.AccessInfoV3(
body=resp_data) auth_token=resp.headers['X-Subject-Token'], body=resp_data
)
def get_cache_id_elements(self): def get_cache_id_elements(self):
if not self.auth_methods: if not self.auth_methods:
return None return None
params = {'auth_url': self.auth_url, params = {
'auth_url': self.auth_url,
'domain_id': self.domain_id, 'domain_id': self.domain_id,
'domain_name': self.domain_name, 'domain_name': self.domain_name,
'project_id': self.project_id, 'project_id': self.project_id,
'project_name': self.project_name, 'project_name': self.project_name,
'project_domain_id': self.project_domain_id, 'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name, 'project_domain_name': self.project_domain_name,
'trust_id': self.trust_id} 'trust_id': self.trust_id,
}
for method in self.auth_methods: for method in self.auth_methods:
try: try:
@ -240,14 +261,13 @@ class AuthMethod(metaclass=abc.ABCMeta):
setattr(self, param, kwargs.pop(param, None)) setattr(self, param, kwargs.pop(param, None))
if kwargs: if kwargs:
msg = "Unexpected Attributes: %s" % ", ".join(kwargs.keys()) msg = "Unexpected Attributes: {}".format(", ".join(kwargs.keys()))
raise AttributeError(msg) raise AttributeError(msg)
@classmethod @classmethod
def _extract_kwargs(cls, kwargs): def _extract_kwargs(cls, kwargs):
"""Remove parameters related to this method from other kwargs.""" """Remove parameters related to this method from other kwargs."""
return dict([(p, kwargs.pop(p, None)) return {p: kwargs.pop(p, None) for p in cls._method_parameters}
for p in cls._method_parameters])
@abc.abstractmethod @abc.abstractmethod
def get_auth_data(self, session, auth, headers, **kwargs): def get_auth_data(self, session, auth, headers, **kwargs):
@ -296,4 +316,4 @@ class AuthConstructor(Auth, metaclass=abc.ABCMeta):
def __init__(self, auth_url, *args, **kwargs): def __init__(self, auth_url, *args, **kwargs):
method_kwargs = self._auth_method_class._extract_kwargs(kwargs) method_kwargs = self._auth_method_class._extract_kwargs(kwargs)
method = self._auth_method_class(*args, **method_kwargs) method = self._auth_method_class(*args, **method_kwargs)
super(AuthConstructor, self).__init__(auth_url, [method], **kwargs) super().__init__(auth_url, [method], **kwargs)

View File

@ -35,13 +35,15 @@ class _Rescoped(base.BaseAuth, metaclass=abc.ABCMeta):
rescoping_plugin = token.Token rescoping_plugin = token.Token
def _get_scoping_data(self): def _get_scoping_data(self):
return {'trust_id': self.trust_id, return {
'trust_id': self.trust_id,
'domain_id': self.domain_id, 'domain_id': self.domain_id,
'domain_name': self.domain_name, 'domain_name': self.domain_name,
'project_id': self.project_id, 'project_id': self.project_id,
'project_name': self.project_name, 'project_name': self.project_name,
'project_domain_id': self.project_domain_id, 'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name} 'project_domain_name': self.project_domain_name,
}
def get_auth_ref(self, session, **kwargs): def get_auth_ref(self, session, **kwargs):
"""Authenticate retrieve token information. """Authenticate retrieve token information.
@ -63,9 +65,9 @@ class _Rescoped(base.BaseAuth, metaclass=abc.ABCMeta):
scoping = self._get_scoping_data() scoping = self._get_scoping_data()
if any(scoping.values()): if any(scoping.values()):
token_plugin = self.rescoping_plugin(self.auth_url, token_plugin = self.rescoping_plugin(
token=auth_ref.auth_token, self.auth_url, token=auth_ref.auth_token, **scoping
**scoping) )
auth_ref = token_plugin.get_auth_ref(session) auth_ref = token_plugin.get_auth_ref(session)
@ -93,7 +95,7 @@ class FederationBaseAuth(_Rescoped):
""" """
def __init__(self, auth_url, identity_provider, protocol, **kwargs): def __init__(self, auth_url, identity_provider, protocol, **kwargs):
super(FederationBaseAuth, self).__init__(auth_url=auth_url, **kwargs) super().__init__(auth_url=auth_url, **kwargs)
self.identity_provider = identity_provider self.identity_provider = identity_provider
self.protocol = protocol self.protocol = protocol
@ -106,10 +108,12 @@ class FederationBaseAuth(_Rescoped):
values = { values = {
'host': host, 'host': host,
'identity_provider': self.identity_provider, 'identity_provider': self.identity_provider,
'protocol': self.protocol 'protocol': self.protocol,
} }
url = ("%(host)s/OS-FEDERATION/identity_providers/" url = (
"%(identity_provider)s/protocols/%(protocol)s/auth") "%(host)s/OS-FEDERATION/identity_providers/"
"%(identity_provider)s/protocols/%(protocol)s/auth"
)
url = url % values url = url % values
return url return url

View File

@ -43,7 +43,7 @@ class Keystone2Keystone(federation._Rescoped):
HTTP_SEE_OTHER = 303 HTTP_SEE_OTHER = 303
def __init__(self, base_plugin, service_provider, **kwargs): def __init__(self, base_plugin, service_provider, **kwargs):
super(Keystone2Keystone, self).__init__(auth_url=None, **kwargs) super().__init__(auth_url=None, **kwargs)
self._local_cloud_plugin = base_plugin self._local_cloud_plugin = base_plugin
self._sp_id = service_provider self._sp_id = service_provider
@ -81,36 +81,38 @@ class Keystone2Keystone(federation._Rescoped):
'methods': ['token'], 'methods': ['token'],
'token': { 'token': {
'id': self._local_cloud_plugin.get_token(session) 'id': self._local_cloud_plugin.get_token(session)
}
}, },
'scope': { },
'service_provider': { 'scope': {'service_provider': {'id': self._sp_id}},
'id': self._sp_id
}
}
} }
} }
endpoint_filter = {'version': (3, 0), endpoint_filter = {
'interface': plugin.AUTH_INTERFACE} 'version': (3, 0),
'interface': plugin.AUTH_INTERFACE,
}
headers = {'Accept': 'application/json'} headers = {'Accept': 'application/json'}
resp = session.post(self.REQUEST_ECP_URL, resp = session.post(
self.REQUEST_ECP_URL,
json=body, json=body,
auth=self._local_cloud_plugin, auth=self._local_cloud_plugin,
endpoint_filter=endpoint_filter, endpoint_filter=endpoint_filter,
headers=headers, headers=headers,
authenticated=False, authenticated=False,
raise_exc=False) raise_exc=False,
)
# NOTE(marek-denis): I am not sure whether disabling exceptions in the # NOTE(marek-denis): I am not sure whether disabling exceptions in the
# Session object and testing if resp.ok is sufficient. An alternative # Session object and testing if resp.ok is sufficient. An alternative
# would be catching locally all exceptions and reraising with custom # would be catching locally all exceptions and reraising with custom
# warning. # warning.
if not resp.ok: if not resp.ok:
msg = ("Error while requesting ECP wrapped assertion: response " msg = (
"exit code: %(status_code)d, reason: %(err)s") "Error while requesting ECP wrapped assertion: response "
"exit code: %(status_code)d, reason: %(err)s"
)
msg = msg % {'status_code': resp.status_code, 'err': resp.reason} msg = msg % {'status_code': resp.status_code, 'err': resp.reason}
raise exceptions.AuthorizationFailure(msg) raise exceptions.AuthorizationFailure(msg)
@ -119,8 +121,9 @@ class Keystone2Keystone(federation._Rescoped):
return str(resp.text) return str(resp.text)
def _send_service_provider_ecp_authn_response(self, session, sp_url, def _send_service_provider_ecp_authn_response(
sp_auth_url): self, session, sp_url, sp_auth_url
):
"""Present ECP wrapped SAML assertion to the keystone SP. """Present ECP wrapped SAML assertion to the keystone SP.
The assertion is issued by the keystone IdP and it is targeted to the The assertion is issued by the keystone IdP and it is targeted to the
@ -145,27 +148,33 @@ class Keystone2Keystone(federation._Rescoped):
headers={'Content-Type': 'application/vnd.paos+xml'}, headers={'Content-Type': 'application/vnd.paos+xml'},
data=self._get_ecp_assertion(session), data=self._get_ecp_assertion(session),
authenticated=False, authenticated=False,
redirect=False) redirect=False,
)
# Don't follow HTTP specs - after the HTTP 302/303 response don't # Don't follow HTTP specs - after the HTTP 302/303 response don't
# repeat the call directed to the Location URL. In this case, this is # repeat the call directed to the Location URL. In this case, this is
# an indication that SAML2 session is now active and protected resource # an indication that SAML2 session is now active and protected resource
# can be accessed. # can be accessed.
if response.status_code in (self.HTTP_MOVED_TEMPORARILY, if response.status_code in (
self.HTTP_SEE_OTHER): self.HTTP_MOVED_TEMPORARILY,
self.HTTP_SEE_OTHER,
):
response = session.get( response = session.get(
sp_auth_url, sp_auth_url,
headers={'Content-Type': 'application/vnd.paos+xml'}, headers={'Content-Type': 'application/vnd.paos+xml'},
authenticated=False) authenticated=False,
)
return response return response
def get_unscoped_auth_ref(self, session, **kwargs): def get_unscoped_auth_ref(self, session, **kwargs):
sp_auth_url = self._local_cloud_plugin.get_sp_auth_url( sp_auth_url = self._local_cloud_plugin.get_sp_auth_url(
session, self._sp_id) session, self._sp_id
)
sp_url = self._local_cloud_plugin.get_sp_url(session, self._sp_id) sp_url = self._local_cloud_plugin.get_sp_url(session, self._sp_id)
self.auth_url = self._remote_auth_url(sp_auth_url) self.auth_url = self._remote_auth_url(sp_auth_url)
response = self._send_service_provider_ecp_authn_response( response = self._send_service_provider_ecp_authn_response(
session, sp_url, sp_auth_url) session, sp_url, sp_auth_url
)
return access.create(resp=response) return access.create(resp=response)

View File

@ -14,7 +14,7 @@ from keystoneauth1.identity.v3 import base
from keystoneauth1 import loading from keystoneauth1 import loading
__all__ = ('MultiFactor', ) __all__ = ('MultiFactor',)
class MultiFactor(base.Auth): class MultiFactor(base.Auth):
@ -42,7 +42,8 @@ class MultiFactor(base.Auth):
for method in auth_methods: for method in auth_methods:
# Using the loaders we pull the related auth method class # Using the loaders we pull the related auth method class
method_class = loading.get_plugin_loader( method_class = loading.get_plugin_loader(
method).plugin_class._auth_method_class method
).plugin_class._auth_method_class
# We build some new kwargs for the method from required parameters # We build some new kwargs for the method from required parameters
method_kwargs = {} method_kwargs = {}
for key in method_class._method_parameters: for key in method_class._method_parameters:
@ -56,4 +57,4 @@ class MultiFactor(base.Auth):
# to the super class and throw errors # to the super class and throw errors
for key in method_keys: for key in method_keys:
kwargs.pop(key, None) kwargs.pop(key, None)
super(MultiFactor, self).__init__(auth_url, method_instances, **kwargs) super().__init__(auth_url, method_instances, **kwargs)

View File

@ -31,7 +31,7 @@ class OAuth2ClientCredentialMethod(base.AuthMethod):
_method_parameters = [ _method_parameters = [
'oauth2_endpoint', 'oauth2_endpoint',
'oauth2_client_id', 'oauth2_client_id',
'oauth2_client_secret' 'oauth2_client_secret',
] ]
def get_auth_data(self, session, auth, headers, **kwargs): def get_auth_data(self, session, auth, headers, **kwargs):
@ -48,7 +48,7 @@ class OAuth2ClientCredentialMethod(base.AuthMethod):
""" """
auth_data = { auth_data = {
'id': self.oauth2_client_id, 'id': self.oauth2_client_id,
'secret': self.oauth2_client_secret 'secret': self.oauth2_client_secret,
} }
return 'application_credential', auth_data return 'application_credential', auth_data
@ -66,8 +66,10 @@ class OAuth2ClientCredentialMethod(base.AuthMethod):
should be prefixed with the plugin identifier. For example the password should be prefixed with the plugin identifier. For example the password
plugin returns its username value as 'password_username'. plugin returns its username value as 'password_username'.
""" """
return dict(('oauth2_client_credential_%s' % p, getattr(self, p)) return {
for p in self._method_parameters) f'oauth2_client_credential_{p}': getattr(self, p)
for p in self._method_parameters
}
class OAuth2ClientCredential(base.AuthConstructor): class OAuth2ClientCredential(base.AuthConstructor):
@ -82,7 +84,7 @@ class OAuth2ClientCredential(base.AuthConstructor):
_auth_method_class = OAuth2ClientCredentialMethod _auth_method_class = OAuth2ClientCredentialMethod
def __init__(self, auth_url, *args, **kwargs): def __init__(self, auth_url, *args, **kwargs):
super(OAuth2ClientCredential, self).__init__(auth_url, *args, **kwargs) super().__init__(auth_url, *args, **kwargs)
self._oauth2_endpoint = kwargs['oauth2_endpoint'] self._oauth2_endpoint = kwargs['oauth2_endpoint']
self._oauth2_client_id = kwargs['oauth2_client_id'] self._oauth2_client_id = kwargs['oauth2_client_id']
self._oauth2_client_secret = kwargs['oauth2_client_secret'] self._oauth2_client_secret = kwargs['oauth2_client_secret']
@ -99,19 +101,21 @@ class OAuth2ClientCredential(base.AuthConstructor):
:rtype: dict :rtype: dict
""" """
# get headers for X-Auth-Token # get headers for X-Auth-Token
headers = super(OAuth2ClientCredential, self).get_headers( headers = super().get_headers(session, **kwargs)
session, **kwargs)
# Get OAuth2.0 access token and add the field 'Authorization' # Get OAuth2.0 access token and add the field 'Authorization'
data = {"grant_type": "client_credentials"} data = {"grant_type": "client_credentials"}
auth = requests.auth.HTTPBasicAuth(self._oauth2_client_id, auth = requests.auth.HTTPBasicAuth(
self._oauth2_client_secret) self._oauth2_client_id, self._oauth2_client_secret
resp = session.request(self._oauth2_endpoint, )
resp = session.request(
self._oauth2_endpoint,
"POST", "POST",
authenticated=False, authenticated=False,
raise_exc=False, raise_exc=False,
data=data, data=data,
requests_auth=auth) requests_auth=auth,
)
if resp.status_code == 200: if resp.status_code == 200:
oauth2 = resp.json() oauth2 = resp.json()
oauth2_token = oauth2["access_token"] oauth2_token = oauth2["access_token"]

View File

@ -27,10 +27,10 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
:param string oauth2_client_id: OAuth2.0 client credential id. :param string oauth2_client_id: OAuth2.0 client credential id.
""" """
def __init__(self, auth_url, oauth2_endpoint, oauth2_client_id, def __init__(
*args, **kwargs): self, auth_url, oauth2_endpoint, oauth2_client_id, *args, **kwargs
super(OAuth2mTlsClientCredential, self).__init__( ):
auth_url, *args, **kwargs) super().__init__(auth_url, *args, **kwargs)
self.auth_url = auth_url self.auth_url = auth_url
self.oauth2_endpoint = oauth2_endpoint self.oauth2_endpoint = oauth2_endpoint
self.oauth2_client_id = oauth2_client_id self.oauth2_client_id = oauth2_client_id
@ -64,12 +64,16 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
""" """
# Get OAuth2.0 access token and add the field 'Authorization' when # Get OAuth2.0 access token and add the field 'Authorization' when
# using the HTTPS protocol. # using the HTTPS protocol.
data = {'grant_type': 'client_credentials', data = {
'client_id': self.oauth2_client_id} 'grant_type': 'client_credentials',
resp = session.post(url=self.oauth2_endpoint, 'client_id': self.oauth2_client_id,
}
resp = session.post(
url=self.oauth2_endpoint,
authenticated=False, authenticated=False,
raise_exc=False, raise_exc=False,
data=data) data=data,
)
if resp.status_code == 200: if resp.status_code == 200:
oauth2 = resp.json() oauth2 = resp.json()
self.oauth2_access_token = oauth2.get('access_token') self.oauth2_access_token = oauth2.get('access_token')
@ -78,17 +82,18 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
msg = error.get('error_description') msg = error.get('error_description')
raise exceptions.ClientException(msg) raise exceptions.ClientException(msg)
headers = {'Accept': 'application/json', headers = {
'Accept': 'application/json',
'X-Auth-Token': self.oauth2_access_token, 'X-Auth-Token': self.oauth2_access_token,
'X-Subject-Token': self.oauth2_access_token} 'X-Subject-Token': self.oauth2_access_token,
}
token_url = '%s/auth/tokens' % self.auth_url.rstrip('/') token_url = '{}/auth/tokens'.format(self.auth_url.rstrip('/'))
if not self.auth_url.rstrip('/').endswith('v3'): if not self.auth_url.rstrip('/').endswith('v3'):
token_url = '%s/v3/auth/tokens' % self.auth_url.rstrip('/') token_url = '{}/v3/auth/tokens'.format(self.auth_url.rstrip('/'))
resp = session.get(url=token_url, resp = session.get(
authenticated=False, url=token_url, authenticated=False, headers=headers, log=False
headers=headers, )
log=False)
try: try:
resp_data = resp.json() resp_data = resp.json()
except ValueError: except ValueError:
@ -96,8 +101,9 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
if 'token' not in resp_data: if 'token' not in resp_data:
raise exceptions.InvalidResponse(response=resp) raise exceptions.InvalidResponse(response=resp)
return access.AccessInfoV3(auth_token=self.oauth2_access_token, return access.AccessInfoV3(
body=resp_data) auth_token=self.oauth2_access_token, body=resp_data
)
def get_headers(self, session, **kwargs): def get_headers(self, session, **kwargs):
"""Fetch authentication headers for message. """Fetch authentication headers for message.
@ -111,8 +117,7 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
:rtype: dict :rtype: dict
""" """
# get headers for X-Auth-Token # get headers for X-Auth-Token
headers = super(OAuth2mTlsClientCredential, self).get_headers( headers = super().get_headers(session, **kwargs)
session, **kwargs)
# add OAuth2.0 access token to the headers # add OAuth2.0 access token to the headers
if headers: if headers:

View File

@ -27,10 +27,12 @@ from keystoneauth1.identity.v3 import federation
_logger = utils.get_logger(__name__) _logger = utils.get_logger(__name__)
__all__ = ('OidcAuthorizationCode', __all__ = (
'OidcAuthorizationCode',
'OidcClientCredentials', 'OidcClientCredentials',
'OidcPassword', 'OidcPassword',
'OidcAccessToken') 'OidcAccessToken',
)
SENSITIVE_KEYS = ("password", "code", "token", "secret") SENSITIVE_KEYS = ("password", "code", "token", "secret")
@ -44,14 +46,20 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
grant_type = None grant_type = None
def __init__(self, auth_url, identity_provider, protocol, def __init__(
client_id, client_secret, self,
auth_url,
identity_provider,
protocol,
client_id,
client_secret,
access_token_type, access_token_type,
scope="openid profile", scope="openid profile",
access_token_endpoint=None, access_token_endpoint=None,
discovery_endpoint=None, discovery_endpoint=None,
grant_type=None, grant_type=None,
**kwargs): **kwargs,
):
"""The OpenID Connect plugin expects the following. """The OpenID Connect plugin expects the following.
:param auth_url: URL of the Identity Service :param auth_url: URL of the Identity Service
@ -96,8 +104,7 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
states that "openid" must be always specified. states that "openid" must be always specified.
:type scope: string :type scope: string
""" """
super(_OidcBase, self).__init__(auth_url, identity_provider, protocol, super().__init__(auth_url, identity_provider, protocol, **kwargs)
**kwargs)
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
@ -111,7 +118,8 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
if grant_type is not None: if grant_type is not None:
if grant_type != self.grant_type: if grant_type != self.grant_type:
raise exceptions.OidcGrantTypeMissmatch() raise exceptions.OidcGrantTypeMissmatch()
warnings.warn("Passing grant_type as an argument has been " warnings.warn(
"Passing grant_type as an argument has been "
"deprecated as it is now defined in the plugin " "deprecated as it is now defined in the plugin "
"itself. You should stop passing this argument " "itself. You should stop passing this argument "
"to the plugin, as it will be ignored, since you " "to the plugin, as it will be ignored, since you "
@ -119,7 +127,8 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
"This argument will be dropped from the plugin in " "This argument will be dropped from the plugin in "
"July 2017 or with the next major release of " "July 2017 or with the next major release of "
"keystoneauth (3.0.0)", "keystoneauth (3.0.0)",
DeprecationWarning) DeprecationWarning,
)
def _get_discovery_document(self, session): def _get_discovery_document(self, session):
"""Get the contents of the OpenID Connect Discovery Document. """Get the contents of the OpenID Connect Discovery Document.
@ -137,14 +146,18 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
otherwise it will return an empty dict. otherwise it will return an empty dict.
:rtype: dict :rtype: dict
""" """
if (self.discovery_endpoint is not None if (
and not self._discovery_document): self.discovery_endpoint is not None
and not self._discovery_document
):
try: try:
resp = session.get(self.discovery_endpoint, resp = session.get(
authenticated=False) self.discovery_endpoint, authenticated=False
)
except exceptions.HttpError: except exceptions.HttpError:
_logger.error("Cannot fetch discovery document %(discovery)s" % _logger.error(
{"discovery": self.discovery_endpoint}) f"Cannot fetch discovery document {self.discovery_endpoint}"
)
raise raise
try: try:
@ -211,20 +224,25 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
sanitized_payload = self._sanitize(payload) sanitized_payload = self._sanitize(payload)
_logger.debug( _logger.debug(
"Making OpenID-Connect authentication request to %s with " "Making OpenID-Connect authentication request to %s with "
"data %s", access_token_endpoint, sanitized_payload "data %s",
access_token_endpoint,
sanitized_payload,
) )
op_response = session.post(access_token_endpoint, op_response = session.post(
access_token_endpoint,
requests_auth=client_auth, requests_auth=client_auth,
data=payload, data=payload,
log=False, log=False,
authenticated=False) authenticated=False,
)
response = op_response.json() response = op_response.json()
if _logger.isEnabledFor(logging.DEBUG): if _logger.isEnabledFor(logging.DEBUG):
sanitized_response = self._sanitize(response) sanitized_response = self._sanitize(response)
_logger.debug( _logger.debug(
"OpenID-Connect authentication response from %s is %s", "OpenID-Connect authentication response from %s is %s",
access_token_endpoint, sanitized_response access_token_endpoint,
sanitized_response,
) )
return response[self.access_token_type] return response[self.access_token_type]
@ -247,9 +265,9 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
""" """
# use access token against protected URL # use access token against protected URL
headers = {'Authorization': 'Bearer ' + access_token} headers = {'Authorization': 'Bearer ' + access_token}
auth_response = session.post(self.federated_token_url, auth_response = session.post(
headers=headers, self.federated_token_url, headers=headers, authenticated=False
authenticated=False) )
return auth_response return auth_response
def get_unscoped_auth_ref(self, session): def get_unscoped_auth_ref(self, session):
@ -278,8 +296,11 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
# First of all, check if the grant type is supported # First of all, check if the grant type is supported
discovery = self._get_discovery_document(session) discovery = self._get_discovery_document(session)
grant_types = discovery.get("grant_types_supported") grant_types = discovery.get("grant_types_supported")
if (grant_types and self.grant_type is not None if (
and self.grant_type not in grant_types): grant_types
and self.grant_type is not None
and self.grant_type not in grant_types
):
raise exceptions.OidcPluginNotSupported() raise exceptions.OidcPluginNotSupported()
# Get the payload # Get the payload
@ -317,13 +338,21 @@ class OidcPassword(_OidcBase):
grant_type = "password" grant_type = "password"
def __init__(self, auth_url, identity_provider, protocol, # nosec def __init__(
client_id, client_secret, self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret,
access_token_endpoint=None, access_token_endpoint=None,
discovery_endpoint=None, discovery_endpoint=None,
access_token_type='access_token', access_token_type='access_token',
username=None, password=None, idp_otp_key=None, username=None,
**kwargs): password=None,
idp_otp_key=None,
**kwargs,
):
"""The OpenID Password plugin expects the following. """The OpenID Password plugin expects the following.
:param username: Username used to authenticate :param username: Username used to authenticate
@ -332,7 +361,7 @@ class OidcPassword(_OidcBase):
:param password: Password used to authenticate :param password: Password used to authenticate
:type password: string :type password: string
""" """
super(OidcPassword, self).__init__( super().__init__(
auth_url=auth_url, auth_url=auth_url,
identity_provider=identity_provider, identity_provider=identity_provider,
protocol=protocol, protocol=protocol,
@ -341,7 +370,8 @@ class OidcPassword(_OidcBase):
access_token_endpoint=access_token_endpoint, access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint, discovery_endpoint=discovery_endpoint,
access_token_type=access_token_type, access_token_type=access_token_type,
**kwargs) **kwargs,
)
self.username = username self.username = username
self.password = password self.password = password
self.idp_otp_key = idp_otp_key self.idp_otp_key = idp_otp_key
@ -355,10 +385,12 @@ class OidcPassword(_OidcBase):
:returns: a python dictionary containing the payload to be exchanged :returns: a python dictionary containing the payload to be exchanged
:rtype: dict :rtype: dict
""" """
payload = {'username': self.username, payload = {
'username': self.username,
'password': self.password, 'password': self.password,
'scope': self.scope, 'scope': self.scope,
'client_id': self.client_id} 'client_id': self.client_id,
}
self.manage_otp_from_session_or_request_to_the_user(payload, session) self.manage_otp_from_session_or_request_to_the_user(payload, session)
@ -391,7 +423,8 @@ class OidcPassword(_OidcBase):
payload[self.idp_otp_key] = otp_from_session payload[self.idp_otp_key] = otp_from_session
else: else:
payload[self.idp_otp_key] = input( payload[self.idp_otp_key] = input(
"Please, enter the generated OTP code: ") "Please, enter the generated OTP code: "
)
setattr(session, 'otp', payload[self.idp_otp_key]) setattr(session, 'otp', payload[self.idp_otp_key])
@ -400,12 +433,18 @@ class OidcClientCredentials(_OidcBase):
grant_type = 'client_credentials' grant_type = 'client_credentials'
def __init__(self, auth_url, identity_provider, protocol, # nosec def __init__(
client_id, client_secret, self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret,
access_token_endpoint=None, access_token_endpoint=None,
discovery_endpoint=None, discovery_endpoint=None,
access_token_type='access_token', access_token_type='access_token',
**kwargs): **kwargs,
):
"""The OpenID Client Credentials expects the following. """The OpenID Client Credentials expects the following.
:param client_id: Client ID used to authenticate :param client_id: Client ID used to authenticate
@ -414,7 +453,7 @@ class OidcClientCredentials(_OidcBase):
:param client_secret: Client Secret used to authenticate :param client_secret: Client Secret used to authenticate
:type password: string :type password: string
""" """
super(OidcClientCredentials, self).__init__( super().__init__(
auth_url=auth_url, auth_url=auth_url,
identity_provider=identity_provider, identity_provider=identity_provider,
protocol=protocol, protocol=protocol,
@ -423,7 +462,8 @@ class OidcClientCredentials(_OidcBase):
access_token_endpoint=access_token_endpoint, access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint, discovery_endpoint=discovery_endpoint,
access_token_type=access_token_type, access_token_type=access_token_type,
**kwargs) **kwargs,
)
def get_payload(self, session): def get_payload(self, session):
"""Get an authorization grant for the client credentials grant type. """Get an authorization grant for the client credentials grant type.
@ -443,12 +483,20 @@ class OidcAuthorizationCode(_OidcBase):
grant_type = 'authorization_code' grant_type = 'authorization_code'
def __init__(self, auth_url, identity_provider, protocol, # nosec def __init__(
client_id, client_secret, self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret,
access_token_endpoint=None, access_token_endpoint=None,
discovery_endpoint=None, discovery_endpoint=None,
access_token_type='access_token', access_token_type='access_token',
redirect_uri=None, code=None, **kwargs): redirect_uri=None,
code=None,
**kwargs,
):
"""The OpenID Authorization Code plugin expects the following. """The OpenID Authorization Code plugin expects the following.
:param redirect_uri: OpenID Connect Client Redirect URL :param redirect_uri: OpenID Connect Client Redirect URL
@ -458,7 +506,7 @@ class OidcAuthorizationCode(_OidcBase):
:type code: string :type code: string
""" """
super(OidcAuthorizationCode, self).__init__( super().__init__(
auth_url=auth_url, auth_url=auth_url,
identity_provider=identity_provider, identity_provider=identity_provider,
protocol=protocol, protocol=protocol,
@ -467,7 +515,8 @@ class OidcAuthorizationCode(_OidcBase):
access_token_endpoint=access_token_endpoint, access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint, discovery_endpoint=discovery_endpoint,
access_token_type=access_token_type, access_token_type=access_token_type,
**kwargs) **kwargs,
)
self.redirect_uri = redirect_uri self.redirect_uri = redirect_uri
self.code = code self.code = code
@ -488,8 +537,9 @@ class OidcAuthorizationCode(_OidcBase):
class OidcAccessToken(_OidcBase): class OidcAccessToken(_OidcBase):
"""Implementation for OpenID Connect access token reuse.""" """Implementation for OpenID Connect access token reuse."""
def __init__(self, auth_url, identity_provider, protocol, def __init__(
access_token, **kwargs): self, auth_url, identity_provider, protocol, access_token, **kwargs
):
"""The OpenID Connect plugin based on the Access Token. """The OpenID Connect plugin based on the Access Token.
It expects the following: It expects the following:
@ -507,13 +557,16 @@ class OidcAccessToken(_OidcBase):
:param access_token: OpenID Connect Access token :param access_token: OpenID Connect Access token
:type access_token: string :type access_token: string
""" """
super(OidcAccessToken, self).__init__(auth_url, identity_provider, super().__init__(
auth_url,
identity_provider,
protocol, protocol,
client_id=None, client_id=None,
client_secret=None, client_secret=None,
access_token_endpoint=None, access_token_endpoint=None,
access_token_type=None, access_token_type=None,
**kwargs) **kwargs,
)
self.access_token = access_token self.access_token = access_token
def get_payload(self, session): def get_payload(self, session):
@ -546,13 +599,20 @@ class OidcDeviceAuthorization(_OidcBase):
grant_type = "urn:ietf:params:oauth:grant-type:device_code" grant_type = "urn:ietf:params:oauth:grant-type:device_code"
HEADER_X_FORM = {"Content-Type": "application/x-www-form-urlencoded"} HEADER_X_FORM = {"Content-Type": "application/x-www-form-urlencoded"}
def __init__(self, auth_url, identity_provider, protocol, # nosec def __init__(
client_id, client_secret=None, self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret=None,
access_token_endpoint=None, access_token_endpoint=None,
device_authorization_endpoint=None, device_authorization_endpoint=None,
discovery_endpoint=None, discovery_endpoint=None,
code_challenge=None, code_challenge_method=None, code_challenge=None,
**kwargs): code_challenge_method=None,
**kwargs,
):
"""The OAuth 2.0 Device Authorization plugin expects the following. """The OAuth 2.0 Device Authorization plugin expects the following.
:param device_authorization_endpoint: OAuth 2.0 Device Authorization :param device_authorization_endpoint: OAuth 2.0 Device Authorization
@ -571,7 +631,7 @@ class OidcDeviceAuthorization(_OidcBase):
self.device_authorization_endpoint = device_authorization_endpoint self.device_authorization_endpoint = device_authorization_endpoint
self.code_challenge_method = code_challenge_method self.code_challenge_method = code_challenge_method
super(OidcDeviceAuthorization, self).__init__( super().__init__(
auth_url=auth_url, auth_url=auth_url,
identity_provider=identity_provider, identity_provider=identity_provider,
protocol=protocol, protocol=protocol,
@ -580,7 +640,8 @@ class OidcDeviceAuthorization(_OidcBase):
access_token_endpoint=access_token_endpoint, access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint, discovery_endpoint=discovery_endpoint,
access_token_type=self.access_token_type, access_token_type=self.access_token_type,
**kwargs) **kwargs,
)
def _get_device_authorization_endpoint(self, session): def _get_device_authorization_endpoint(self, session):
"""Get the endpoint for the OAuth 2.0 Device Authorization flow. """Get the endpoint for the OAuth 2.0 Device Authorization flow.
@ -639,8 +700,9 @@ class OidcDeviceAuthorization(_OidcBase):
:returns: a python dictionary containing the payload to be exchanged :returns: a python dictionary containing the payload to be exchanged
:rtype: dict :rtype: dict
""" """
device_authz_endpoint = \ device_authz_endpoint = self._get_device_authorization_endpoint(
self._get_device_authorization_endpoint(session) session
)
if self.client_secret: if self.client_secret:
client_auth = (self.client_id, self.client_secret) client_auth = (self.client_id, self.client_secret)
@ -651,8 +713,9 @@ class OidcDeviceAuthorization(_OidcBase):
if self.code_challenge_method: if self.code_challenge_method:
self.code_challenge = self._generate_pkce_challenge() self.code_challenge = self._generate_pkce_challenge()
payload.setdefault('code_challenge_method', payload.setdefault(
self.code_challenge_method) 'code_challenge_method', self.code_challenge_method
)
payload.setdefault('code_challenge', self.code_challenge) payload.setdefault('code_challenge', self.code_challenge)
encoded_payload = urlparse.urlencode(payload) encoded_payload = urlparse.urlencode(payload)
@ -660,19 +723,24 @@ class OidcDeviceAuthorization(_OidcBase):
sanitized_payload = self._sanitize(payload) sanitized_payload = self._sanitize(payload)
_logger.debug( _logger.debug(
"Making OpenID-Connect authentication request to %s with " "Making OpenID-Connect authentication request to %s with "
"data %s", device_authz_endpoint, sanitized_payload "data %s",
device_authz_endpoint,
sanitized_payload,
) )
op_response = session.post(device_authz_endpoint, op_response = session.post(
device_authz_endpoint,
requests_auth=client_auth, requests_auth=client_auth,
headers=self.HEADER_X_FORM, headers=self.HEADER_X_FORM,
data=encoded_payload, data=encoded_payload,
log=False, log=False,
authenticated=False) authenticated=False,
)
if _logger.isEnabledFor(logging.DEBUG): if _logger.isEnabledFor(logging.DEBUG):
sanitized_response = self._sanitize(op_response.json()) sanitized_response = self._sanitize(op_response.json())
_logger.debug( _logger.debug(
"OpenID-Connect authentication response from %s is %s", "OpenID-Connect authentication response from %s is %s",
device_authz_endpoint, sanitized_response device_authz_endpoint,
sanitized_response,
) )
self.expires_in = int(op_response.json()["expires_in"]) self.expires_in = int(op_response.json()["expires_in"])
@ -681,8 +749,9 @@ class OidcDeviceAuthorization(_OidcBase):
self.interval = int(op_response.json()["interval"]) self.interval = int(op_response.json()["interval"])
self.user_code = op_response.json()["user_code"] self.user_code = op_response.json()["user_code"]
self.verification_uri = op_response.json()["verification_uri"] self.verification_uri = op_response.json()["verification_uri"]
self.verification_uri_complete = \ self.verification_uri_complete = op_response.json()[
op_response.json()["verification_uri_complete"] "verification_uri_complete"
]
payload = {'device_code': self.device_code} payload = {'device_code': self.device_code}
if self.code_challenge_method: if self.code_challenge_method:
@ -701,8 +770,10 @@ class OidcDeviceAuthorization(_OidcBase):
'device_code': self.device_code} 'device_code': self.device_code}
:type payload: dict :type payload: dict
""" """
_logger.warning(f"To authenticate please go to: " _logger.warning(
f"{self.verification_uri_complete}") f"To authenticate please go to: "
f"{self.verification_uri_complete}"
)
if self.client_secret: if self.client_secret:
client_auth = (self.client_id, self.client_secret) client_auth = (self.client_id, self.client_secret)
@ -720,19 +791,23 @@ class OidcDeviceAuthorization(_OidcBase):
_logger.debug( _logger.debug(
"Making OpenID-Connect authentication request to %s " "Making OpenID-Connect authentication request to %s "
"with data %s", "with data %s",
access_token_endpoint, sanitized_payload access_token_endpoint,
sanitized_payload,
) )
op_response = session.post(access_token_endpoint, op_response = session.post(
access_token_endpoint,
requests_auth=client_auth, requests_auth=client_auth,
data=encoded_payload, data=encoded_payload,
headers=self.HEADER_X_FORM, headers=self.HEADER_X_FORM,
log=False, log=False,
authenticated=False) authenticated=False,
)
if _logger.isEnabledFor(logging.DEBUG): if _logger.isEnabledFor(logging.DEBUG):
sanitized_response = self._sanitize(op_response.json()) sanitized_response = self._sanitize(op_response.json())
_logger.debug( _logger.debug(
"OpenID-Connect authentication response from %s is %s", "OpenID-Connect authentication response from %s is %s",
access_token_endpoint, sanitized_response access_token_endpoint,
sanitized_response,
) )
except exceptions.http.BadRequest as exc: except exceptions.http.BadRequest as exc:
error = exc.response.json().get("error") error = exc.response.json().get("error")

View File

@ -26,11 +26,13 @@ class PasswordMethod(base.AuthMethod):
:param string user_domain_name: User's domain name for authentication. :param string user_domain_name: User's domain name for authentication.
""" """
_method_parameters = ['user_id', _method_parameters = [
'user_id',
'username', 'username',
'user_domain_id', 'user_domain_id',
'user_domain_name', 'user_domain_name',
'password'] 'password',
]
def get_auth_data(self, session, auth, headers, **kwargs): def get_auth_data(self, session, auth, headers, **kwargs):
user = {'password': self.password} user = {'password': self.password}
@ -48,8 +50,9 @@ class PasswordMethod(base.AuthMethod):
return 'password', {'user': user} return 'password', {'user': user}
def get_cache_id_elements(self): def get_cache_id_elements(self):
return dict(('password_%s' % p, getattr(self, p)) return {
for p in self._method_parameters) f'password_{p}': getattr(self, p) for p in self._method_parameters
}
class Password(base.AuthConstructor): class Password(base.AuthConstructor):

View File

@ -13,7 +13,7 @@
from keystoneauth1.identity.v3 import base from keystoneauth1.identity.v3 import base
__all__ = ('ReceiptMethod', ) __all__ = ('ReceiptMethod',)
class ReceiptMethod(base.AuthMethod): class ReceiptMethod(base.AuthMethod):

View File

@ -51,4 +51,4 @@ class Token(base.AuthConstructor):
_auth_method_class = TokenMethod _auth_method_class = TokenMethod
def __init__(self, auth_url, token, **kwargs): def __init__(self, auth_url, token, **kwargs):
super(Token, self).__init__(auth_url, token=token, **kwargs) super().__init__(auth_url, token=token, **kwargs)

View File

@ -27,13 +27,16 @@ class TokenlessAuth(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
the provided HTTPS certificate along with the scope information. the provided HTTPS certificate along with the scope information.
""" """
def __init__(self, auth_url, def __init__(
self,
auth_url,
domain_id=None, domain_id=None,
domain_name=None, domain_name=None,
project_id=None, project_id=None,
project_name=None, project_name=None,
project_domain_id=None, project_domain_id=None,
project_domain_name=None): project_domain_name=None,
):
"""A init method for TokenlessAuth. """A init method for TokenlessAuth.
:param string auth_url: Identity service endpoint for authentication. :param string auth_url: Identity service endpoint for authentication.
@ -75,23 +78,23 @@ class TokenlessAuth(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
elif self.project_name: elif self.project_name:
scope_headers['X-Project-Name'] = self.project_name scope_headers['X-Project-Name'] = self.project_name
if self.project_domain_id: if self.project_domain_id:
scope_headers['X-Project-Domain-Id'] = ( scope_headers['X-Project-Domain-Id'] = self.project_domain_id
self.project_domain_id)
elif self.project_domain_name: elif self.project_domain_name:
scope_headers['X-Project-Domain-Name'] = ( scope_headers['X-Project-Domain-Name'] = (
self.project_domain_name) self.project_domain_name
)
else: else:
LOG.warning( LOG.warning(
'Neither Project Domain ID nor Project Domain Name was ' 'Neither Project Domain ID nor Project Domain Name was '
'provided.') 'provided.'
)
return None return None
elif self.domain_id: elif self.domain_id:
scope_headers['X-Domain-Id'] = self.domain_id scope_headers['X-Domain-Id'] = self.domain_id
elif self.domain_name: elif self.domain_name:
scope_headers['X-Domain-Name'] = self.domain_name scope_headers['X-Domain-Name'] = self.domain_name
else: else:
LOG.warning( LOG.warning('Neither Project nor Domain scope was provided.')
'Neither Project nor Domain scope was provided.')
return None return None
return scope_headers return scope_headers
@ -106,8 +109,10 @@ class TokenlessAuth(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
:return: A valid endpoint URL or None if not available. :return: A valid endpoint URL or None if not available.
:rtype: string or None :rtype: string or None
""" """
if (service_type is plugin.AUTH_INTERFACE or if (
service_type.lower() == 'identity'): service_type is plugin.AUTH_INTERFACE
or service_type.lower() == 'identity'
):
return self.auth_url return self.auth_url
return None return None

View File

@ -28,11 +28,13 @@ class TOTPMethod(base.AuthMethod):
:param string user_domain_name: User's domain name for authentication. :param string user_domain_name: User's domain name for authentication.
""" """
_method_parameters = ['user_id', _method_parameters = [
'user_id',
'username', 'username',
'user_domain_id', 'user_domain_id',
'user_domain_name', 'user_domain_name',
'passcode'] 'passcode',
]
def get_auth_data(self, session, auth, headers, **kwargs): def get_auth_data(self, session, auth, headers, **kwargs):
user = {'passcode': self.passcode} user = {'passcode': self.passcode}
@ -54,8 +56,7 @@ class TOTPMethod(base.AuthMethod):
# the key in caching. # the key in caching.
params = copy.copy(self._method_parameters) params = copy.copy(self._method_parameters)
params.remove('passcode') params.remove('passcode')
return dict(('totp_%s' % p, getattr(self, p)) return {f'totp_{p}': getattr(self, p) for p in self._method_parameters}
for p in self._method_parameters)
class TOTP(base.AuthConstructor): class TOTP(base.AuthConstructor):

View File

@ -37,7 +37,8 @@ get_session_conf_options = session.get_conf_options
register_adapter_argparse_arguments = adapter.register_argparse_arguments register_adapter_argparse_arguments = adapter.register_argparse_arguments
register_service_adapter_argparse_arguments = ( register_service_adapter_argparse_arguments = (
adapter.register_service_argparse_arguments) adapter.register_service_argparse_arguments
)
register_adapter_conf_options = adapter.register_conf_options register_adapter_conf_options = adapter.register_conf_options
load_adapter_from_conf_options = adapter.load_from_conf_options load_adapter_from_conf_options = adapter.load_from_conf_options
get_adapter_conf_options = adapter.get_conf_options get_adapter_conf_options = adapter.get_conf_options
@ -50,38 +51,32 @@ __all__ = (
'get_available_plugin_loaders', 'get_available_plugin_loaders',
'get_plugin_loader', 'get_plugin_loader',
'PLUGIN_NAMESPACE', 'PLUGIN_NAMESPACE',
# loading.identity # loading.identity
'BaseIdentityLoader', 'BaseIdentityLoader',
'BaseV2Loader', 'BaseV2Loader',
'BaseV3Loader', 'BaseV3Loader',
'BaseFederationLoader', 'BaseFederationLoader',
'BaseGenericLoader', 'BaseGenericLoader',
# auth cli # auth cli
'register_auth_argparse_arguments', 'register_auth_argparse_arguments',
'load_auth_from_argparse_arguments', 'load_auth_from_argparse_arguments',
# auth conf # auth conf
'get_auth_common_conf_options', 'get_auth_common_conf_options',
'get_auth_plugin_conf_options', 'get_auth_plugin_conf_options',
'register_auth_conf_options', 'register_auth_conf_options',
'load_auth_from_conf_options', 'load_auth_from_conf_options',
# session # session
'register_session_argparse_arguments', 'register_session_argparse_arguments',
'load_session_from_argparse_arguments', 'load_session_from_argparse_arguments',
'register_session_conf_options', 'register_session_conf_options',
'load_session_from_conf_options', 'load_session_from_conf_options',
'get_session_conf_options', 'get_session_conf_options',
# adapter # adapter
'register_adapter_argparse_arguments', 'register_adapter_argparse_arguments',
'register_service_adapter_argparse_arguments', 'register_service_adapter_argparse_arguments',
'register_adapter_conf_options', 'register_adapter_conf_options',
'load_adapter_from_conf_options', 'load_adapter_from_conf_options',
'get_adapter_conf_options', 'get_adapter_conf_options',
# loading.opts # loading.opts
'Opt', 'Opt',
) )

View File

@ -32,15 +32,21 @@ class AdminToken(loading.BaseLoader):
return token_endpoint.Token return token_endpoint.Token
def get_options(self): def get_options(self):
options = super(AdminToken, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('endpoint', [
loading.Opt(
'endpoint',
deprecated=[loading.Opt('url')], deprecated=[loading.Opt('url')],
help='The endpoint that will always be used'), help='The endpoint that will always be used',
loading.Opt('token', ),
loading.Opt(
'token',
secret=True, secret=True,
help='The token that will always be used'), help='The token that will always be used',
]) ),
]
)
return options return options

View File

@ -31,18 +31,25 @@ class HTTPBasicAuth(loading.BaseLoader):
return http_basic.HTTPBasicAuth return http_basic.HTTPBasicAuth
def get_options(self): def get_options(self):
options = super(HTTPBasicAuth, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('username', [
loading.Opt(
'username',
help='Username', help='Username',
deprecated=[loading.Opt('user-name')]), deprecated=[loading.Opt('user-name')],
loading.Opt('password', ),
loading.Opt(
'password',
secret=True, secret=True,
prompt='Password: ', prompt='Password: ',
help="User's password"), help="User's password",
loading.Opt('endpoint', ),
help='The endpoint that will always be used'), loading.Opt(
]) 'endpoint', help='The endpoint that will always be used'
),
]
)
return options return options

View File

@ -33,12 +33,15 @@ class Token(loading.BaseGenericLoader):
return identity.Token return identity.Token
def get_options(self): def get_options(self):
options = super(Token, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('token', secret=True, [
help='Token to authenticate with'), loading.Opt(
]) 'token', secret=True, help='Token to authenticate with'
)
]
)
return options return options
@ -59,17 +62,23 @@ class Password(loading.BaseGenericLoader):
return identity.Password return identity.Password
def get_options(cls): def get_options(cls):
options = super(Password, cls).get_options() options = super().get_options()
options.extend([ options.extend(
[
loading.Opt('user-id', help='User id'), loading.Opt('user-id', help='User id'),
loading.Opt('username', loading.Opt(
'username',
help='Username', help='Username',
deprecated=[loading.Opt('user-name')]), deprecated=[loading.Opt('user-name')],
),
loading.Opt('user-domain-id', help="User's domain id"), loading.Opt('user-domain-id', help="User's domain id"),
loading.Opt('user-domain-name', help="User's domain name"), loading.Opt('user-domain-name', help="User's domain name"),
loading.Opt('password', loading.Opt(
'password',
secret=True, secret=True,
prompt='Password: ', prompt='Password: ',
help="User's password"), help="User's password",
]) ),
]
)
return options return options

View File

@ -15,39 +15,41 @@ from keystoneauth1 import loading
class Token(loading.BaseV2Loader): class Token(loading.BaseV2Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V2Token return identity.V2Token
def get_options(self): def get_options(self):
options = super(Token, self).get_options() options = super().get_options()
options.extend([ options.extend([loading.Opt('token', secret=True, help='Token')])
loading.Opt('token', secret=True, help='Token'),
])
return options return options
class Password(loading.BaseV2Loader): class Password(loading.BaseV2Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V2Password return identity.V2Password
def get_options(self): def get_options(self):
options = super(Password, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('username', [
loading.Opt(
'username',
deprecated=[loading.Opt('user-name')], deprecated=[loading.Opt('user-name')],
help='Username to login with'), help='Username to login with',
),
loading.Opt('user-id', help='User ID to login with'), loading.Opt('user-id', help='User ID to login with'),
loading.Opt('password', loading.Opt(
'password',
secret=True, secret=True,
prompt='Password: ', prompt='Password: ',
help='Password to use'), help='Password to use',
]) ),
]
)
return options return options

View File

@ -16,334 +16,409 @@ from keystoneauth1 import loading
def _add_common_identity_options(options): def _add_common_identity_options(options):
options.extend([ options.extend(
[
loading.Opt('user-id', help='User ID'), loading.Opt('user-id', help='User ID'),
loading.Opt('username', loading.Opt(
'username',
help='Username', help='Username',
deprecated=[loading.Opt('user-name')]), deprecated=[loading.Opt('user-name')],
),
loading.Opt('user-domain-id', help="User's domain id"), loading.Opt('user-domain-id', help="User's domain id"),
loading.Opt('user-domain-name', help="User's domain name"), loading.Opt('user-domain-name', help="User's domain name"),
]) ]
)
def _assert_identity_options(options): def _assert_identity_options(options):
if (options.get('username') and if options.get('username') and not (
not (options.get('user_domain_name') or options.get('user_domain_name') or options.get('user_domain_id')
options.get('user_domain_id'))): ):
m = "You have provided a username. In the V3 identity API a " \ m = (
"username is only unique within a domain so you must " \ "You have provided a username. In the V3 identity API a "
"username is only unique within a domain so you must "
"also provide either a user_domain_id or user_domain_name." "also provide either a user_domain_id or user_domain_name."
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
class Password(loading.BaseV3Loader): class Password(loading.BaseV3Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3Password return identity.V3Password
def get_options(self): def get_options(self):
options = super(Password, self).get_options() options = super().get_options()
_add_common_identity_options(options) _add_common_identity_options(options)
options.extend([ options.extend(
loading.Opt('password', [
loading.Opt(
'password',
secret=True, secret=True,
prompt='Password: ', prompt='Password: ',
help="User's password"), help="User's password",
]) )
]
)
return options return options
def load_from_options(self, **kwargs): def load_from_options(self, **kwargs):
_assert_identity_options(kwargs) _assert_identity_options(kwargs)
return super(Password, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class Token(loading.BaseV3Loader): class Token(loading.BaseV3Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3Token return identity.V3Token
def get_options(self): def get_options(self):
options = super(Token, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('token', [
secret=True, loading.Opt(
help='Token to authenticate with'), 'token', secret=True, help='Token to authenticate with'
]) )
]
)
return options return options
class _OpenIDConnectBase(loading.BaseFederationLoader): class _OpenIDConnectBase(loading.BaseFederationLoader):
def load_from_options(self, **kwargs): def load_from_options(self, **kwargs):
if not (kwargs.get('access_token_endpoint') or if not (
kwargs.get('discovery_endpoint')): kwargs.get('access_token_endpoint')
m = ("You have to specify either an 'access-token-endpoint' or " or kwargs.get('discovery_endpoint')
"a 'discovery-endpoint'.") ):
m = (
"You have to specify either an 'access-token-endpoint' or "
"a 'discovery-endpoint'."
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(_OpenIDConnectBase, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
def get_options(self): def get_options(self):
options = super(_OpenIDConnectBase, self).get_options() options = super().get_options()
options.extend([ options.extend(
[
loading.Opt('client-id', help='OAuth 2.0 Client ID'), loading.Opt('client-id', help='OAuth 2.0 Client ID'),
loading.Opt('client-secret', secret=True, loading.Opt(
help='OAuth 2.0 Client Secret'), 'client-secret',
loading.Opt('openid-scope', default="openid profile", secret=True,
help='OAuth 2.0 Client Secret',
),
loading.Opt(
'openid-scope',
default="openid profile",
dest="scope", dest="scope",
help='OpenID Connect scope that is requested from ' help='OpenID Connect scope that is requested from '
'authorization server. Note that the OpenID ' 'authorization server. Note that the OpenID '
'Connect specification states that "openid" ' 'Connect specification states that "openid" '
'must be always specified.'), 'must be always specified.',
loading.Opt('access-token-endpoint', ),
loading.Opt(
'access-token-endpoint',
help='OpenID Connect Provider Token Endpoint. Note ' help='OpenID Connect Provider Token Endpoint. Note '
'that if a discovery document is being passed this ' 'that if a discovery document is being passed this '
'option will override the endpoint provided by the ' 'option will override the endpoint provided by the '
'server in the discovery document.'), 'server in the discovery document.',
loading.Opt('discovery-endpoint', ),
loading.Opt(
'discovery-endpoint',
help='OpenID Connect Discovery Document URL. ' help='OpenID Connect Discovery Document URL. '
'The discovery document will be used to obtain the ' 'The discovery document will be used to obtain the '
'values of the access token endpoint and the ' 'values of the access token endpoint and the '
'authentication endpoint. This URL should look like ' 'authentication endpoint. This URL should look like '
'https://idp.example.org/.well-known/' 'https://idp.example.org/.well-known/'
'openid-configuration'), 'openid-configuration',
loading.Opt('access-token-type', ),
loading.Opt(
'access-token-type',
help='OAuth 2.0 Authorization Server Introspection ' help='OAuth 2.0 Authorization Server Introspection '
'token type, it is used to decide which type ' 'token type, it is used to decide which type '
'of token will be used when processing token ' 'of token will be used when processing token '
'introspection. Valid values are: ' 'introspection. Valid values are: '
'"access_token" or "id_token"'), '"access_token" or "id_token"',
]) ),
]
)
return options return options
class OpenIDConnectClientCredentials(_OpenIDConnectBase): class OpenIDConnectClientCredentials(_OpenIDConnectBase):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3OidcClientCredentials return identity.V3OidcClientCredentials
def get_options(self): def get_options(self):
options = super(OpenIDConnectClientCredentials, self).get_options() options = super().get_options()
return options return options
class OpenIDConnectPassword(_OpenIDConnectBase): class OpenIDConnectPassword(_OpenIDConnectBase):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3OidcPassword return identity.V3OidcPassword
def get_options(self): def get_options(self):
options = super(OpenIDConnectPassword, self).get_options() options = super().get_options()
options.extend([ options.extend(
[
loading.Opt('username', help='Username', required=True), loading.Opt('username', help='Username', required=True),
loading.Opt('password', secret=True, loading.Opt(
help='Password', required=True), 'password', secret=True, help='Password', required=True
loading.Opt('idp_otp_key', ),
loading.Opt(
'idp_otp_key',
help='A key to be used in the Identity Provider access' help='A key to be used in the Identity Provider access'
' token endpoint to pass the OTP value. ' ' token endpoint to pass the OTP value. '
'E.g. totp'), 'E.g. totp',
]) ),
]
)
return options return options
class OpenIDConnectAuthorizationCode(_OpenIDConnectBase): class OpenIDConnectAuthorizationCode(_OpenIDConnectBase):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3OidcAuthorizationCode return identity.V3OidcAuthorizationCode
def get_options(self): def get_options(self):
options = super(OpenIDConnectAuthorizationCode, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('redirect-uri', help='OpenID Connect Redirect URL'), [
loading.Opt('code', secret=True, required=True, loading.Opt(
'redirect-uri', help='OpenID Connect Redirect URL'
),
loading.Opt(
'code',
secret=True,
required=True,
deprecated=[loading.Opt('authorization-code')], deprecated=[loading.Opt('authorization-code')],
help='OAuth 2.0 Authorization Code'), help='OAuth 2.0 Authorization Code',
]) ),
]
)
return options return options
class OpenIDConnectAccessToken(loading.BaseFederationLoader): class OpenIDConnectAccessToken(loading.BaseFederationLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3OidcAccessToken return identity.V3OidcAccessToken
def get_options(self): def get_options(self):
options = super(OpenIDConnectAccessToken, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('access-token', secret=True, required=True, [
help='OAuth 2.0 Access Token'), loading.Opt(
]) 'access-token',
secret=True,
required=True,
help='OAuth 2.0 Access Token',
)
]
)
return options return options
class OpenIDConnectDeviceAuthorization(_OpenIDConnectBase): class OpenIDConnectDeviceAuthorization(_OpenIDConnectBase):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3OidcDeviceAuthorization return identity.V3OidcDeviceAuthorization
def get_options(self): def get_options(self):
options = super(OpenIDConnectDeviceAuthorization, self).get_options() options = super().get_options()
# RFC 8628 doesn't support id_token # RFC 8628 doesn't support id_token
options = [opt for opt in options if opt.name != 'access-token-type'] options = [opt for opt in options if opt.name != 'access-token-type']
options.extend([ options.extend(
loading.Opt('device-authorization-endpoint', [
loading.Opt(
'device-authorization-endpoint',
help='OAuth 2.0 Device Authorization Endpoint. Note ' help='OAuth 2.0 Device Authorization Endpoint. Note '
'that if a discovery document is being passed this ' 'that if a discovery document is being passed this '
'option will override the endpoint provided by the ' 'option will override the endpoint provided by the '
'server in the discovery document.'), 'server in the discovery document.',
loading.Opt('code-challenge-method', ),
help='PKCE Challenge Method (RFC 7636)'), loading.Opt(
]) 'code-challenge-method',
help='PKCE Challenge Method (RFC 7636)',
),
]
)
return options return options
class TOTP(loading.BaseV3Loader): class TOTP(loading.BaseV3Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3TOTP return identity.V3TOTP
def get_options(self): def get_options(self):
options = super(TOTP, self).get_options() options = super().get_options()
_add_common_identity_options(options) _add_common_identity_options(options)
options.extend([ options.extend(
[
loading.Opt( loading.Opt(
'passcode', 'passcode',
secret=True, secret=True,
prompt='TOTP passcode: ', prompt='TOTP passcode: ',
help="User's TOTP passcode"), help="User's TOTP passcode",
]) )
]
)
return options return options
def load_from_options(self, **kwargs): def load_from_options(self, **kwargs):
_assert_identity_options(kwargs) _assert_identity_options(kwargs)
return super(TOTP, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class TokenlessAuth(loading.BaseLoader): class TokenlessAuth(loading.BaseLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3TokenlessAuth return identity.V3TokenlessAuth
def get_options(self): def get_options(self):
options = super(TokenlessAuth, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('auth-url', required=True, [
help='Authentication URL'), loading.Opt(
'auth-url', required=True, help='Authentication URL'
),
loading.Opt('domain-id', help='Domain ID to scope to'), loading.Opt('domain-id', help='Domain ID to scope to'),
loading.Opt('domain-name', help='Domain name to scope to'), loading.Opt('domain-name', help='Domain name to scope to'),
loading.Opt('project-id', help='Project ID to scope to'), loading.Opt('project-id', help='Project ID to scope to'),
loading.Opt('project-name', help='Project name to scope to'), loading.Opt('project-name', help='Project name to scope to'),
loading.Opt('project-domain-id', loading.Opt(
help='Domain ID containing project'), 'project-domain-id', help='Domain ID containing project'
loading.Opt('project-domain-name', ),
help='Domain name containing project'), loading.Opt(
]) 'project-domain-name',
help='Domain name containing project',
),
]
)
return options return options
def load_from_options(self, **kwargs): def load_from_options(self, **kwargs):
if (not kwargs.get('domain_id') and if (
not kwargs.get('domain_name') and not kwargs.get('domain_id')
not kwargs.get('project_id') and and not kwargs.get('domain_name')
not kwargs.get('project_name') or and not kwargs.get('project_id')
(kwargs.get('project_name') and and not kwargs.get('project_name')
not (kwargs.get('project_domain_name') or or (
kwargs.get('project_domain_id')))): kwargs.get('project_name')
m = ('You need to provide either a domain_name, domain_id, ' and not (
kwargs.get('project_domain_name')
or kwargs.get('project_domain_id')
)
)
):
m = (
'You need to provide either a domain_name, domain_id, '
'project_id or project_name. ' 'project_id or project_name. '
'If you have provided a project_name, in the V3 identity ' 'If you have provided a project_name, in the V3 identity '
'API a project_name is only unique within a domain so ' 'API a project_name is only unique within a domain so '
'you must also provide either a project_domain_id or ' 'you must also provide either a project_domain_id or '
'project_domain_name.') 'project_domain_name.'
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(TokenlessAuth, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class ApplicationCredential(loading.BaseV3Loader): class ApplicationCredential(loading.BaseV3Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3ApplicationCredential return identity.V3ApplicationCredential
def get_options(self): def get_options(self):
options = super(ApplicationCredential, self).get_options() options = super().get_options()
_add_common_identity_options(options) _add_common_identity_options(options)
options.extend([ options.extend(
loading.Opt('application_credential_secret', secret=True, [
loading.Opt(
'application_credential_secret',
secret=True,
required=True, required=True,
help="Application credential auth secret"), help="Application credential auth secret",
loading.Opt('application_credential_id', ),
help='Application credential ID'), loading.Opt(
loading.Opt('application_credential_name', 'application_credential_id',
help='Application credential name'), help='Application credential ID',
]) ),
loading.Opt(
'application_credential_name',
help='Application credential name',
),
]
)
return options return options
def load_from_options(self, **kwargs): def load_from_options(self, **kwargs):
_assert_identity_options(kwargs) _assert_identity_options(kwargs)
if (not kwargs.get('application_credential_id') and if not kwargs.get('application_credential_id') and not kwargs.get(
not kwargs.get('application_credential_name')): 'application_credential_name'
m = ('You must provide either an application credential ID or an ' ):
'application credential name and user.') m = (
'You must provide either an application credential ID or an '
'application credential name and user.'
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
if not kwargs.get('application_credential_secret'): if not kwargs.get('application_credential_secret'):
m = ('You must provide an auth secret.') m = 'You must provide an auth secret.'
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(ApplicationCredential, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class MultiFactor(loading.BaseV3Loader): class MultiFactor(loading.BaseV3Loader):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._methods = None self._methods = None
return super(MultiFactor, self).__init__(*args, **kwargs) return super().__init__(*args, **kwargs)
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3MultiFactor return identity.V3MultiFactor
def get_options(self): def get_options(self):
options = super(MultiFactor, self).get_options() options = super().get_options()
options.extend([ options.extend(
[
loading.Opt( loading.Opt(
'auth_methods', 'auth_methods',
required=True, required=True,
help="Methods to authenticate with."), help="Methods to authenticate with.",
]) )
]
)
if self._methods: if self._methods:
options_dict = {o.name: o for o in options} options_dict = {o.name: o for o in options}
@ -362,29 +437,36 @@ class MultiFactor(loading.BaseV3Loader):
self._methods = kwargs['auth_methods'] self._methods = kwargs['auth_methods']
return super(MultiFactor, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class OAuth2ClientCredential(loading.BaseV3Loader): class OAuth2ClientCredential(loading.BaseV3Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3OAuth2ClientCredential return identity.V3OAuth2ClientCredential
def get_options(self): def get_options(self):
options = super(OAuth2ClientCredential, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('oauth2_endpoint', [
loading.Opt(
'oauth2_endpoint',
required=True, required=True,
help='Endpoint for OAuth2.0'), help='Endpoint for OAuth2.0',
loading.Opt('oauth2_client_id', ),
loading.Opt(
'oauth2_client_id',
required=True, required=True,
help='Client id for OAuth2.0'), help='Client id for OAuth2.0',
loading.Opt('oauth2_client_secret', ),
loading.Opt(
'oauth2_client_secret',
secret=True, secret=True,
required=True, required=True,
help='Client secret for OAuth2.0'), help='Client secret for OAuth2.0',
]) ),
]
)
return options return options
@ -399,26 +481,31 @@ class OAuth2ClientCredential(loading.BaseV3Loader):
m = 'You must provide an OAuth2.0 client credential auth secret.' m = 'You must provide an OAuth2.0 client credential auth secret.'
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(OAuth2ClientCredential, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class OAuth2mTlsClientCredential(loading.BaseV3Loader): class OAuth2mTlsClientCredential(loading.BaseV3Loader):
@property @property
def plugin_class(self): def plugin_class(self):
return identity.V3OAuth2mTlsClientCredential return identity.V3OAuth2mTlsClientCredential
def get_options(self): def get_options(self):
options = super(OAuth2mTlsClientCredential, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('oauth2-endpoint', [
loading.Opt(
'oauth2-endpoint',
required=True, required=True,
help='Endpoint for OAuth2.0 Mutual-TLS Authorization'), help='Endpoint for OAuth2.0 Mutual-TLS Authorization',
loading.Opt('oauth2-client-id', ),
loading.Opt(
'oauth2-client-id',
required=True, required=True,
help='Client credential ID for OAuth2.0 Mutual-TLS ' help='Client credential ID for OAuth2.0 Mutual-TLS '
'Authorization') 'Authorization',
]) ),
]
)
return options return options
def load_from_options(self, **kwargs): def load_from_options(self, **kwargs):
@ -426,8 +513,9 @@ class OAuth2mTlsClientCredential(loading.BaseV3Loader):
m = 'You must provide an OAuth2.0 Mutual-TLS endpoint.' m = 'You must provide an OAuth2.0 Mutual-TLS endpoint.'
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
if not kwargs.get('oauth2_client_id'): if not kwargs.get('oauth2_client_id'):
m = ('You must provide an client credential ID for ' m = (
'OAuth2.0 Mutual-TLS Authorization.') 'You must provide an client credential ID for '
'OAuth2.0 Mutual-TLS Authorization.'
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(OAuth2mTlsClientCredential, return super().load_from_options(**kwargs)
self).load_from_options(**kwargs)

View File

@ -31,11 +31,14 @@ class NoAuth(loading.BaseLoader):
return noauth.NoAuth return noauth.NoAuth
def get_options(self): def get_options(self):
options = super(NoAuth, self).get_options() options = super().get_options()
options.extend([ options.extend(
loading.Opt('endpoint', [
help='The endpoint that will always be used'), loading.Opt(
]) 'endpoint', help='The endpoint that will always be used'
)
]
)
return options return options

View File

@ -32,9 +32,11 @@ def get_oslo_config():
cfg = _NOT_FOUND cfg = _NOT_FOUND
if cfg is _NOT_FOUND: if cfg is _NOT_FOUND:
raise ImportError("oslo.config is not an automatic dependency of " raise ImportError(
"oslo.config is not an automatic dependency of "
"keystoneauth. If you wish to use oslo.config " "keystoneauth. If you wish to use oslo.config "
"you need to import it into your application's " "you need to import it into your application's "
"requirements file. ") "requirements file. "
)
return cfg return cfg

View File

@ -15,15 +15,16 @@ from keystoneauth1.loading import _utils
from keystoneauth1.loading import base from keystoneauth1.loading import base
__all__ = ('register_argparse_arguments', __all__ = (
'register_argparse_arguments',
'register_service_argparse_arguments', 'register_service_argparse_arguments',
'register_conf_options', 'register_conf_options',
'load_from_conf_options', 'load_from_conf_options',
'get_conf_options') 'get_conf_options',
)
class Adapter(base.BaseLoader): class Adapter(base.BaseLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return adapter.Adapter return adapter.Adapter
@ -76,7 +77,7 @@ class Adapter(base.BaseLoader):
the new ``endpoint_override`` option name:: the new ``endpoint_override`` option name::
old_opt = oslo_cfg.DeprecatedOpt('api_endpoint', 'old_group') old_opt = oslo_cfg.DeprecatedOpt('api_endpoint', 'old_group')
deprecated_opts={'endpoint_override': [old_opt]} deprecated_opts = {'endpoint_override': [old_opt]}
:returns: A list of oslo_config options. :returns: A list of oslo_config options.
""" """
@ -86,106 +87,127 @@ class Adapter(base.BaseLoader):
deprecated_opts = {} deprecated_opts = {}
# This is goofy, but need to support hyphens *or* underscores # This is goofy, but need to support hyphens *or* underscores
deprecated_opts = {name.replace('_', '-'): opt deprecated_opts = {
for name, opt in deprecated_opts.items()} name.replace('_', '-'): opt
for name, opt in deprecated_opts.items()
}
opts = [cfg.StrOpt('service-type', opts = [
cfg.StrOpt(
'service-type',
deprecated_opts=deprecated_opts.get('service-type'), deprecated_opts=deprecated_opts.get('service-type'),
help='The default service_type for endpoint URL ' help='The default service_type for endpoint URL discovery.',
'discovery.'), ),
cfg.StrOpt('service-name', cfg.StrOpt(
'service-name',
deprecated_opts=deprecated_opts.get('service-name'), deprecated_opts=deprecated_opts.get('service-name'),
help='The default service_name for endpoint URL ' help='The default service_name for endpoint URL discovery.',
'discovery.'), ),
cfg.ListOpt('valid-interfaces', cfg.ListOpt(
deprecated_opts=deprecated_opts.get( 'valid-interfaces',
'valid-interfaces'), deprecated_opts=deprecated_opts.get('valid-interfaces'),
help='List of interfaces, in order of preference, ' help='List of interfaces, in order of preference, '
'for endpoint URL.'), 'for endpoint URL.',
cfg.StrOpt('region-name', ),
cfg.StrOpt(
'region-name',
deprecated_opts=deprecated_opts.get('region-name'), deprecated_opts=deprecated_opts.get('region-name'),
help='The default region_name for endpoint URL ' help='The default region_name for endpoint URL discovery.',
'discovery.'), ),
cfg.StrOpt('endpoint-override', cfg.StrOpt(
deprecated_opts=deprecated_opts.get( 'endpoint-override',
'endpoint-override'), deprecated_opts=deprecated_opts.get('endpoint-override'),
help='Always use this endpoint URL for requests ' help='Always use this endpoint URL for requests '
'for this client. NOTE: The unversioned ' 'for this client. NOTE: The unversioned '
'endpoint should be specified here; to ' 'endpoint should be specified here; to '
'request a particular API version, use the ' 'request a particular API version, use the '
'`version`, `min-version`, and/or ' '`version`, `min-version`, and/or '
'`max-version` options.'), '`max-version` options.',
cfg.StrOpt('version', ),
cfg.StrOpt(
'version',
deprecated_opts=deprecated_opts.get('version'), deprecated_opts=deprecated_opts.get('version'),
help='Minimum Major API version within a given ' help='Minimum Major API version within a given '
'Major API version for endpoint URL ' 'Major API version for endpoint URL '
'discovery. Mutually exclusive with ' 'discovery. Mutually exclusive with '
'min_version and max_version'), 'min_version and max_version',
cfg.StrOpt('min-version', ),
cfg.StrOpt(
'min-version',
deprecated_opts=deprecated_opts.get('min-version'), deprecated_opts=deprecated_opts.get('min-version'),
help='The minimum major version of a given API, ' help='The minimum major version of a given API, '
'intended to be used as the lower bound of a ' 'intended to be used as the lower bound of a '
'range with max_version. Mutually exclusive ' 'range with max_version. Mutually exclusive '
'with version. If min_version is given with ' 'with version. If min_version is given with '
'no max_version it is as if max version is ' 'no max_version it is as if max version is '
'"latest".'), '"latest".',
cfg.StrOpt('max-version', ),
cfg.StrOpt(
'max-version',
deprecated_opts=deprecated_opts.get('max-version'), deprecated_opts=deprecated_opts.get('max-version'),
help='The maximum major version of a given API, ' help='The maximum major version of a given API, '
'intended to be used as the upper bound of a ' 'intended to be used as the upper bound of a '
'range with min_version. Mutually exclusive ' 'range with min_version. Mutually exclusive '
'with version.'), 'with version.',
cfg.IntOpt('connect-retries', ),
deprecated_opts=deprecated_opts.get( cfg.IntOpt(
'connect-retries'), 'connect-retries',
deprecated_opts=deprecated_opts.get('connect-retries'),
help='The maximum number of retries that should be ' help='The maximum number of retries that should be '
'attempted for connection errors.'), 'attempted for connection errors.',
cfg.FloatOpt('connect-retry-delay', ),
deprecated_opts=deprecated_opts.get( cfg.FloatOpt(
'connect-retry-delay'), 'connect-retry-delay',
deprecated_opts=deprecated_opts.get('connect-retry-delay'),
help='Delay (in seconds) between two retries ' help='Delay (in seconds) between two retries '
'for connection errors. If not set, ' 'for connection errors. If not set, '
'exponential retry starting with 0.5 ' 'exponential retry starting with 0.5 '
'seconds up to a maximum of 60 seconds ' 'seconds up to a maximum of 60 seconds '
'is used.'), 'is used.',
cfg.IntOpt('status-code-retries', ),
deprecated_opts=deprecated_opts.get( cfg.IntOpt(
'status-code-retries'), 'status-code-retries',
deprecated_opts=deprecated_opts.get('status-code-retries'),
help='The maximum number of retries that should be ' help='The maximum number of retries that should be '
'attempted for retriable HTTP status codes.'), 'attempted for retriable HTTP status codes.',
cfg.FloatOpt('status-code-retry-delay', ),
deprecated_opts=deprecated_opts.get( cfg.FloatOpt(
'status-code-retry-delay'), 'status-code-retry-delay',
deprecated_opts=deprecated_opts.get('status-code-retry-delay'),
help='Delay (in seconds) between two retries ' help='Delay (in seconds) between two retries '
'for retriable status codes. If not set, ' 'for retriable status codes. If not set, '
'exponential retry starting with 0.5 ' 'exponential retry starting with 0.5 '
'seconds up to a maximum of 60 seconds ' 'seconds up to a maximum of 60 seconds '
'is used.'), 'is used.',
cfg.ListOpt('retriable-status-codes', ),
deprecated_opts=deprecated_opts.get( cfg.ListOpt(
'retriable-status-codes'), 'retriable-status-codes',
deprecated_opts=deprecated_opts.get('retriable-status-codes'),
item_type=cfg.types.Integer(), item_type=cfg.types.Integer(),
help='List of retriable HTTP status codes that ' help='List of retriable HTTP status codes that '
'should be retried. If not set default to ' 'should be retried. If not set default to '
' [503]' ' [503]',
), ),
] ]
if include_deprecated: if include_deprecated:
opts += [ opts += [
cfg.StrOpt('interface', cfg.StrOpt(
'interface',
help='The default interface for endpoint URL ' help='The default interface for endpoint URL '
'discovery.', 'discovery.',
deprecated_for_removal=True, deprecated_for_removal=True,
deprecated_reason='Using valid-interfaces is' deprecated_reason='Using valid-interfaces is'
' preferrable because it is' ' preferrable because it is'
' capable of accepting a list of' ' capable of accepting a list of'
' possible interfaces.'), ' possible interfaces.',
)
] ]
return opts return opts
def register_conf_options(self, conf, group, include_deprecated=True, def register_conf_options(
deprecated_opts=None): self, conf, group, include_deprecated=True, deprecated_opts=None
):
"""Register the oslo_config options that are needed for an Adapter. """Register the oslo_config options that are needed for an Adapter.
The options that are set are: The options that are set are:
@ -231,12 +253,14 @@ class Adapter(base.BaseLoader):
the new ``endpoint_override`` option name:: the new ``endpoint_override`` option name::
old_opt = oslo_cfg.DeprecatedOpt('api_endpoint', 'old_group') old_opt = oslo_cfg.DeprecatedOpt('api_endpoint', 'old_group')
deprecated_opts={'endpoint_override': [old_opt]} deprecated_opts = {'endpoint_override': [old_opt]}
:returns: The list of options that was registered. :returns: The list of options that was registered.
""" """
opts = self.get_conf_options(include_deprecated=include_deprecated, opts = self.get_conf_options(
deprecated_opts=deprecated_opts) include_deprecated=include_deprecated,
deprecated_opts=deprecated_opts,
)
conf.register_group(_utils.get_oslo_config().OptGroup(group)) conf.register_group(_utils.get_oslo_config().OptGroup(group))
conf.register_opts(opts, group=group) conf.register_opts(opts, group=group)
return opts return opts
@ -270,16 +294,19 @@ def process_conf_options(confgrp, kwargs):
:raise TypeError: If invalid conf option values or combinations are found. :raise TypeError: If invalid conf option values or combinations are found.
""" """
if confgrp.valid_interfaces and getattr(confgrp, 'interface', None): if confgrp.valid_interfaces and getattr(confgrp, 'interface', None):
raise TypeError("interface and valid_interfaces are mutually" raise TypeError(
" exclusive. Please use valid_interfaces.") "interface and valid_interfaces are mutually"
" exclusive. Please use valid_interfaces."
)
if confgrp.valid_interfaces: if confgrp.valid_interfaces:
for iface in confgrp.valid_interfaces: for iface in confgrp.valid_interfaces:
if iface not in ('public', 'internal', 'admin'): if iface not in ('public', 'internal', 'admin'):
# TODO(efried): s/valies/values/ - are we allowed to fix this? # TODO(efried): s/valies/values/ - are we allowed to fix this?
raise TypeError("'{iface}' is not a valid value for" raise TypeError(
f"'{iface}' is not a valid value for"
" valid_interfaces. Valid valies are" " valid_interfaces. Valid valies are"
" public, internal or admin".format( " public, internal or admin"
iface=iface)) )
kwargs.setdefault('interface', confgrp.valid_interfaces) kwargs.setdefault('interface', confgrp.valid_interfaces)
elif hasattr(confgrp, 'interface'): elif hasattr(confgrp, 'interface'):
kwargs.setdefault('interface', confgrp.interface) kwargs.setdefault('interface', confgrp.interface)
@ -290,16 +317,16 @@ def process_conf_options(confgrp, kwargs):
kwargs.setdefault('version', confgrp.version) kwargs.setdefault('version', confgrp.version)
kwargs.setdefault('min_version', confgrp.min_version) kwargs.setdefault('min_version', confgrp.min_version)
kwargs.setdefault('max_version', confgrp.max_version) kwargs.setdefault('max_version', confgrp.max_version)
if kwargs['version'] and ( if kwargs['version'] and (kwargs['max_version'] or kwargs['min_version']):
kwargs['max_version'] or kwargs['min_version']):
raise TypeError( raise TypeError(
"version is mutually exclusive with min_version and" "version is mutually exclusive with min_version and max_version"
" max_version") )
kwargs.setdefault('connect_retries', confgrp.connect_retries) kwargs.setdefault('connect_retries', confgrp.connect_retries)
kwargs.setdefault('connect_retry_delay', confgrp.connect_retry_delay) kwargs.setdefault('connect_retry_delay', confgrp.connect_retry_delay)
kwargs.setdefault('status_code_retries', confgrp.status_code_retries) kwargs.setdefault('status_code_retries', confgrp.status_code_retries)
kwargs.setdefault('status_code_retry_delay', kwargs.setdefault(
confgrp.status_code_retry_delay) 'status_code_retry_delay', confgrp.status_code_retry_delay
)
kwargs.setdefault('retriable_status_codes', confgrp.retriable_status_codes) kwargs.setdefault('retriable_status_codes', confgrp.retriable_status_codes)

View File

@ -19,12 +19,14 @@ from keystoneauth1 import exceptions
PLUGIN_NAMESPACE = 'keystoneauth1.plugin' PLUGIN_NAMESPACE = 'keystoneauth1.plugin'
__all__ = ('get_available_plugin_names', __all__ = (
'get_available_plugin_names',
'get_available_plugin_loaders', 'get_available_plugin_loaders',
'get_plugin_loader', 'get_plugin_loader',
'get_plugin_options', 'get_plugin_options',
'BaseLoader', 'BaseLoader',
'PLUGIN_NAMESPACE') 'PLUGIN_NAMESPACE',
)
def _auth_plugin_available(ext): def _auth_plugin_available(ext):
@ -41,10 +43,12 @@ def get_available_plugin_names():
:returns: A list of names. :returns: A list of names.
:rtype: frozenset :rtype: frozenset
""" """
mgr = stevedore.EnabledExtensionManager(namespace=PLUGIN_NAMESPACE, mgr = stevedore.EnabledExtensionManager(
namespace=PLUGIN_NAMESPACE,
check_func=_auth_plugin_available, check_func=_auth_plugin_available,
invoke_on_load=True, invoke_on_load=True,
propagate_map_exceptions=True) propagate_map_exceptions=True,
)
return frozenset(mgr.names()) return frozenset(mgr.names())
@ -55,10 +59,12 @@ def get_available_plugin_loaders():
loader as the value. loader as the value.
:rtype: dict :rtype: dict
""" """
mgr = stevedore.EnabledExtensionManager(namespace=PLUGIN_NAMESPACE, mgr = stevedore.EnabledExtensionManager(
namespace=PLUGIN_NAMESPACE,
check_func=_auth_plugin_available, check_func=_auth_plugin_available,
invoke_on_load=True, invoke_on_load=True,
propagate_map_exceptions=True) propagate_map_exceptions=True,
)
return dict(mgr.map(lambda ext: (ext.entry_point.name, ext.obj))) return dict(mgr.map(lambda ext: (ext.entry_point.name, ext.obj)))
@ -75,9 +81,9 @@ def get_plugin_loader(name):
if a plugin cannot be created. if a plugin cannot be created.
""" """
try: try:
mgr = stevedore.DriverManager(namespace=PLUGIN_NAMESPACE, mgr = stevedore.DriverManager(
invoke_on_load=True, namespace=PLUGIN_NAMESPACE, invoke_on_load=True, name=name
name=name) )
except RuntimeError: except RuntimeError:
raise exceptions.NoMatchingPlugin(name) raise exceptions.NoMatchingPlugin(name)
@ -99,7 +105,6 @@ def get_plugin_options(name):
class BaseLoader(metaclass=abc.ABCMeta): class BaseLoader(metaclass=abc.ABCMeta):
@property @property
def plugin_class(self): def plugin_class(self):
raise NotImplementedError() raise NotImplementedError()
@ -153,8 +158,11 @@ class BaseLoader(metaclass=abc.ABCMeta):
handle differences between the registered options and what is required handle differences between the registered options and what is required
to create the plugin. to create the plugin.
""" """
missing_required = [o for o in self.get_options() missing_required = [
if o.required and kwargs.get(o.dest) is None] o
for o in self.get_options()
if o.required and kwargs.get(o.dest) is None
]
if missing_required: if missing_required:
raise exceptions.MissingRequiredOptions(missing_required) raise exceptions.MissingRequiredOptions(missing_required)

View File

@ -16,17 +16,18 @@ import os
from keystoneauth1.loading import base from keystoneauth1.loading import base
__all__ = ('register_argparse_arguments', __all__ = ('register_argparse_arguments', 'load_from_argparse_arguments')
'load_from_argparse_arguments')
def _register_plugin_argparse_arguments(parser, plugin): def _register_plugin_argparse_arguments(parser, plugin):
for opt in plugin.get_options(): for opt in plugin.get_options():
parser.add_argument(*opt.argparse_args, parser.add_argument(
*opt.argparse_args,
default=opt.argparse_default, default=opt.argparse_default,
metavar=opt.metavar, metavar=opt.metavar,
help=opt.help, help=opt.help,
dest='os_%s' % opt.dest) dest=f'os_{opt.dest}',
)
def register_argparse_arguments(parser, argv, default=None): def register_argparse_arguments(parser, argv, default=None):
@ -48,14 +49,17 @@ def register_argparse_arguments(parser, argv, default=None):
if a plugin cannot be created. if a plugin cannot be created.
""" """
in_parser = argparse.ArgumentParser(add_help=False) in_parser = argparse.ArgumentParser(add_help=False)
env_plugin = os.environ.get('OS_AUTH_TYPE', env_plugin = os.environ.get(
os.environ.get('OS_AUTH_PLUGIN', default)) 'OS_AUTH_TYPE', os.environ.get('OS_AUTH_PLUGIN', default)
)
for p in (in_parser, parser): for p in (in_parser, parser):
p.add_argument('--os-auth-type', p.add_argument(
'--os-auth-type',
'--os-auth-plugin', '--os-auth-plugin',
metavar='<name>', metavar='<name>',
default=env_plugin, default=env_plugin,
help='Authentication type to use') help='Authentication type to use',
)
options, _args = in_parser.parse_known_args(argv) options, _args = in_parser.parse_known_args(argv)
@ -66,7 +70,7 @@ def register_argparse_arguments(parser, argv, default=None):
msg = 'Default Authentication options' msg = 'Default Authentication options'
plugin = options.os_auth_type plugin = options.os_auth_type
else: else:
msg = 'Options specific to the %s plugin.' % options.os_auth_type msg = f'Options specific to the {options.os_auth_type} plugin.'
plugin = base.get_plugin_loader(options.os_auth_type) plugin = base.get_plugin_loader(options.os_auth_type)
group = parser.add_argument_group('Authentication Options', msg) group = parser.add_argument_group('Authentication Options', msg)
@ -97,6 +101,6 @@ def load_from_argparse_arguments(namespace, **kwargs):
plugin = base.get_plugin_loader(namespace.os_auth_type) plugin = base.get_plugin_loader(namespace.os_auth_type)
def _getter(opt): def _getter(opt):
return getattr(namespace, 'os_%s' % opt.dest) return getattr(namespace, f'os_{opt.dest}')
return plugin.load_from_options_getter(_getter, **kwargs) return plugin.load_from_options_getter(_getter, **kwargs)

View File

@ -13,18 +13,22 @@
from keystoneauth1.loading import base from keystoneauth1.loading import base
from keystoneauth1.loading import opts from keystoneauth1.loading import opts
_AUTH_TYPE_OPT = opts.Opt('auth_type', _AUTH_TYPE_OPT = opts.Opt(
'auth_type',
deprecated=[opts.Opt('auth_plugin')], deprecated=[opts.Opt('auth_plugin')],
help='Authentication type to load') help='Authentication type to load',
)
_section_help = 'Config Section from which to load plugin specific options' _section_help = 'Config Section from which to load plugin specific options'
_AUTH_SECTION_OPT = opts.Opt('auth_section', help=_section_help) _AUTH_SECTION_OPT = opts.Opt('auth_section', help=_section_help)
__all__ = ('get_common_conf_options', __all__ = (
'get_common_conf_options',
'get_plugin_conf_options', 'get_plugin_conf_options',
'register_conf_options', 'register_conf_options',
'load_from_conf_options') 'load_from_conf_options',
)
def get_common_conf_options(): def get_common_conf_options():

View File

@ -14,11 +14,13 @@ from keystoneauth1 import exceptions
from keystoneauth1.loading import base from keystoneauth1.loading import base
from keystoneauth1.loading import opts from keystoneauth1.loading import opts
__all__ = ('BaseIdentityLoader', __all__ = (
'BaseIdentityLoader',
'BaseV2Loader', 'BaseV2Loader',
'BaseV3Loader', 'BaseV3Loader',
'BaseFederationLoader', 'BaseFederationLoader',
'BaseGenericLoader') 'BaseGenericLoader',
)
class BaseIdentityLoader(base.BaseLoader): class BaseIdentityLoader(base.BaseLoader):
@ -31,13 +33,11 @@ class BaseIdentityLoader(base.BaseLoader):
""" """
def get_options(self): def get_options(self):
options = super(BaseIdentityLoader, self).get_options() options = super().get_options()
options.extend([ options.extend(
opts.Opt('auth-url', [opts.Opt('auth-url', required=True, help='Authentication URL')]
required=True, )
help='Authentication URL'),
])
return options return options
@ -51,14 +51,17 @@ class BaseV2Loader(BaseIdentityLoader):
""" """
def get_options(self): def get_options(self):
options = super(BaseV2Loader, self).get_options() options = super().get_options()
options.extend([ options.extend(
[
opts.Opt('tenant-id', help='Tenant ID'), opts.Opt('tenant-id', help='Tenant ID'),
opts.Opt('tenant-name', help='Tenant Name'), opts.Opt('tenant-name', help='Tenant Name'),
opts.Opt('trust-id', opts.Opt(
help='ID of the trust to use as a trustee use'), 'trust-id', help='ID of the trust to use as a trustee use'
]) ),
]
)
return options return options
@ -72,35 +75,44 @@ class BaseV3Loader(BaseIdentityLoader):
""" """
def get_options(self): def get_options(self):
options = super(BaseV3Loader, self).get_options() options = super().get_options()
options.extend([ options.extend(
[
opts.Opt('system-scope', help='Scope for system operations'), opts.Opt('system-scope', help='Scope for system operations'),
opts.Opt('domain-id', help='Domain ID to scope to'), opts.Opt('domain-id', help='Domain ID to scope to'),
opts.Opt('domain-name', help='Domain name to scope to'), opts.Opt('domain-name', help='Domain name to scope to'),
opts.Opt('project-id', help='Project ID to scope to'), opts.Opt('project-id', help='Project ID to scope to'),
opts.Opt('project-name', help='Project name to scope to'), opts.Opt('project-name', help='Project name to scope to'),
opts.Opt('project-domain-id', opts.Opt(
help='Domain ID containing project'), 'project-domain-id', help='Domain ID containing project'
opts.Opt('project-domain-name', ),
help='Domain name containing project'), opts.Opt(
opts.Opt('trust-id', 'project-domain-name',
help='ID of the trust to use as a trustee use'), help='Domain name containing project',
]) ),
opts.Opt(
'trust-id', help='ID of the trust to use as a trustee use'
),
]
)
return options return options
def load_from_options(self, **kwargs): def load_from_options(self, **kwargs):
if (kwargs.get('project_name') and if kwargs.get('project_name') and not (
not (kwargs.get('project_domain_name') or kwargs.get('project_domain_name')
kwargs.get('project_domain_id'))): or kwargs.get('project_domain_id')
m = "You have provided a project_name. In the V3 identity API a " \ ):
"project_name is only unique within a domain so you must " \ m = (
"also provide either a project_domain_id or " \ "You have provided a project_name. In the V3 identity API a "
"project_name is only unique within a domain so you must "
"also provide either a project_domain_id or "
"project_domain_name." "project_domain_name."
)
raise exceptions.OptionError(m) raise exceptions.OptionError(m)
return super(BaseV3Loader, self).load_from_options(**kwargs) return super().load_from_options(**kwargs)
class BaseFederationLoader(BaseV3Loader): class BaseFederationLoader(BaseV3Loader):
@ -112,16 +124,22 @@ class BaseFederationLoader(BaseV3Loader):
""" """
def get_options(self): def get_options(self):
options = super(BaseFederationLoader, self).get_options() options = super().get_options()
options.extend([ options.extend(
opts.Opt('identity-provider', [
opts.Opt(
'identity-provider',
help="Identity Provider's name", help="Identity Provider's name",
required=True), required=True,
opts.Opt('protocol', ),
opts.Opt(
'protocol',
help='Protocol for federated plugin', help='Protocol for federated plugin',
required=True), required=True,
]) ),
]
)
return options return options
@ -136,32 +154,48 @@ class BaseGenericLoader(BaseIdentityLoader):
""" """
def get_options(self): def get_options(self):
options = super(BaseGenericLoader, self).get_options() options = super().get_options()
options.extend([ options.extend(
[
opts.Opt('system-scope', help='Scope for system operations'), opts.Opt('system-scope', help='Scope for system operations'),
opts.Opt('domain-id', help='Domain ID to scope to'), opts.Opt('domain-id', help='Domain ID to scope to'),
opts.Opt('domain-name', help='Domain name to scope to'), opts.Opt('domain-name', help='Domain name to scope to'),
opts.Opt('project-id', help='Project ID to scope to', opts.Opt(
deprecated=[opts.Opt('tenant-id')]), 'project-id',
opts.Opt('project-name', help='Project name to scope to', help='Project ID to scope to',
deprecated=[opts.Opt('tenant-name')]), deprecated=[opts.Opt('tenant-id')],
opts.Opt('project-domain-id', ),
help='Domain ID containing project'), opts.Opt(
opts.Opt('project-domain-name', 'project-name',
help='Domain name containing project'), help='Project name to scope to',
opts.Opt('trust-id', deprecated=[opts.Opt('tenant-name')],
help='ID of the trust to use as a trustee use'), ),
opts.Opt('default-domain-id', opts.Opt(
'project-domain-id', help='Domain ID containing project'
),
opts.Opt(
'project-domain-name',
help='Domain name containing project',
),
opts.Opt(
'trust-id', help='ID of the trust to use as a trustee use'
),
opts.Opt(
'default-domain-id',
help='Optional domain ID to use with v3 and v2 ' help='Optional domain ID to use with v3 and v2 '
'parameters. It will be used for both the user ' 'parameters. It will be used for both the user '
'and project domain in v3 and ignored in ' 'and project domain in v3 and ignored in '
'v2 authentication.'), 'v2 authentication.',
opts.Opt('default-domain-name', ),
opts.Opt(
'default-domain-name',
help='Optional domain name to use with v3 API and v2 ' help='Optional domain name to use with v3 API and v2 '
'parameters. It will be used for both the user ' 'parameters. It will be used for both the user '
'and project domain in v3 and ignored in ' 'and project domain in v3 and ignored in '
'v2 authentication.'), 'v2 authentication.',
]) ),
]
)
return options return options

View File

@ -19,7 +19,7 @@ from keystoneauth1.loading import _utils
__all__ = ('Opt',) __all__ = ('Opt',)
class Opt(object): class Opt:
"""An option required by an authentication plugin. """An option required by an authentication plugin.
Opts provide a means for authentication plugins that are going to be Opts provide a means for authentication plugins that are going to be
@ -60,7 +60,8 @@ class Opt(object):
appropriate) set the string that should be used to prompt with. appropriate) set the string that should be used to prompt with.
""" """
def __init__(self, def __init__(
self,
name, name,
type=str, type=str,
help=None, help=None,
@ -70,7 +71,8 @@ class Opt(object):
default=None, default=None,
metavar=None, metavar=None,
required=False, required=False,
prompt=None): prompt=None,
):
if not callable(type): if not callable(type):
raise TypeError('type must be callable') raise TypeError('type must be callable')
@ -95,33 +97,37 @@ class Opt(object):
def __repr__(self): def __repr__(self):
"""Return string representation of option name.""" """Return string representation of option name."""
return '<Opt: %s>' % self.name return f'<Opt: {self.name}>'
def _to_oslo_opt(self): def _to_oslo_opt(self):
cfg = _utils.get_oslo_config() cfg = _utils.get_oslo_config()
deprecated_opts = [cfg.DeprecatedOpt(o.name) for o in self.deprecated] deprecated_opts = [cfg.DeprecatedOpt(o.name) for o in self.deprecated]
return cfg.Opt(name=self.name, return cfg.Opt(
name=self.name,
type=self.type, type=self.type,
help=self.help, help=self.help,
secret=self.secret, secret=self.secret,
required=self.required, required=self.required,
dest=self.dest, dest=self.dest,
deprecated_opts=deprecated_opts, deprecated_opts=deprecated_opts,
metavar=self.metavar) metavar=self.metavar,
)
def __eq__(self, other): def __eq__(self, other):
"""Define equality operator on option parameters.""" """Define equality operator on option parameters."""
return (type(self) is type(other) and return (
self.name == other.name and type(self) is type(other)
self.type == other.type and and self.name == other.name
self.help == other.help and and self.type == other.type
self.secret == other.secret and and self.help == other.help
self.required == other.required and and self.secret == other.secret
self.dest == other.dest and and self.required == other.required
self.deprecated == other.deprecated and and self.dest == other.dest
self.default == other.default and and self.deprecated == other.deprecated
self.metavar == other.metavar) and self.default == other.default
and self.metavar == other.metavar
)
# NOTE: This function is only needed by Python 2. If we get to point where # NOTE: This function is only needed by Python 2. If we get to point where
# we don't support Python 2 anymore, this function should be removed. # we don't support Python 2 anymore, this function should be removed.
@ -135,13 +141,15 @@ class Opt(object):
@property @property
def argparse_args(self): def argparse_args(self):
return ['--os-%s' % o.name for o in self._all_opts] return [f'--os-{o.name}' for o in self._all_opts]
@property @property
def argparse_default(self): def argparse_default(self):
# select the first ENV that is not false-y or return None # select the first ENV that is not false-y or return None
for o in self._all_opts: for o in self._all_opts:
v = os.environ.get('OS_%s' % o.name.replace('-', '_').upper()) v = os.environ.get(
'OS_{}'.format(o.name.replace('-', '_').upper())
)
if v: if v:
return v return v

View File

@ -18,11 +18,13 @@ from keystoneauth1.loading import base
from keystoneauth1 import session from keystoneauth1 import session
__all__ = ('register_argparse_arguments', __all__ = (
'register_argparse_arguments',
'load_from_argparse_arguments', 'load_from_argparse_arguments',
'register_conf_options', 'register_conf_options',
'load_from_conf_options', 'load_from_conf_options',
'get_conf_options') 'get_conf_options',
)
def _positive_non_zero_float(argument_value): def _positive_non_zero_float(argument_value):
@ -31,16 +33,15 @@ def _positive_non_zero_float(argument_value):
try: try:
value = float(argument_value) value = float(argument_value)
except ValueError: except ValueError:
msg = "%s must be a float" % argument_value msg = f"{argument_value} must be a float"
raise argparse.ArgumentTypeError(msg) raise argparse.ArgumentTypeError(msg)
if value <= 0: if value <= 0:
msg = "%s must be greater than 0" % argument_value msg = f"{argument_value} must be greater than 0"
raise argparse.ArgumentTypeError(msg) raise argparse.ArgumentTypeError(msg)
return value return value
class Session(base.BaseLoader): class Session(base.BaseLoader):
@property @property
def plugin_class(self): def plugin_class(self):
return session.Session return session.Session
@ -48,13 +49,15 @@ class Session(base.BaseLoader):
def get_options(self): def get_options(self):
return [] return []
def load_from_options(self, def load_from_options(
self,
insecure=False, insecure=False,
verify=None, verify=None,
cacert=None, cacert=None,
cert=None, cert=None,
key=None, key=None,
**kwargs): **kwargs,
):
"""Create a session with individual certificate parameters. """Create a session with individual certificate parameters.
Some parameters used to create a session don't lend themselves to be Some parameters used to create a session don't lend themselves to be
@ -72,14 +75,13 @@ class Session(base.BaseLoader):
# requests lib form of having the cert and key as a tuple # requests lib form of having the cert and key as a tuple
cert = (cert, key) cert = (cert, key)
return super(Session, self).load_from_options(verify=verify, return super().load_from_options(verify=verify, cert=cert, **kwargs)
cert=cert,
**kwargs)
def register_argparse_arguments(self, parser): def register_argparse_arguments(self, parser):
session_group = parser.add_argument_group( session_group = parser.add_argument_group(
'API Connection Options', 'API Connection Options',
'Options controlling the HTTP API Connections') 'Options controlling the HTTP API Connections',
)
session_group.add_argument( session_group.add_argument(
'--insecure', '--insecure',
@ -89,7 +91,8 @@ class Session(base.BaseLoader):
'"insecure" TLS (https) requests. The ' '"insecure" TLS (https) requests. The '
'server\'s certificate will not be verified ' 'server\'s certificate will not be verified '
'against any certificate authorities. This ' 'against any certificate authorities. This '
'option should be used with caution.') 'option should be used with caution.',
)
session_group.add_argument( session_group.add_argument(
'--os-cacert', '--os-cacert',
@ -97,7 +100,8 @@ class Session(base.BaseLoader):
default=os.environ.get('OS_CACERT'), default=os.environ.get('OS_CACERT'),
help='Specify a CA bundle file to use in ' help='Specify a CA bundle file to use in '
'verifying a TLS (https) server certificate. ' 'verifying a TLS (https) server certificate. '
'Defaults to env[OS_CACERT].') 'Defaults to env[OS_CACERT].',
)
session_group.add_argument( session_group.add_argument(
'--os-cert', '--os-cert',
@ -105,7 +109,8 @@ class Session(base.BaseLoader):
default=os.environ.get('OS_CERT'), default=os.environ.get('OS_CERT'),
help='The location for the keystore (PEM formatted) ' help='The location for the keystore (PEM formatted) '
'containing the public key of this client. ' 'containing the public key of this client. '
'Defaults to env[OS_CERT].') 'Defaults to env[OS_CERT].',
)
session_group.add_argument( session_group.add_argument(
'--os-key', '--os-key',
@ -113,20 +118,23 @@ class Session(base.BaseLoader):
default=os.environ.get('OS_KEY'), default=os.environ.get('OS_KEY'),
help='The location for the keystore (PEM formatted) ' help='The location for the keystore (PEM formatted) '
'containing the private key of this client. ' 'containing the private key of this client. '
'Defaults to env[OS_KEY].') 'Defaults to env[OS_KEY].',
)
session_group.add_argument( session_group.add_argument(
'--timeout', '--timeout',
default=600, default=600,
type=_positive_non_zero_float, type=_positive_non_zero_float,
metavar='<seconds>', metavar='<seconds>',
help='Set request timeout (in seconds).') help='Set request timeout (in seconds).',
)
session_group.add_argument( session_group.add_argument(
'--collect-timing', '--collect-timing',
default=False, default=False,
action='store_true', action='store_true',
help='Collect per-API call timing information.') help='Collect per-API call timing information.',
)
def load_from_argparse_arguments(self, namespace, **kwargs): def load_from_argparse_arguments(self, namespace, **kwargs):
kwargs.setdefault('insecure', namespace.insecure) kwargs.setdefault('insecure', namespace.insecure)
@ -162,7 +170,7 @@ class Session(base.BaseLoader):
``cafile`` option name:: ``cafile`` option name::
old_opt = oslo_cfg.DeprecatedOpt('ca_file', 'old_group') old_opt = oslo_cfg.DeprecatedOpt('ca_file', 'old_group')
deprecated_opts={'cafile': [old_opt]} deprecated_opts = {'cafile': [old_opt]}
:returns: A list of oslo_config options. :returns: A list of oslo_config options.
""" """
@ -171,33 +179,46 @@ class Session(base.BaseLoader):
if deprecated_opts is None: if deprecated_opts is None:
deprecated_opts = {} deprecated_opts = {}
return [cfg.StrOpt('cafile', return [
cfg.StrOpt(
'cafile',
deprecated_opts=deprecated_opts.get('cafile'), deprecated_opts=deprecated_opts.get('cafile'),
help='PEM encoded Certificate Authority to use ' help='PEM encoded Certificate Authority to use '
'when verifying HTTPs connections.'), 'when verifying HTTPs connections.',
cfg.StrOpt('certfile', ),
cfg.StrOpt(
'certfile',
deprecated_opts=deprecated_opts.get('certfile'), deprecated_opts=deprecated_opts.get('certfile'),
help='PEM encoded client certificate cert file'), help='PEM encoded client certificate cert file',
cfg.StrOpt('keyfile', ),
cfg.StrOpt(
'keyfile',
deprecated_opts=deprecated_opts.get('keyfile'), deprecated_opts=deprecated_opts.get('keyfile'),
help='PEM encoded client certificate key file'), help='PEM encoded client certificate key file',
cfg.BoolOpt('insecure', ),
cfg.BoolOpt(
'insecure',
default=False, default=False,
deprecated_opts=deprecated_opts.get('insecure'), deprecated_opts=deprecated_opts.get('insecure'),
help='Verify HTTPS connections.'), help='Verify HTTPS connections.',
cfg.IntOpt('timeout', ),
cfg.IntOpt(
'timeout',
deprecated_opts=deprecated_opts.get('timeout'), deprecated_opts=deprecated_opts.get('timeout'),
help='Timeout value for http requests'), help='Timeout value for http requests',
cfg.BoolOpt('collect-timing', ),
deprecated_opts=deprecated_opts.get( cfg.BoolOpt(
'collect-timing'), 'collect-timing',
deprecated_opts=deprecated_opts.get('collect-timing'),
default=False, default=False,
help='Collect per-API call timing information.'), help='Collect per-API call timing information.',
cfg.BoolOpt('split-loggers', ),
deprecated_opts=deprecated_opts.get( cfg.BoolOpt(
'split-loggers'), 'split-loggers',
deprecated_opts=deprecated_opts.get('split-loggers'),
default=False, default=False,
help='Log requests to multiple loggers.') help='Log requests to multiple loggers.',
),
] ]
def register_conf_options(self, conf, group, deprecated_opts=None): def register_conf_options(self, conf, group, deprecated_opts=None):
@ -223,7 +244,7 @@ class Session(base.BaseLoader):
``cafile`` option name:: ``cafile`` option name::
old_opt = oslo_cfg.DeprecatedOpt('ca_file', 'old_group') old_opt = oslo_cfg.DeprecatedOpt('ca_file', 'old_group')
deprecated_opts={'cafile': [old_opt]} deprecated_opts = {'cafile': [old_opt]}
:returns: The list of options that was registered. :returns: The list of options that was registered.
""" """

View File

@ -20,7 +20,7 @@ AUTH_INTERFACE = object()
IDENTITY_AUTH_HEADER_NAME = 'X-Auth-Token' IDENTITY_AUTH_HEADER_NAME = 'X-Auth-Token'
class BaseAuthPlugin(object): class BaseAuthPlugin:
"""The basic structure of an authentication plugin. """The basic structure of an authentication plugin.
.. note:: .. note::
@ -110,10 +110,9 @@ class BaseAuthPlugin(object):
return {IDENTITY_AUTH_HEADER_NAME: token} return {IDENTITY_AUTH_HEADER_NAME: token}
def get_endpoint_data(self, session, def get_endpoint_data(
endpoint_override=None, self, session, endpoint_override=None, discover_versions=True, **kwargs
discover_versions=True, ):
**kwargs):
"""Return a valid endpoint data for a the service. """Return a valid endpoint data for a the service.
:param session: A session object that can be used for communication. :param session: A session object that can be used for communication.
@ -140,8 +139,10 @@ class BaseAuthPlugin(object):
return endpoint_data return endpoint_data
return endpoint_data.get_versioned_data( return endpoint_data.get_versioned_data(
session, cache=self._discovery_cache, session,
discover_versions=discover_versions) cache=self._discovery_cache,
discover_versions=discover_versions,
)
def get_api_major_version(self, session, endpoint_override=None, **kwargs): def get_api_major_version(self, session, endpoint_override=None, **kwargs):
"""Get the major API version from the endpoint. """Get the major API version from the endpoint.
@ -158,16 +159,22 @@ class BaseAuthPlugin(object):
:rtype: `keystoneauth1.discover.EndpointData` or None :rtype: `keystoneauth1.discover.EndpointData` or None
""" """
endpoint_data = self.get_endpoint_data( endpoint_data = self.get_endpoint_data(
session, endpoint_override=endpoint_override, session,
discover_versions=False, **kwargs) endpoint_override=endpoint_override,
discover_versions=False,
**kwargs,
)
if endpoint_data is None: if endpoint_data is None:
return return
if endpoint_data.api_version is None: if endpoint_data.api_version is None:
# No version detected from the URL, trying full discovery. # No version detected from the URL, trying full discovery.
endpoint_data = self.get_endpoint_data( endpoint_data = self.get_endpoint_data(
session, endpoint_override=endpoint_override, session,
discover_versions=True, **kwargs) endpoint_override=endpoint_override,
discover_versions=True,
**kwargs,
)
if endpoint_data and endpoint_data.api_version: if endpoint_data and endpoint_data.api_version:
return endpoint_data.api_version return endpoint_data.api_version
@ -195,7 +202,8 @@ class BaseAuthPlugin(object):
:rtype: string :rtype: string
""" """
endpoint_data = self.get_endpoint_data( endpoint_data = self.get_endpoint_data(
session, discover_versions=False, **kwargs) session, discover_versions=False, **kwargs
)
if not endpoint_data: if not endpoint_data:
return None return None
return endpoint_data.url return endpoint_data.url
@ -340,7 +348,7 @@ class FixedEndpointPlugin(BaseAuthPlugin):
"""A base class for plugins that have one fixed endpoint.""" """A base class for plugins that have one fixed endpoint."""
def __init__(self, endpoint=None): def __init__(self, endpoint=None):
super(FixedEndpointPlugin, self).__init__() super().__init__()
self.endpoint = endpoint self.endpoint = endpoint
def get_endpoint(self, session, **kwargs): def get_endpoint(self, session, **kwargs):
@ -352,10 +360,9 @@ class FixedEndpointPlugin(BaseAuthPlugin):
""" """
return kwargs.get('endpoint_override') or self.endpoint return kwargs.get('endpoint_override') or self.endpoint
def get_endpoint_data(self, session, def get_endpoint_data(
endpoint_override=None, self, session, endpoint_override=None, discover_versions=True, **kwargs
discover_versions=True, ):
**kwargs):
"""Return a valid endpoint data for a the service. """Return a valid endpoint data for a the service.
:param session: A session object that can be used for communication. :param session: A session object that can be used for communication.
@ -374,8 +381,9 @@ class FixedEndpointPlugin(BaseAuthPlugin):
:return: Valid EndpointData or None if not available. :return: Valid EndpointData or None if not available.
:rtype: `keystoneauth1.discover.EndpointData` or None :rtype: `keystoneauth1.discover.EndpointData` or None
""" """
return super(FixedEndpointPlugin, self).get_endpoint_data( return super().get_endpoint_data(
session, session,
endpoint_override=endpoint_override or self.endpoint, endpoint_override=endpoint_override or self.endpoint,
discover_versions=discover_versions, discover_versions=discover_versions,
**kwargs) **kwargs,
)

View File

@ -18,9 +18,8 @@ __all__ = ('ServiceTokenAuthWrapper',)
class ServiceTokenAuthWrapper(plugin.BaseAuthPlugin): class ServiceTokenAuthWrapper(plugin.BaseAuthPlugin):
def __init__(self, user_auth, service_auth): def __init__(self, user_auth, service_auth):
super(ServiceTokenAuthWrapper, self).__init__() super().__init__()
self.user_auth = user_auth self.user_auth = user_auth
self.service_auth = service_auth self.service_auth = service_auth

View File

@ -40,14 +40,12 @@ try:
except ImportError: except ImportError:
osprofiler_web = None osprofiler_web = None
DEFAULT_USER_AGENT = 'keystoneauth1/%s %s %s/%s' % ( DEFAULT_USER_AGENT = f'keystoneauth1/{keystoneauth1.__version__} {requests.utils.default_user_agent()} {platform.python_implementation()}/{platform.python_version()}'
keystoneauth1.__version__, requests.utils.default_user_agent(),
platform.python_implementation(), platform.python_version())
# NOTE(jamielennox): Clients will likely want to print more than json. Please # NOTE(jamielennox): Clients will likely want to print more than json. Please
# propose a patch if you have a content type you think is reasonable to print # propose a patch if you have a content type you think is reasonable to print
# here and we'll add it to the list as required. # here and we'll add it to the list as required.
_LOG_CONTENT_TYPES = set(['application/json', 'text/plain']) _LOG_CONTENT_TYPES = {'application/json', 'text/plain'}
_MAX_RETRY_INTERVAL = 60.0 _MAX_RETRY_INTERVAL = 60.0
_EXPONENTIAL_DELAY_START = 0.5 _EXPONENTIAL_DELAY_START = 0.5
@ -101,7 +99,7 @@ def _sanitize_headers(headers):
return str_dict return str_dict
class NoOpSemaphore(object): class NoOpSemaphore:
"""Empty context manager for use as a default semaphore.""" """Empty context manager for use as a default semaphore."""
def __enter__(self): def __enter__(self):
@ -114,7 +112,6 @@ class NoOpSemaphore(object):
class _JSONEncoder(json.JSONEncoder): class _JSONEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):
if isinstance(o, datetime.datetime): if isinstance(o, datetime.datetime):
return o.isoformat() return o.isoformat()
@ -123,10 +120,10 @@ class _JSONEncoder(json.JSONEncoder):
if netaddr and isinstance(o, netaddr.IPAddress): if netaddr and isinstance(o, netaddr.IPAddress):
return str(o) return str(o)
return super(_JSONEncoder, self).default(o) return super().default(o)
class _StringFormatter(object): class _StringFormatter:
"""A String formatter that fetches values on demand.""" """A String formatter that fetches values on demand."""
def __init__(self, session, auth): def __init__(self, session, auth):
@ -142,8 +139,10 @@ class _StringFormatter(object):
raise AttributeError(item) raise AttributeError(item)
if not value: if not value:
raise ValueError("This type of authentication does not provide a " raise ValueError(
"%s that can be substituted" % item) "This type of authentication does not provide a "
f"{item} that can be substituted"
)
return value return value
@ -159,8 +158,11 @@ def _determine_calling_package():
# because sys.modules can change during iteration, which results # because sys.modules can change during iteration, which results
# in a RuntimeError # in a RuntimeError
# https://docs.python.org/3/library/sys.html#sys.modules # https://docs.python.org/3/library/sys.html#sys.modules
mod_lookup = dict((m.__file__, n) for n, m in sys.modules.copy().items() mod_lookup = {
if hasattr(m, '__file__')) m.__file__: n
for n, m in sys.modules.copy().items()
if hasattr(m, '__file__')
}
# NOTE(shaleh): these are not useful because they hide the real # NOTE(shaleh): these are not useful because they hide the real
# user of the code. debtcollector did not import keystoneauth but # user of the code. debtcollector did not import keystoneauth but
@ -205,7 +207,7 @@ def _determine_user_agent():
# NOTE(shaleh): mod_wsgi is not any more useful than just # NOTE(shaleh): mod_wsgi is not any more useful than just
# reporting "keystoneauth". Ignore it and perform the package name # reporting "keystoneauth". Ignore it and perform the package name
# heuristic. # heuristic.
ignored = ('mod_wsgi', ) ignored = ('mod_wsgi',)
try: try:
name = sys.argv[0] name = sys.argv[0]
@ -222,7 +224,7 @@ def _determine_user_agent():
return name return name
class RequestTiming(object): class RequestTiming:
"""Contains timing information for an HTTP interaction.""" """Contains timing information for an HTTP interaction."""
#: HTTP method used for the call (GET, POST, etc) #: HTTP method used for the call (GET, POST, etc)
@ -240,7 +242,7 @@ class RequestTiming(object):
self.elapsed = elapsed self.elapsed = elapsed
class _Retries(object): class _Retries:
__slots__ = ('_fixed_delay', '_current') __slots__ = ('_fixed_delay', '_current')
def __init__(self, fixed_delay=None): def __init__(self, fixed_delay=None):
@ -263,7 +265,7 @@ class _Retries(object):
next = __next__ next = __next__
class Session(object): class Session:
"""Maintains client communication state and common functionality. """Maintains client communication state and common functionality.
As much as possible the parameters to this class reflect and are passed As much as possible the parameters to this class reflect and are passed
@ -341,14 +343,26 @@ class Session(object):
_DEFAULT_REDIRECT_LIMIT = 30 _DEFAULT_REDIRECT_LIMIT = 30
def __init__(self, auth=None, session=None, original_ip=None, verify=True, def __init__(
cert=None, timeout=None, user_agent=None, self,
redirect=_DEFAULT_REDIRECT_LIMIT, additional_headers=None, auth=None,
app_name=None, app_version=None, additional_user_agent=None, session=None,
discovery_cache=None, split_loggers=None, original_ip=None,
collect_timing=False, rate_semaphore=None, verify=True,
connect_retries=0): cert=None,
timeout=None,
user_agent=None,
redirect=_DEFAULT_REDIRECT_LIMIT,
additional_headers=None,
app_name=None,
app_version=None,
additional_user_agent=None,
discovery_cache=None,
split_loggers=None,
collect_timing=False,
rate_semaphore=None,
connect_retries=0,
):
self.auth = auth self.auth = auth
self.session = _construct_session(session) self.session = _construct_session(session)
# NOTE(mwhahaha): keep a reference to the session object so we can # NOTE(mwhahaha): keep a reference to the session object so we can
@ -383,7 +397,7 @@ class Session(object):
self.timeout = float(timeout) self.timeout = float(timeout)
if user_agent is not None: if user_agent is not None:
self.user_agent = "%s %s" % (user_agent, DEFAULT_USER_AGENT) self.user_agent = f"{user_agent} {DEFAULT_USER_AGENT}"
self._json = _JSONEncoder() self._json = _JSONEncoder()
@ -431,13 +445,17 @@ class Session(object):
@staticmethod @staticmethod
def _process_header(header): def _process_header(header):
"""Redact the secure headers to be logged.""" """Redact the secure headers to be logged."""
secure_headers = ('authorization', 'x-auth-token', secure_headers = (
'x-subject-token', 'x-service-token') 'authorization',
'x-auth-token',
'x-subject-token',
'x-service-token',
)
if header[0].lower() in secure_headers: if header[0].lower() in secure_headers:
token_hasher = hashlib.sha256() token_hasher = hashlib.sha256()
token_hasher.update(header[1].encode('utf-8')) token_hasher.update(header[1].encode('utf-8'))
token_hash = token_hasher.hexdigest() token_hash = token_hasher.hexdigest()
return (header[0], '{SHA256}%s' % token_hash) return (header[0], f'{{SHA256}}{token_hash}')
return header return header
def _get_split_loggers(self, split_loggers): def _get_split_loggers(self, split_loggers):
@ -458,9 +476,17 @@ class Session(object):
split_loggers = False split_loggers = False
return split_loggers return split_loggers
def _http_log_request(self, url, method=None, data=None, def _http_log_request(
json=None, headers=None, query_params=None, self,
logger=None, split_loggers=None): url,
method=None,
data=None,
json=None,
headers=None,
query_params=None,
logger=None,
split_loggers=None,
):
string_parts = [] string_parts = []
if self._get_split_loggers(split_loggers): if self._get_split_loggers(split_loggers):
@ -484,7 +510,7 @@ class Session(object):
if self.verify is False: if self.verify is False:
string_parts.append('--insecure') string_parts.append('--insecure')
elif isinstance(self.verify, str): elif isinstance(self.verify, str):
string_parts.append('--cacert "%s"' % self.verify) string_parts.append(f'--cacert "{self.verify}"')
if method: if method:
string_parts.extend(['-X', method]) string_parts.extend(['-X', method])
@ -495,15 +521,16 @@ class Session(object):
url = url + '?' + urllib.parse.urlencode(query_params) url = url + '?' + urllib.parse.urlencode(query_params)
# URLs with query strings need to be wrapped in quotes in order # URLs with query strings need to be wrapped in quotes in order
# for the CURL command to run properly. # for the CURL command to run properly.
string_parts.append('"%s"' % url) string_parts.append(f'"{url}"')
else: else:
string_parts.append(url) string_parts.append(url)
if headers: if headers:
# Sort headers so that testing can work consistently. # Sort headers so that testing can work consistently.
for header in sorted(headers.items()): for header in sorted(headers.items()):
string_parts.append('-H "%s: %s"' string_parts.append(
% self._process_header(header)) '-H "{}: {}"'.format(*self._process_header(header))
)
if json: if json:
data = self._json.encode(json) data = self._json.encode(json)
if data: if data:
@ -512,13 +539,20 @@ class Session(object):
data = data.decode("ascii") data = data.decode("ascii")
except UnicodeDecodeError: except UnicodeDecodeError:
data = "<binary_data>" data = "<binary_data>"
string_parts.append("-d '%s'" % data) string_parts.append(f"-d '{data}'")
logger.debug(' '.join(string_parts)) logger.debug(' '.join(string_parts))
def _http_log_response(self, response=None, json=None, def _http_log_response(
status_code=None, headers=None, text=None, self,
logger=None, split_loggers=True): response=None,
json=None,
status_code=None,
headers=None,
text=None,
logger=None,
split_loggers=True,
):
string_parts = [] string_parts = []
body_parts = [] body_parts = []
if self._get_split_loggers(split_loggers): if self._get_split_loggers(split_loggers):
@ -540,11 +574,13 @@ class Session(object):
headers = response.headers headers = response.headers
if status_code: if status_code:
string_parts.append('[%s]' % status_code) string_parts.append(f'[{status_code}]')
if headers: if headers:
# Sort headers so that testing can work consistently. # Sort headers so that testing can work consistently.
for header in sorted(headers.items()): for header in sorted(headers.items()):
string_parts.append('%s: %s' % self._process_header(header)) string_parts.append(
'{}: {}'.format(*self._process_header(header))
)
logger.debug(' '.join(string_parts)) logger.debug(' '.join(string_parts))
if not body_logger.isEnabledFor(logging.DEBUG): if not body_logger.isEnabledFor(logging.DEBUG):
@ -565,12 +601,15 @@ class Session(object):
# [1] https://www.w3.org/Protocols/rfc1341/4_Content-Type.html # [1] https://www.w3.org/Protocols/rfc1341/4_Content-Type.html
for log_type in _LOG_CONTENT_TYPES: for log_type in _LOG_CONTENT_TYPES:
if content_type is not None and content_type.startswith( if content_type is not None and content_type.startswith(
log_type): log_type
):
text = self._remove_service_catalog(response.text) text = self._remove_service_catalog(response.text)
break break
else: else:
text = ('Omitted, Content-Type is set to %s. Only ' text = (
'%s responses have their bodies logged.') 'Omitted, Content-Type is set to %s. Only '
'%s responses have their bodies logged.'
)
text = text % (content_type, ', '.join(_LOG_CONTENT_TYPES)) text = text % (content_type, ', '.join(_LOG_CONTENT_TYPES))
if json: if json:
text = self._json.encode(json) text = self._json.encode(json)
@ -581,7 +620,8 @@ class Session(object):
@staticmethod @staticmethod
def _set_microversion_headers( def _set_microversion_headers(
headers, microversion, service_type, endpoint_filter): headers, microversion, service_type, endpoint_filter
):
# We're converting it to normalized version number for two reasons. # We're converting it to normalized version number for two reasons.
# First, to validate it's a real version number. Second, so that in # First, to validate it's a real version number. Second, so that in
# the future we can pre-validate that it is within the range of # the future we can pre-validate that it is within the range of
@ -592,27 +632,32 @@ class Session(object):
# with the microversion range we found in discovery. # with the microversion range we found in discovery.
microversion = discover.normalize_version_number(microversion) microversion = discover.normalize_version_number(microversion)
# Can't specify a M.latest microversion # Can't specify a M.latest microversion
if (microversion[0] != discover.LATEST and if (
discover.LATEST in microversion[1:]): microversion[0] != discover.LATEST
and discover.LATEST in microversion[1:]
):
raise TypeError( raise TypeError(
"Specifying a '{major}.latest' microversion is not allowed.") "Specifying a '{major}.latest' microversion is not allowed."
)
microversion = discover.version_to_string(microversion) microversion = discover.version_to_string(microversion)
if not service_type: if not service_type:
if endpoint_filter and 'service_type' in endpoint_filter: if endpoint_filter and 'service_type' in endpoint_filter:
service_type = endpoint_filter['service_type'] service_type = endpoint_filter['service_type']
else: else:
raise TypeError( raise TypeError(
"microversion {microversion} was requested but no" f"microversion {microversion} was requested but no"
" service_type information is available. Either provide a" " service_type information is available. Either provide a"
" service_type in endpoint_filter or pass" " service_type in endpoint_filter or pass"
" microversion_service_type as an argument.".format( " microversion_service_type as an argument."
microversion=microversion)) )
# TODO(mordred) cinder uses volume in its microversion header. This # TODO(mordred) cinder uses volume in its microversion header. This
# logic should be handled in the future by os-service-types but for # logic should be handled in the future by os-service-types but for
# now hard-code for cinder. # now hard-code for cinder.
if (service_type.startswith('volume') or if (
service_type == 'block-storage'): service_type.startswith('volume')
or service_type == 'block-storage'
):
service_type = 'volume' service_type = 'volume'
elif service_type.startswith('share'): elif service_type.startswith('share'):
# NOTE(gouthamr) manila doesn't honor the "OpenStack-API-Version" # NOTE(gouthamr) manila doesn't honor the "OpenStack-API-Version"
@ -622,25 +667,44 @@ class Session(object):
# service catalog # service catalog
service_type = 'shared-file-system' service_type = 'shared-file-system'
headers.setdefault('OpenStack-API-Version', headers.setdefault(
'{service_type} {microversion}'.format( 'OpenStack-API-Version', f'{service_type} {microversion}'
service_type=service_type, )
microversion=microversion))
header_names = _mv_legacy_headers_for_service(service_type) header_names = _mv_legacy_headers_for_service(service_type)
for h in header_names: for h in header_names:
headers.setdefault(h, microversion) headers.setdefault(h, microversion)
def request(self, url, method, json=None, original_ip=None, def request(
user_agent=None, redirect=None, authenticated=None, self,
endpoint_filter=None, auth=None, requests_auth=None, url,
raise_exc=True, allow_reauth=True, log=True, method,
endpoint_override=None, connect_retries=None, logger=None, json=None,
allow=None, client_name=None, client_version=None, original_ip=None,
microversion=None, microversion_service_type=None, user_agent=None,
status_code_retries=0, retriable_status_codes=None, redirect=None,
rate_semaphore=None, global_request_id=None, authenticated=None,
connect_retry_delay=None, status_code_retry_delay=None, endpoint_filter=None,
**kwargs): auth=None,
requests_auth=None,
raise_exc=True,
allow_reauth=True,
log=True,
endpoint_override=None,
connect_retries=None,
logger=None,
allow=None,
client_name=None,
client_version=None,
microversion=None,
microversion_service_type=None,
status_code_retries=0,
retriable_status_codes=None,
rate_semaphore=None,
global_request_id=None,
connect_retry_delay=None,
status_code_retry_delay=None,
**kwargs,
):
"""Send an HTTP request with the specified characteristics. """Send an HTTP request with the specified characteristics.
Wrapper around `requests.Session.request` to handle tasks such as Wrapper around `requests.Session.request` to handle tasks such as
@ -766,21 +830,26 @@ class Session(object):
# case insensitive. # case insensitive.
if kwargs.get('headers'): if kwargs.get('headers'):
kwargs['headers'] = requests.structures.CaseInsensitiveDict( kwargs['headers'] = requests.structures.CaseInsensitiveDict(
kwargs['headers']) kwargs['headers']
)
else: else:
kwargs['headers'] = requests.structures.CaseInsensitiveDict() kwargs['headers'] = requests.structures.CaseInsensitiveDict()
if connect_retries is None: if connect_retries is None:
connect_retries = self._connect_retries connect_retries = self._connect_retries
# HTTP 503 - Service Unavailable # HTTP 503 - Service Unavailable
retriable_status_codes = retriable_status_codes or \ retriable_status_codes = (
_RETRIABLE_STATUS_CODES retriable_status_codes or _RETRIABLE_STATUS_CODES
)
rate_semaphore = rate_semaphore or self._rate_semaphore rate_semaphore = rate_semaphore or self._rate_semaphore
headers = kwargs.setdefault('headers', dict()) headers = kwargs.setdefault('headers', {})
if microversion: if microversion:
self._set_microversion_headers( self._set_microversion_headers(
headers, microversion, microversion_service_type, headers,
endpoint_filter) microversion,
microversion_service_type,
endpoint_filter,
)
if authenticated is None: if authenticated is None:
authenticated = bool(auth or self.auth) authenticated = bool(auth or self.auth)
@ -807,13 +876,14 @@ class Session(object):
if endpoint_override: if endpoint_override:
base_url = endpoint_override % _StringFormatter(self, auth) base_url = endpoint_override % _StringFormatter(self, auth)
elif endpoint_filter: elif endpoint_filter:
base_url = self.get_endpoint(auth, allow=allow, base_url = self.get_endpoint(
**endpoint_filter) auth, allow=allow, **endpoint_filter
)
if not base_url: if not base_url:
raise exceptions.EndpointNotFound() raise exceptions.EndpointNotFound()
url = '%s/%s' % (base_url.rstrip('/'), url.lstrip('/')) url = '{}/{}'.format(base_url.rstrip('/'), url.lstrip('/'))
if self.cert: if self.cert:
kwargs.setdefault('cert', self.cert) kwargs.setdefault('cert', self.cert)
@ -835,17 +905,17 @@ class Session(object):
agent = [] agent = []
if self.app_name and self.app_version: if self.app_name and self.app_version:
agent.append('%s/%s' % (self.app_name, self.app_version)) agent.append(f'{self.app_name}/{self.app_version}')
elif self.app_name: elif self.app_name:
agent.append(self.app_name) agent.append(self.app_name)
if client_name and client_version: if client_name and client_version:
agent.append('%s/%s' % (client_name, client_version)) agent.append(f'{client_name}/{client_version}')
elif client_name: elif client_name:
agent.append(client_name) agent.append(client_name)
for additional in self.additional_user_agent: for additional in self.additional_user_agent:
agent.append('%s/%s' % additional) agent.append('{}/{}'.format(*additional))
if not agent: if not agent:
# NOTE(jamielennox): determine_user_agent will return an empty # NOTE(jamielennox): determine_user_agent will return an empty
@ -861,8 +931,9 @@ class Session(object):
user_agent = headers.setdefault('User-Agent', ' '.join(agent)) user_agent = headers.setdefault('User-Agent', ' '.join(agent))
if self.original_ip: if self.original_ip:
headers.setdefault('Forwarded', headers.setdefault(
'for=%s;by=%s' % (self.original_ip, user_agent)) 'Forwarded', f'for={self.original_ip};by={user_agent}'
)
if json is not None: if json is not None:
headers.setdefault('Content-Type', 'application/json') headers.setdefault('Content-Type', 'application/json')
@ -890,14 +961,18 @@ class Session(object):
# be logged properly, but those sent in the `params` parameter # be logged properly, but those sent in the `params` parameter
# (which the requests library handles) need to be explicitly # (which the requests library handles) need to be explicitly
# picked out so they can be included in the URL that gets loggged. # picked out so they can be included in the URL that gets loggged.
query_params = kwargs.get('params', dict()) query_params = kwargs.get('params', {})
if log: if log:
self._http_log_request(url, method=method, self._http_log_request(
url,
method=method,
data=kwargs.get('data'), data=kwargs.get('data'),
headers=headers, headers=headers,
query_params=query_params, query_params=query_params,
logger=logger, split_loggers=split_loggers) logger=logger,
split_loggers=split_loggers,
)
# Force disable requests redirect handling. We will manage this below. # Force disable requests redirect handling. We will manage this below.
kwargs['allow_redirects'] = False kwargs['allow_redirects'] = False
@ -908,12 +983,21 @@ class Session(object):
connect_retry_delays = _Retries(connect_retry_delay) connect_retry_delays = _Retries(connect_retry_delay)
status_code_retry_delays = _Retries(status_code_retry_delay) status_code_retry_delays = _Retries(status_code_retry_delay)
send = functools.partial(self._send_request, send = functools.partial(
url, method, redirect, log, logger, self._send_request,
split_loggers, connect_retries, url,
status_code_retries, retriable_status_codes, method,
rate_semaphore, connect_retry_delays, redirect,
status_code_retry_delays) log,
logger,
split_loggers,
connect_retries,
status_code_retries,
retriable_status_codes,
rate_semaphore,
connect_retry_delays,
status_code_retry_delays,
)
try: try:
connection_params = self.get_auth_connection_params(auth=auth) connection_params = self.get_auth_connection_params(auth=auth)
@ -942,8 +1026,9 @@ class Session(object):
# Nova uses 'x-compute-request-id' and other services like # Nova uses 'x-compute-request-id' and other services like
# Glance, Cinder etc are using 'x-openstack-request-id' to store # Glance, Cinder etc are using 'x-openstack-request-id' to store
# request-id in the header # request-id in the header
request_id = (resp.headers.get('x-openstack-request-id') or request_id = resp.headers.get(
resp.headers.get('x-compute-request-id')) 'x-openstack-request-id'
) or resp.headers.get('x-compute-request-id')
if request_id: if request_id:
if self._get_split_loggers(split_loggers): if self._get_split_loggers(split_loggers):
id_logger = utils.get_logger(__name__ + '.request-id') id_logger = utils.get_logger(__name__ + '.request-id')
@ -953,21 +1038,25 @@ class Session(object):
id_logger.debug( id_logger.debug(
'%(method)s call to %(service_name)s for ' '%(method)s call to %(service_name)s for '
'%(url)s used request id ' '%(url)s used request id '
'%(response_request_id)s', { '%(response_request_id)s',
{
'method': resp.request.method, 'method': resp.request.method,
'service_name': service_name, 'service_name': service_name,
'url': resp.url, 'url': resp.url,
'response_request_id': request_id 'response_request_id': request_id,
}) },
)
else: else:
id_logger.debug( id_logger.debug(
'%(method)s call to ' '%(method)s call to '
'%(url)s used request id ' '%(url)s used request id '
'%(response_request_id)s', { '%(response_request_id)s',
{
'method': resp.request.method, 'method': resp.request.method,
'url': resp.url, 'url': resp.url,
'response_request_id': request_id 'response_request_id': request_id,
}) },
)
# handle getting a 401 Unauthorized response by invalidating the plugin # handle getting a 401 Unauthorized response by invalidating the plugin
# and then retrying the request. This is only tried once. # and then retrying the request. This is only tried once.
@ -980,30 +1069,46 @@ class Session(object):
resp = send(**kwargs) resp = send(**kwargs)
if raise_exc and resp.status_code >= 400: if raise_exc and resp.status_code >= 400:
logger.debug('Request returned failure status: %s', logger.debug(
resp.status_code) 'Request returned failure status: %s', resp.status_code
)
raise exceptions.from_response(resp, method, url) raise exceptions.from_response(resp, method, url)
if self._collect_timing: if self._collect_timing:
for h in resp.history: for h in resp.history:
self._api_times.append(RequestTiming( self._api_times.append(
RequestTiming(
method=h.request.method, method=h.request.method,
url=h.request.url, url=h.request.url,
elapsed=h.elapsed, elapsed=h.elapsed,
)) )
self._api_times.append(RequestTiming( )
self._api_times.append(
RequestTiming(
method=resp.request.method, method=resp.request.method,
url=resp.request.url, url=resp.request.url,
elapsed=resp.elapsed, elapsed=resp.elapsed,
)) )
)
return resp return resp
def _send_request(self, url, method, redirect, log, logger, split_loggers, def _send_request(
connect_retries, status_code_retries, self,
retriable_status_codes, rate_semaphore, url,
connect_retry_delays, status_code_retry_delays, method,
**kwargs): redirect,
log,
logger,
split_loggers,
connect_retries,
status_code_retries,
retriable_status_codes,
rate_semaphore,
connect_retry_delays,
status_code_retry_delays,
**kwargs,
):
# NOTE(jamielennox): We handle redirection manually because the # NOTE(jamielennox): We handle redirection manually because the
# requests lib follows some browser patterns where it will redirect # requests lib follows some browser patterns where it will redirect
# POSTs as GETs for certain statuses which is not want we want for an # POSTs as GETs for certain statuses which is not want we want for an
@ -1020,11 +1125,10 @@ class Session(object):
with rate_semaphore: with rate_semaphore:
resp = self.session.request(method, url, **kwargs) resp = self.session.request(method, url, **kwargs)
except requests.exceptions.SSLError as e: except requests.exceptions.SSLError as e:
msg = 'SSL exception connecting to %(url)s: %(error)s' % { msg = f'SSL exception connecting to {url}: {e}'
'url': url, 'error': e}
raise exceptions.SSLError(msg) raise exceptions.SSLError(msg)
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
msg = 'Request to %s timed out' % url msg = f'Request to {url} timed out'
raise exceptions.ConnectTimeout(msg) raise exceptions.ConnectTimeout(msg)
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
# NOTE(sdague): urllib3/requests connection error is a # NOTE(sdague): urllib3/requests connection error is a
@ -1033,11 +1137,10 @@ class Session(object):
# level message is often really important in figuring # level message is often really important in figuring
# out the difference between network misconfigurations # out the difference between network misconfigurations
# and firewall blocking. # and firewall blocking.
msg = 'Unable to establish connection to %s: %s' % (url, e) msg = f'Unable to establish connection to {url}: {e}'
raise exceptions.ConnectFailure(msg) raise exceptions.ConnectFailure(msg)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
msg = 'Unexpected exception for %(url)s: %(error)s' % { msg = f'Unexpected exception for {url}: {e}'
'url': url, 'error': e}
raise exceptions.UnknownConnectionError(msg, e) raise exceptions.UnknownConnectionError(msg, e)
except exceptions.RetriableConnectionFailure as e: except exceptions.RetriableConnectionFailure as e:
@ -1045,26 +1148,33 @@ class Session(object):
raise raise
delay = next(connect_retry_delays) delay = next(connect_retry_delays)
logger.warning('Failure: %(e)s. Retrying in %(delay).1fs.' logger.warning(
'Failure: %(e)s. Retrying in %(delay).1fs.'
'%(retries)s retries left', '%(retries)s retries left',
{'e': e, 'delay': delay, {'e': e, 'delay': delay, 'retries': connect_retries},
'retries': connect_retries}) )
time.sleep(delay) time.sleep(delay)
return self._send_request( return self._send_request(
url, method, redirect, log, logger, split_loggers, url,
method,
redirect,
log,
logger,
split_loggers,
status_code_retries=status_code_retries, status_code_retries=status_code_retries,
retriable_status_codes=retriable_status_codes, retriable_status_codes=retriable_status_codes,
rate_semaphore=rate_semaphore, rate_semaphore=rate_semaphore,
connect_retries=connect_retries - 1, connect_retries=connect_retries - 1,
connect_retry_delays=connect_retry_delays, connect_retry_delays=connect_retry_delays,
status_code_retry_delays=status_code_retry_delays, status_code_retry_delays=status_code_retry_delays,
**kwargs) **kwargs,
)
if log: if log:
self._http_log_response( self._http_log_response(
response=resp, logger=logger, response=resp, logger=logger, split_loggers=split_loggers
split_loggers=split_loggers) )
if resp.status_code in self._REDIRECT_STATUSES: if resp.status_code in self._REDIRECT_STATUSES:
# be careful here in python True == 1 and False == 0 # be careful here in python True == 1 and False == 0
@ -1080,8 +1190,11 @@ class Session(object):
try: try:
location = resp.headers['location'] location = resp.headers['location']
except KeyError: except KeyError:
logger.warning("Failed to redirect request to %s as new " logger.warning(
"location was not provided.", resp.url) "Failed to redirect request to %s as new "
"location was not provided.",
resp.url,
)
else: else:
# NOTE(TheJulia): Location redirects generally should have # NOTE(TheJulia): Location redirects generally should have
# URI's to the destination. # URI's to the destination.
@ -1090,50 +1203,69 @@ class Session(object):
kwargs['params'] = {} kwargs['params'] = {}
if 'x-openstack-request-id' in resp.headers: if 'x-openstack-request-id' in resp.headers:
kwargs['headers'].setdefault('x-openstack-request-id', kwargs['headers'].setdefault(
resp.headers[ 'x-openstack-request-id',
'x-openstack-request-id']) resp.headers['x-openstack-request-id'],
)
# NOTE(jamielennox): We don't keep increasing delays. # NOTE(jamielennox): We don't keep increasing delays.
# This request actually worked so we can reset the delay count. # This request actually worked so we can reset the delay count.
connect_retry_delays.reset() connect_retry_delays.reset()
status_code_retry_delays.reset() status_code_retry_delays.reset()
new_resp = self._send_request( new_resp = self._send_request(
location, method, redirect, log, logger, split_loggers, location,
method,
redirect,
log,
logger,
split_loggers,
rate_semaphore=rate_semaphore, rate_semaphore=rate_semaphore,
connect_retries=connect_retries, connect_retries=connect_retries,
status_code_retries=status_code_retries, status_code_retries=status_code_retries,
retriable_status_codes=retriable_status_codes, retriable_status_codes=retriable_status_codes,
connect_retry_delays=connect_retry_delays, connect_retry_delays=connect_retry_delays,
status_code_retry_delays=status_code_retry_delays, status_code_retry_delays=status_code_retry_delays,
**kwargs) **kwargs,
)
if not isinstance(new_resp.history, list): if not isinstance(new_resp.history, list):
new_resp.history = list(new_resp.history) new_resp.history = list(new_resp.history)
new_resp.history.insert(0, resp) new_resp.history.insert(0, resp)
resp = new_resp resp = new_resp
elif (resp.status_code in retriable_status_codes and elif (
status_code_retries > 0): resp.status_code in retriable_status_codes
and status_code_retries > 0
):
delay = next(status_code_retry_delays) delay = next(status_code_retry_delays)
logger.warning('Retriable status code %(code)s. Retrying in ' logger.warning(
'Retriable status code %(code)s. Retrying in '
'%(delay).1fs. %(retries)s retries left', '%(delay).1fs. %(retries)s retries left',
{'code': resp.status_code, 'delay': delay, {
'retries': status_code_retries}) 'code': resp.status_code,
'delay': delay,
'retries': status_code_retries,
},
)
time.sleep(delay) time.sleep(delay)
# NOTE(jamielennox): We don't keep increasing connection delays. # NOTE(jamielennox): We don't keep increasing connection delays.
# This request actually worked so we can reset the delay count. # This request actually worked so we can reset the delay count.
connect_retry_delays.reset() connect_retry_delays.reset()
return self._send_request( return self._send_request(
url, method, redirect, log, logger, split_loggers, url,
method,
redirect,
log,
logger,
split_loggers,
connect_retries=connect_retries, connect_retries=connect_retries,
status_code_retries=status_code_retries - 1, status_code_retries=status_code_retries - 1,
retriable_status_codes=retriable_status_codes, retriable_status_codes=retriable_status_codes,
rate_semaphore=rate_semaphore, rate_semaphore=rate_semaphore,
connect_retry_delays=connect_retry_delays, connect_retry_delays=connect_retry_delays,
status_code_retry_delays=status_code_retry_delays, status_code_retry_delays=status_code_retry_delays,
**kwargs) **kwargs,
)
return resp return resp
@ -1288,9 +1420,14 @@ class Session(object):
auth = self._auth_required(auth, 'determine endpoint URL') auth = self._auth_required(auth, 'determine endpoint URL')
return auth.get_api_major_version(self, **kwargs) return auth.get_api_major_version(self, **kwargs)
def get_all_version_data(self, auth=None, interface='public', def get_all_version_data(
region_name=None, service_type=None, self,
**kwargs): auth=None,
interface='public',
region_name=None,
service_type=None,
**kwargs,
):
"""Get version data for all services in the catalog. """Get version data for all services in the catalog.
:param auth: :param auth:
@ -1318,7 +1455,8 @@ class Session(object):
interface=interface, interface=interface,
region_name=region_name, region_name=region_name,
service_type=service_type, service_type=service_type,
**kwargs) **kwargs,
)
def get_auth_connection_params(self, auth=None, **kwargs): def get_auth_connection_params(self, auth=None, **kwargs):
"""Return auth connection params as provided by the auth plugin. """Return auth connection params as provided by the auth plugin.
@ -1461,17 +1599,19 @@ class TCPKeepAliveAdapter(requests.adapters.HTTPAdapter):
] ]
# Windows subsystem for Linux does not support this feature # Windows subsystem for Linux does not support this feature
if (hasattr(socket, 'TCP_KEEPCNT') and if (
not utils.is_windows_linux_subsystem): hasattr(socket, 'TCP_KEEPCNT')
and not utils.is_windows_linux_subsystem
):
socket_options += [ socket_options += [
# Set the maximum number of keep-alive probes # Set the maximum number of keep-alive probes
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4), (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
] ]
if hasattr(socket, 'TCP_KEEPINTVL'): if hasattr(socket, 'TCP_KEEPINTVL'):
socket_options += [ socket_options += [
# Send keep-alive probes every 15 seconds # Send keep-alive probes every 15 seconds
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15), (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
] ]
# After waiting 60 seconds, and then sending a probe once every 15 # After waiting 60 seconds, and then sending a probe once every 15
@ -1479,4 +1619,4 @@ class TCPKeepAliveAdapter(requests.adapters.HTTPAdapter):
# hands for no longer than 2 minutes before a ConnectionError is # hands for no longer than 2 minutes before a ConnectionError is
# raised. # raised.
kwargs['socket_options'] = socket_options kwargs['socket_options'] = socket_options
super(TCPKeepAliveAdapter, self).init_poolmanager(*args, **kwargs) super().init_poolmanager(*args, **kwargs)

View File

@ -21,7 +21,6 @@ from keystoneauth1.tests.unit import utils
class AccessV2Test(utils.TestCase): class AccessV2Test(utils.TestCase):
def test_building_unscoped_accessinfo(self): def test_building_unscoped_accessinfo(self):
token = fixture.V2Token(expires='2012-10-03T16:58:01Z') token = fixture.V2Token(expires='2012-10-03T16:58:01Z')
@ -115,13 +114,10 @@ class AccessV2Test(utils.TestCase):
'user': { 'user': {
'id': 'user_id1', 'id': 'user_id1',
'name': 'user_name1', 'name': 'user_name1',
'roles': [ 'roles': [{'name': 'role1'}, {'name': 'role2'}],
{'name': 'role1'},
{'name': 'role2'},
],
},
}, },
} }
}
auth_ref = access.create(body=diablo_token) auth_ref = access.create(body=diablo_token)
self.assertIsInstance(auth_ref, access.AccessInfoV2) self.assertIsInstance(auth_ref, access.AccessInfoV2)
@ -148,13 +144,10 @@ class AccessV2Test(utils.TestCase):
'name': 'user_name1', 'name': 'user_name1',
'tenantId': 'tenant_id1', 'tenantId': 'tenant_id1',
'tenantName': 'tenant_name1', 'tenantName': 'tenant_name1',
'roles': [ 'roles': [{'name': 'role1'}, {'name': 'role2'}],
{'name': 'role1'},
{'name': 'role2'},
],
},
}, },
} }
}
auth_ref = access.create(body=grizzly_token) auth_ref = access.create(body=grizzly_token)
self.assertIsInstance(auth_ref, access.AccessInfoV2) self.assertIsInstance(auth_ref, access.AccessInfoV2)
@ -179,11 +172,13 @@ class AccessV2Test(utils.TestCase):
self.assertIsInstance(auth_ref, access.AccessInfoV2) self.assertIsInstance(auth_ref, access.AccessInfoV2)
self.assertEqual([role_id], auth_ref.role_ids) self.assertEqual([role_id], auth_ref.role_ids)
self.assertEqual([role_id], self.assertEqual(
auth_ref._data['access']['metadata']['roles']) [role_id], auth_ref._data['access']['metadata']['roles']
)
self.assertEqual([role_name], auth_ref.role_names) self.assertEqual([role_name], auth_ref.role_names)
self.assertEqual([{'name': role_name}], self.assertEqual(
auth_ref._data['access']['user']['roles']) [{'name': role_name}], auth_ref._data['access']['user']['roles']
)
def test_trusts(self): def test_trusts(self):
user_id = uuid.uuid4().hex user_id = uuid.uuid4().hex

View File

@ -20,7 +20,7 @@ from keystoneauth1.tests.unit import utils
class ServiceCatalogTest(utils.TestCase): class ServiceCatalogTest(utils.TestCase):
def setUp(self): def setUp(self):
super(ServiceCatalogTest, self).setUp() super().setUp()
self.AUTH_RESPONSE_BODY = fixture.V2Token( self.AUTH_RESPONSE_BODY = fixture.V2Token(
token_id='ab48a9efdfedb23ty3494', token_id='ab48a9efdfedb23ty3494',
@ -29,18 +29,21 @@ class ServiceCatalogTest(utils.TestCase):
tenant_name='My Project', tenant_name='My Project',
user_id='123', user_id='123',
user_name='jqsmith', user_name='jqsmith',
audit_chain_id=uuid.uuid4().hex) audit_chain_id=uuid.uuid4().hex,
)
self.AUTH_RESPONSE_BODY.add_role(id='234', name='compute:admin') self.AUTH_RESPONSE_BODY.add_role(id='234', name='compute:admin')
role = self.AUTH_RESPONSE_BODY.add_role(id='235', role = self.AUTH_RESPONSE_BODY.add_role(
name='object-store:admin') id='235', name='object-store:admin'
)
role['tenantId'] = '1' role['tenantId'] = '1'
s = self.AUTH_RESPONSE_BODY.add_service('compute', 'Cloud Servers') s = self.AUTH_RESPONSE_BODY.add_service('compute', 'Cloud Servers')
endpoint = s.add_endpoint( endpoint = s.add_endpoint(
public='https://compute.north.host/v1/1234', public='https://compute.north.host/v1/1234',
internal='https://compute.north.host/v1/1234', internal='https://compute.north.host/v1/1234',
region='North') region='North',
)
endpoint['tenantId'] = '1' endpoint['tenantId'] = '1'
endpoint['versionId'] = '1.0' endpoint['versionId'] = '1.0'
endpoint['versionInfo'] = 'https://compute.north.host/v1.0/' endpoint['versionInfo'] = 'https://compute.north.host/v1.0/'
@ -49,16 +52,19 @@ class ServiceCatalogTest(utils.TestCase):
endpoint = s.add_endpoint( endpoint = s.add_endpoint(
public='https://compute.north.host/v1.1/3456', public='https://compute.north.host/v1.1/3456',
internal='https://compute.north.host/v1.1/3456', internal='https://compute.north.host/v1.1/3456',
region='North') region='North',
)
endpoint['tenantId'] = '2' endpoint['tenantId'] = '2'
endpoint['versionId'] = '1.1' endpoint['versionId'] = '1.1'
endpoint['versionInfo'] = 'https://compute.north.host/v1.1/' endpoint['versionInfo'] = 'https://compute.north.host/v1.1/'
endpoint['versionList'] = 'https://compute.north.host/' endpoint['versionList'] = 'https://compute.north.host/'
s = self.AUTH_RESPONSE_BODY.add_service('object-store', 'Cloud Files') s = self.AUTH_RESPONSE_BODY.add_service('object-store', 'Cloud Files')
endpoint = s.add_endpoint(public='https://swift.north.host/v1/blah', endpoint = s.add_endpoint(
public='https://swift.north.host/v1/blah',
internal='https://swift.north.host/v1/blah', internal='https://swift.north.host/v1/blah',
region='South') region='South',
)
endpoint['tenantId'] = '11' endpoint['tenantId'] = '11'
endpoint['versionId'] = '1.0' endpoint['versionId'] = '1.0'
endpoint['versionInfo'] = 'uri' endpoint['versionInfo'] = 'uri'
@ -67,48 +73,62 @@ class ServiceCatalogTest(utils.TestCase):
endpoint = s.add_endpoint( endpoint = s.add_endpoint(
public='https://swift.north.host/v1.1/blah', public='https://swift.north.host/v1.1/blah',
internal='https://compute.north.host/v1.1/blah', internal='https://compute.north.host/v1.1/blah',
region='South') region='South',
)
endpoint['tenantId'] = '2' endpoint['tenantId'] = '2'
endpoint['versionId'] = '1.1' endpoint['versionId'] = '1.1'
endpoint['versionInfo'] = 'https://swift.north.host/v1.1/' endpoint['versionInfo'] = 'https://swift.north.host/v1.1/'
endpoint['versionList'] = 'https://swift.north.host/' endpoint['versionList'] = 'https://swift.north.host/'
s = self.AUTH_RESPONSE_BODY.add_service('image', 'Image Servers') s = self.AUTH_RESPONSE_BODY.add_service('image', 'Image Servers')
s.add_endpoint(public='https://image.north.host/v1/', s.add_endpoint(
public='https://image.north.host/v1/',
internal='https://image-internal.north.host/v1/', internal='https://image-internal.north.host/v1/',
region='North') region='North',
s.add_endpoint(public='https://image.south.host/v1/', )
s.add_endpoint(
public='https://image.south.host/v1/',
internal='https://image-internal.south.host/v1/', internal='https://image-internal.south.host/v1/',
region='South') region='South',
)
def test_building_a_service_catalog(self): def test_building_a_service_catalog(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY) auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
self.assertEqual(sc.url_for(service_type='compute'), self.assertEqual(
"https://compute.north.host/v1/1234") sc.url_for(service_type='compute'),
self.assertRaises(exceptions.EndpointNotFound, "https://compute.north.host/v1/1234",
)
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for, sc.url_for,
region_name="South", region_name="South",
service_type='compute') service_type='compute',
)
def test_service_catalog_endpoints(self): def test_service_catalog_endpoints(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY) auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
public_ep = sc.get_endpoints(service_type='compute', public_ep = sc.get_endpoints(
interface='publicURL') service_type='compute', interface='publicURL'
)
self.assertEqual(public_ep['compute'][1]['tenantId'], '2') self.assertEqual(public_ep['compute'][1]['tenantId'], '2')
self.assertEqual(public_ep['compute'][1]['versionId'], '1.1') self.assertEqual(public_ep['compute'][1]['versionId'], '1.1')
self.assertEqual(public_ep['compute'][1]['internalURL'], self.assertEqual(
"https://compute.north.host/v1.1/3456") public_ep['compute'][1]['internalURL'],
"https://compute.north.host/v1.1/3456",
)
def test_service_catalog_empty(self): def test_service_catalog_empty(self):
self.AUTH_RESPONSE_BODY['access']['serviceCatalog'] = [] self.AUTH_RESPONSE_BODY['access']['serviceCatalog'] = []
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY) auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
self.assertRaises(exceptions.EmptyCatalog, self.assertRaises(
exceptions.EmptyCatalog,
auth_ref.service_catalog.url_for, auth_ref.service_catalog.url_for,
service_type='image', service_type='image',
interface='internalURL') interface='internalURL',
)
def test_service_catalog_get_endpoints_region_names(self): def test_service_catalog_get_endpoints_region_names(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY) auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
@ -116,23 +136,27 @@ class ServiceCatalogTest(utils.TestCase):
endpoints = sc.get_endpoints(service_type='image', region_name='North') endpoints = sc.get_endpoints(service_type='image', region_name='North')
self.assertEqual(len(endpoints), 1) self.assertEqual(len(endpoints), 1)
self.assertEqual(endpoints['image'][0]['publicURL'], self.assertEqual(
'https://image.north.host/v1/') endpoints['image'][0]['publicURL'], 'https://image.north.host/v1/'
)
endpoints = sc.get_endpoints(service_type='image', region_name='South') endpoints = sc.get_endpoints(service_type='image', region_name='South')
self.assertEqual(len(endpoints), 1) self.assertEqual(len(endpoints), 1)
self.assertEqual(endpoints['image'][0]['publicURL'], self.assertEqual(
'https://image.south.host/v1/') endpoints['image'][0]['publicURL'], 'https://image.south.host/v1/'
)
endpoints = sc.get_endpoints(service_type='compute') endpoints = sc.get_endpoints(service_type='compute')
self.assertEqual(len(endpoints['compute']), 2) self.assertEqual(len(endpoints['compute']), 2)
endpoints = sc.get_endpoints(service_type='compute', endpoints = sc.get_endpoints(
region_name='North') service_type='compute', region_name='North'
)
self.assertEqual(len(endpoints['compute']), 2) self.assertEqual(len(endpoints['compute']), 2)
endpoints = sc.get_endpoints(service_type='compute', endpoints = sc.get_endpoints(
region_name='West') service_type='compute', region_name='West'
)
self.assertEqual(len(endpoints['compute']), 0) self.assertEqual(len(endpoints['compute']), 0)
def test_service_catalog_url_for_region_names(self): def test_service_catalog_url_for_region_names(self):
@ -145,8 +169,12 @@ class ServiceCatalogTest(utils.TestCase):
url = sc.url_for(service_type='image', region_name='South') url = sc.url_for(service_type='image', region_name='South')
self.assertEqual(url, 'https://image.south.host/v1/') self.assertEqual(url, 'https://image.south.host/v1/')
self.assertRaises(exceptions.EndpointNotFound, sc.url_for, self.assertRaises(
service_type='image', region_name='West') exceptions.EndpointNotFound,
sc.url_for,
service_type='image',
region_name='West',
)
def test_servcie_catalog_get_url_region_names(self): def test_servcie_catalog_get_url_region_names(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY) auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
@ -170,21 +198,33 @@ class ServiceCatalogTest(utils.TestCase):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY) auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
url = sc.url_for(service_name='Image Servers', interface='public', url = sc.url_for(
service_type='image', region_name='North') service_name='Image Servers',
interface='public',
service_type='image',
region_name='North',
)
self.assertEqual('https://image.north.host/v1/', url) self.assertEqual('https://image.north.host/v1/', url)
self.assertRaises(exceptions.EndpointNotFound, sc.url_for, self.assertRaises(
service_name='Image Servers', service_type='compute') exceptions.EndpointNotFound,
sc.url_for,
service_name='Image Servers',
service_type='compute',
)
urls = sc.get_urls(service_type='image', service_name='Image Servers', urls = sc.get_urls(
interface='public') service_type='image',
service_name='Image Servers',
interface='public',
)
self.assertIn('https://image.north.host/v1/', urls) self.assertIn('https://image.north.host/v1/', urls)
self.assertIn('https://image.south.host/v1/', urls) self.assertIn('https://image.south.host/v1/', urls)
urls = sc.get_urls(service_type='image', service_name='Servers', urls = sc.get_urls(
interface='public') service_type='image', service_name='Servers', interface='public'
)
self.assertEqual(0, len(urls)) self.assertEqual(0, len(urls))
@ -194,23 +234,28 @@ class ServiceCatalogTest(utils.TestCase):
for i in range(3): for i in range(3):
s = token.add_service('compute') s = token.add_service('compute')
s.add_endpoint(public='public-%d' % i, s.add_endpoint(
public='public-%d' % i,
admin='admin-%d' % i, admin='admin-%d' % i,
internal='internal-%d' % i, internal='internal-%d' % i,
region='region-%d' % i) region='region-%d' % i,
)
auth_ref = access.create(body=token) auth_ref = access.create(body=token)
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
interface='publicURL') service_type='compute', interface='publicURL'
)
self.assertEqual(set(['public-0', 'public-1', 'public-2']), set(urls)) self.assertEqual({'public-0', 'public-1', 'public-2'}, set(urls))
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
service_type='compute',
interface='publicURL', interface='publicURL',
region_name='region-1') region_name='region-1',
)
self.assertEqual(('public-1', ), urls) self.assertEqual(('public-1',), urls)
def test_service_catalog_endpoint_id(self): def test_service_catalog_endpoint_id(self):
token = fixture.V2Token() token = fixture.V2Token()
@ -228,23 +273,27 @@ class ServiceCatalogTest(utils.TestCase):
urls = auth_ref.service_catalog.get_urls(interface='public') urls = auth_ref.service_catalog.get_urls(interface='public')
self.assertEqual(2, len(urls)) self.assertEqual(2, len(urls))
urls = auth_ref.service_catalog.get_urls(endpoint_id=endpoint_id, urls = auth_ref.service_catalog.get_urls(
interface='public') endpoint_id=endpoint_id, interface='public'
)
self.assertEqual((public_url, ), urls) self.assertEqual((public_url,), urls)
# with bad endpoint_id nothing should be found # with bad endpoint_id nothing should be found
urls = auth_ref.service_catalog.get_urls(endpoint_id=uuid.uuid4().hex, urls = auth_ref.service_catalog.get_urls(
interface='public') endpoint_id=uuid.uuid4().hex, interface='public'
)
self.assertEqual(0, len(urls)) self.assertEqual(0, len(urls))
# we ignore a service_id because v2 doesn't know what it is # we ignore a service_id because v2 doesn't know what it is
urls = auth_ref.service_catalog.get_urls(endpoint_id=endpoint_id, urls = auth_ref.service_catalog.get_urls(
endpoint_id=endpoint_id,
service_id=uuid.uuid4().hex, service_id=uuid.uuid4().hex,
interface='public') interface='public',
)
self.assertEqual((public_url, ), urls) self.assertEqual((public_url,), urls)
def test_service_catalog_without_service_type(self): def test_service_catalog_without_service_type(self):
token = fixture.V2Token() token = fixture.V2Token()
@ -260,8 +309,9 @@ class ServiceCatalogTest(utils.TestCase):
s.add_endpoint(public=public_url) s.add_endpoint(public=public_url)
auth_ref = access.create(body=token) auth_ref = access.create(body=token)
urls = auth_ref.service_catalog.get_urls(service_type=None, urls = auth_ref.service_catalog.get_urls(
interface='public') service_type=None, interface='public'
)
self.assertEqual(3, len(urls)) self.assertEqual(3, len(urls))

View File

@ -21,7 +21,6 @@ from keystoneauth1.tests.unit import utils
class AccessV3Test(utils.TestCase): class AccessV3Test(utils.TestCase):
def test_building_unscoped_accessinfo(self): def test_building_unscoped_accessinfo(self):
token = fixture.V3Token() token = fixture.V3Token()
token_id = uuid.uuid4().hex token_id = uuid.uuid4().hex
@ -52,10 +51,14 @@ class AccessV3Test(utils.TestCase):
self.assertIsNone(auth_ref.project_domain_id) self.assertIsNone(auth_ref.project_domain_id)
self.assertIsNone(auth_ref.project_domain_name) self.assertIsNone(auth_ref.project_domain_name)
self.assertEqual(auth_ref.expires, timeutils.parse_isotime( self.assertEqual(
token['token']['expires_at'])) auth_ref.expires,
self.assertEqual(auth_ref.issued, timeutils.parse_isotime( timeutils.parse_isotime(token['token']['expires_at']),
token['token']['issued_at'])) )
self.assertEqual(
auth_ref.issued,
timeutils.parse_isotime(token['token']['issued_at']),
)
self.assertEqual(auth_ref.expires, token.expires) self.assertEqual(auth_ref.expires, token.expires)
self.assertEqual(auth_ref.issued, token.issued) self.assertEqual(auth_ref.issued, token.issued)
@ -197,8 +200,9 @@ class AccessV3Test(utils.TestCase):
self.assertEqual(auth_ref.tenant_id, auth_ref.project_id) self.assertEqual(auth_ref.tenant_id, auth_ref.project_id)
self.assertEqual(token.project_domain_id, auth_ref.project_domain_id) self.assertEqual(token.project_domain_id, auth_ref.project_domain_id)
self.assertEqual(token.project_domain_name, self.assertEqual(
auth_ref.project_domain_name) token.project_domain_name, auth_ref.project_domain_name
)
self.assertEqual(token.user_domain_id, auth_ref.user_domain_id) self.assertEqual(token.user_domain_id, auth_ref.user_domain_id)
self.assertEqual(token.user_domain_name, auth_ref.user_domain_name) self.assertEqual(token.user_domain_name, auth_ref.user_domain_name)
@ -243,8 +247,9 @@ class AccessV3Test(utils.TestCase):
self.assertEqual(auth_ref.tenant_id, auth_ref.project_id) self.assertEqual(auth_ref.tenant_id, auth_ref.project_id)
self.assertEqual(token.project_domain_id, auth_ref.project_domain_id) self.assertEqual(token.project_domain_id, auth_ref.project_domain_id)
self.assertEqual(token.project_domain_name, self.assertEqual(
auth_ref.project_domain_name) token.project_domain_name, auth_ref.project_domain_name
)
self.assertEqual(token.user_domain_id, auth_ref.user_domain_id) self.assertEqual(token.user_domain_id, auth_ref.user_domain_id)
self.assertEqual(token.user_domain_name, auth_ref.user_domain_name) self.assertEqual(token.user_domain_name, auth_ref.user_domain_name)
@ -262,19 +267,22 @@ class AccessV3Test(utils.TestCase):
token = fixture.V3Token() token = fixture.V3Token()
token.set_project_scope() token.set_project_scope()
token.set_oauth(access_token_id=access_token_id, token.set_oauth(
consumer_id=consumer_id) access_token_id=access_token_id, consumer_id=consumer_id
)
auth_ref = access.create(body=token) auth_ref = access.create(body=token)
self.assertEqual(consumer_id, auth_ref.oauth_consumer_id) self.assertEqual(consumer_id, auth_ref.oauth_consumer_id)
self.assertEqual(access_token_id, auth_ref.oauth_access_token_id) self.assertEqual(access_token_id, auth_ref.oauth_access_token_id)
self.assertEqual(consumer_id, self.assertEqual(
auth_ref._data['token']['OS-OAUTH1']['consumer_id']) consumer_id, auth_ref._data['token']['OS-OAUTH1']['consumer_id']
)
self.assertEqual( self.assertEqual(
access_token_id, access_token_id,
auth_ref._data['token']['OS-OAUTH1']['access_token_id']) auth_ref._data['token']['OS-OAUTH1']['access_token_id'],
)
def test_federated_property_standard_token(self): def test_federated_property_standard_token(self):
"""Check if is_federated property returns expected value.""" """Check if is_federated property returns expected value."""

View File

@ -20,10 +20,11 @@ from keystoneauth1.tests.unit import utils
class ServiceCatalogTest(utils.TestCase): class ServiceCatalogTest(utils.TestCase):
def setUp(self): def setUp(self):
super(ServiceCatalogTest, self).setUp() super().setUp()
self.AUTH_RESPONSE_BODY = fixture.V3Token( self.AUTH_RESPONSE_BODY = fixture.V3Token(
audit_chain_id=uuid.uuid4().hex) audit_chain_id=uuid.uuid4().hex
)
self.AUTH_RESPONSE_BODY.set_project_scope() self.AUTH_RESPONSE_BODY.set_project_scope()
self.AUTH_RESPONSE_BODY.add_role(name='admin') self.AUTH_RESPONSE_BODY.add_role(name='admin')
@ -34,207 +35,251 @@ class ServiceCatalogTest(utils.TestCase):
public='https://compute.north.host/novapi/public', public='https://compute.north.host/novapi/public',
internal='https://compute.north.host/novapi/internal', internal='https://compute.north.host/novapi/internal',
admin='https://compute.north.host/novapi/admin', admin='https://compute.north.host/novapi/admin',
region='North') region='North',
)
s = self.AUTH_RESPONSE_BODY.add_service('object-store', name='swift') s = self.AUTH_RESPONSE_BODY.add_service('object-store', name='swift')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://swift.north.host/swiftapi/public', public='http://swift.north.host/swiftapi/public',
internal='http://swift.north.host/swiftapi/internal', internal='http://swift.north.host/swiftapi/internal',
admin='http://swift.north.host/swiftapi/admin', admin='http://swift.north.host/swiftapi/admin',
region='South') region='South',
)
s = self.AUTH_RESPONSE_BODY.add_service('image', name='glance') s = self.AUTH_RESPONSE_BODY.add_service('image', name='glance')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://glance.north.host/glanceapi/public', public='http://glance.north.host/glanceapi/public',
internal='http://glance.north.host/glanceapi/internal', internal='http://glance.north.host/glanceapi/internal',
admin='http://glance.north.host/glanceapi/admin', admin='http://glance.north.host/glanceapi/admin',
region='North') region='North',
)
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://glance.south.host/glanceapi/public', public='http://glance.south.host/glanceapi/public',
internal='http://glance.south.host/glanceapi/internal', internal='http://glance.south.host/glanceapi/internal',
admin='http://glance.south.host/glanceapi/admin', admin='http://glance.south.host/glanceapi/admin',
region='South') region='South',
)
s = self.AUTH_RESPONSE_BODY.add_service('block-storage', name='cinder') s = self.AUTH_RESPONSE_BODY.add_service('block-storage', name='cinder')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://cinder.north.host/cinderapi/public', public='http://cinder.north.host/cinderapi/public',
internal='http://cinder.north.host/cinderapi/internal', internal='http://cinder.north.host/cinderapi/internal',
admin='http://cinder.north.host/cinderapi/admin', admin='http://cinder.north.host/cinderapi/admin',
region='North') region='North',
)
s = self.AUTH_RESPONSE_BODY.add_service('volumev2', name='cinder') s = self.AUTH_RESPONSE_BODY.add_service('volumev2', name='cinder')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://cinder.south.host/cinderapi/public/v2', public='http://cinder.south.host/cinderapi/public/v2',
internal='http://cinder.south.host/cinderapi/internal/v2', internal='http://cinder.south.host/cinderapi/internal/v2',
admin='http://cinder.south.host/cinderapi/admin/v2', admin='http://cinder.south.host/cinderapi/admin/v2',
region='South') region='South',
)
s = self.AUTH_RESPONSE_BODY.add_service('volumev3', name='cinder') s = self.AUTH_RESPONSE_BODY.add_service('volumev3', name='cinder')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://cinder.south.host/cinderapi/public/v3', public='http://cinder.south.host/cinderapi/public/v3',
internal='http://cinder.south.host/cinderapi/internal/v3', internal='http://cinder.south.host/cinderapi/internal/v3',
admin='http://cinder.south.host/cinderapi/admin/v3', admin='http://cinder.south.host/cinderapi/admin/v3',
region='South') region='South',
)
self.north_endpoints = {'public': self.north_endpoints = {
'http://glance.north.host/glanceapi/public', 'public': 'http://glance.north.host/glanceapi/public',
'internal': 'internal': 'http://glance.north.host/glanceapi/internal',
'http://glance.north.host/glanceapi/internal', 'admin': 'http://glance.north.host/glanceapi/admin',
'admin': }
'http://glance.north.host/glanceapi/admin'}
self.south_endpoints = {'public': self.south_endpoints = {
'http://glance.south.host/glanceapi/public', 'public': 'http://glance.south.host/glanceapi/public',
'internal': 'internal': 'http://glance.south.host/glanceapi/internal',
'http://glance.south.host/glanceapi/internal', 'admin': 'http://glance.south.host/glanceapi/admin',
'admin': }
'http://glance.south.host/glanceapi/admin'}
def test_building_a_service_catalog(self): def test_building_a_service_catalog(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
self.assertEqual(sc.url_for(service_type='compute'), self.assertEqual(
"https://compute.north.host/novapi/public") sc.url_for(service_type='compute'),
self.assertEqual(sc.url_for(service_type='compute', "https://compute.north.host/novapi/public",
interface='internal'), )
"https://compute.north.host/novapi/internal") self.assertEqual(
sc.url_for(service_type='compute', interface='internal'),
"https://compute.north.host/novapi/internal",
)
self.assertRaises(exceptions.EndpointNotFound, self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for, sc.url_for,
region_name='South', region_name='South',
service_type='compute') service_type='compute',
)
def test_service_catalog_endpoints(self): def test_service_catalog_endpoints(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
public_ep = sc.get_endpoints(service_type='compute', public_ep = sc.get_endpoints(
interface='public') service_type='compute', interface='public'
)
self.assertEqual(public_ep['compute'][0]['region'], 'North') self.assertEqual(public_ep['compute'][0]['region'], 'North')
self.assertEqual(public_ep['compute'][0]['url'], self.assertEqual(
"https://compute.north.host/novapi/public") public_ep['compute'][0]['url'],
"https://compute.north.host/novapi/public",
)
def test_service_catalog_alias_find_official(self): def test_service_catalog_alias_find_official(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
# Tests that we find the block-storage endpoint when we request # Tests that we find the block-storage endpoint when we request
# the volume endpoint. # the volume endpoint.
public_ep = sc.get_endpoints(service_type='volume', public_ep = sc.get_endpoints(
interface='public', service_type='volume', interface='public', region_name='North'
region_name='North') )
self.assertEqual(public_ep['block-storage'][0]['region'], 'North') self.assertEqual(public_ep['block-storage'][0]['region'], 'North')
self.assertEqual(public_ep['block-storage'][0]['url'], self.assertEqual(
"http://cinder.north.host/cinderapi/public") public_ep['block-storage'][0]['url'],
"http://cinder.north.host/cinderapi/public",
)
def test_service_catalog_alias_find_exact_match(self): def test_service_catalog_alias_find_exact_match(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
# Tests that we find the volumev3 endpoint when we request it. # Tests that we find the volumev3 endpoint when we request it.
public_ep = sc.get_endpoints(service_type='volumev3', public_ep = sc.get_endpoints(
interface='public') service_type='volumev3', interface='public'
)
self.assertEqual(public_ep['volumev3'][0]['region'], 'South') self.assertEqual(public_ep['volumev3'][0]['region'], 'South')
self.assertEqual(public_ep['volumev3'][0]['url'], self.assertEqual(
"http://cinder.south.host/cinderapi/public/v3") public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3",
)
def test_service_catalog_alias_find_best_match(self): def test_service_catalog_alias_find_best_match(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
# Tests that we find the volumev3 endpoint when we request # Tests that we find the volumev3 endpoint when we request
# block-storage when only volumev2 and volumev3 are present since # block-storage when only volumev2 and volumev3 are present since
# volumev3 comes first in the list. # volumev3 comes first in the list.
public_ep = sc.get_endpoints(service_type='block-storage', public_ep = sc.get_endpoints(
service_type='block-storage',
interface='public', interface='public',
region_name='South') region_name='South',
)
self.assertEqual(public_ep['volumev3'][0]['region'], 'South') self.assertEqual(public_ep['volumev3'][0]['region'], 'South')
self.assertEqual(public_ep['volumev3'][0]['url'], self.assertEqual(
"http://cinder.south.host/cinderapi/public/v3") public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3",
)
def test_service_catalog_alias_all_by_name(self): def test_service_catalog_alias_all_by_name(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
# Tests that we find all the cinder endpoints since we request # Tests that we find all the cinder endpoints since we request
# them by name and that no filtering related to aliases happens. # them by name and that no filtering related to aliases happens.
public_ep = sc.get_endpoints(service_name='cinder', public_ep = sc.get_endpoints(service_name='cinder', interface='public')
interface='public')
self.assertEqual(public_ep['volumev2'][0]['region'], 'South') self.assertEqual(public_ep['volumev2'][0]['region'], 'South')
self.assertEqual(public_ep['volumev2'][0]['url'], self.assertEqual(
"http://cinder.south.host/cinderapi/public/v2") public_ep['volumev2'][0]['url'],
"http://cinder.south.host/cinderapi/public/v2",
)
self.assertEqual(public_ep['volumev3'][0]['region'], 'South') self.assertEqual(public_ep['volumev3'][0]['region'], 'South')
self.assertEqual(public_ep['volumev3'][0]['url'], self.assertEqual(
"http://cinder.south.host/cinderapi/public/v3") public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3",
)
self.assertEqual(public_ep['block-storage'][0]['region'], 'North') self.assertEqual(public_ep['block-storage'][0]['region'], 'North')
self.assertEqual(public_ep['block-storage'][0]['url'], self.assertEqual(
"http://cinder.north.host/cinderapi/public") public_ep['block-storage'][0]['url'],
"http://cinder.north.host/cinderapi/public",
)
def test_service_catalog_regions(self): def test_service_catalog_regions(self):
self.AUTH_RESPONSE_BODY['token']['region_name'] = "North" self.AUTH_RESPONSE_BODY['token']['region_name'] = "North"
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
url = sc.url_for(service_type='image', interface='public') url = sc.url_for(service_type='image', interface='public')
self.assertEqual(url, "http://glance.north.host/glanceapi/public") self.assertEqual(url, "http://glance.north.host/glanceapi/public")
self.AUTH_RESPONSE_BODY['token']['region_name'] = "South" self.AUTH_RESPONSE_BODY['token']['region_name'] = "South"
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog sc = auth_ref.service_catalog
url = sc.url_for(service_type='image', url = sc.url_for(
region_name="South", service_type='image', region_name="South", interface='internal'
interface='internal') )
self.assertEqual(url, "http://glance.south.host/glanceapi/internal") self.assertEqual(url, "http://glance.south.host/glanceapi/internal")
def test_service_catalog_empty(self): def test_service_catalog_empty(self):
self.AUTH_RESPONSE_BODY['token']['catalog'] = [] self.AUTH_RESPONSE_BODY['token']['catalog'] = []
auth_ref = access.create(auth_token=uuid.uuid4().hex, auth_ref = access.create(
body=self.AUTH_RESPONSE_BODY) auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
self.assertRaises(exceptions.EmptyCatalog, )
self.assertRaises(
exceptions.EmptyCatalog,
auth_ref.service_catalog.url_for, auth_ref.service_catalog.url_for,
service_type='image', service_type='image',
interface='internalURL') interface='internalURL',
)
def test_service_catalog_get_endpoints_region_names(self): def test_service_catalog_get_endpoints_region_names(self):
sc = access.create(auth_token=uuid.uuid4().hex, sc = access.create(
body=self.AUTH_RESPONSE_BODY).service_catalog auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
endpoints = sc.get_endpoints(service_type='image', region_name='North') endpoints = sc.get_endpoints(service_type='image', region_name='North')
self.assertEqual(len(endpoints), 1) self.assertEqual(len(endpoints), 1)
for endpoint in endpoints['image']: for endpoint in endpoints['image']:
self.assertEqual(endpoint['url'], self.assertEqual(
self.north_endpoints[endpoint['interface']]) endpoint['url'], self.north_endpoints[endpoint['interface']]
)
endpoints = sc.get_endpoints(service_type='image', region_name='South') endpoints = sc.get_endpoints(service_type='image', region_name='South')
self.assertEqual(len(endpoints), 1) self.assertEqual(len(endpoints), 1)
for endpoint in endpoints['image']: for endpoint in endpoints['image']:
self.assertEqual(endpoint['url'], self.assertEqual(
self.south_endpoints[endpoint['interface']]) endpoint['url'], self.south_endpoints[endpoint['interface']]
)
endpoints = sc.get_endpoints(service_type='compute') endpoints = sc.get_endpoints(service_type='compute')
self.assertEqual(len(endpoints['compute']), 3) self.assertEqual(len(endpoints['compute']), 3)
endpoints = sc.get_endpoints(service_type='compute', endpoints = sc.get_endpoints(
region_name='North') service_type='compute', region_name='North'
)
self.assertEqual(len(endpoints['compute']), 3) self.assertEqual(len(endpoints['compute']), 3)
endpoints = sc.get_endpoints(service_type='compute', endpoints = sc.get_endpoints(
region_name='West') service_type='compute', region_name='West'
)
self.assertEqual(len(endpoints['compute']), 0) self.assertEqual(len(endpoints['compute']), 0)
def test_service_catalog_url_for_region_names(self): def test_service_catalog_url_for_region_names(self):
sc = access.create(auth_token=uuid.uuid4().hex, sc = access.create(
body=self.AUTH_RESPONSE_BODY).service_catalog auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
url = sc.url_for(service_type='image', region_name='North') url = sc.url_for(service_type='image', region_name='North')
self.assertEqual(url, self.north_endpoints['public']) self.assertEqual(url, self.north_endpoints['public'])
@ -242,12 +287,17 @@ class ServiceCatalogTest(utils.TestCase):
url = sc.url_for(service_type='image', region_name='South') url = sc.url_for(service_type='image', region_name='South')
self.assertEqual(url, self.south_endpoints['public']) self.assertEqual(url, self.south_endpoints['public'])
self.assertRaises(exceptions.EndpointNotFound, sc.url_for, self.assertRaises(
service_type='image', region_name='West') exceptions.EndpointNotFound,
sc.url_for,
service_type='image',
region_name='West',
)
def test_service_catalog_get_url_region_names(self): def test_service_catalog_get_url_region_names(self):
sc = access.create(auth_token=uuid.uuid4().hex, sc = access.create(
body=self.AUTH_RESPONSE_BODY).service_catalog auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
urls = sc.get_urls(service_type='image') urls = sc.get_urls(service_type='image')
self.assertEqual(len(urls), 2) self.assertEqual(len(urls), 2)
@ -264,29 +314,43 @@ class ServiceCatalogTest(utils.TestCase):
self.assertEqual(len(urls), 0) self.assertEqual(len(urls), 0)
def test_service_catalog_service_name(self): def test_service_catalog_service_name(self):
sc = access.create(auth_token=uuid.uuid4().hex, sc = access.create(
body=self.AUTH_RESPONSE_BODY).service_catalog auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
url = sc.url_for(service_name='glance', interface='public', url = sc.url_for(
service_type='image', region_name='North') service_name='glance',
interface='public',
service_type='image',
region_name='North',
)
self.assertEqual('http://glance.north.host/glanceapi/public', url) self.assertEqual('http://glance.north.host/glanceapi/public', url)
url = sc.url_for(service_name='glance', interface='public', url = sc.url_for(
service_type='image', region_name='South') service_name='glance',
interface='public',
service_type='image',
region_name='South',
)
self.assertEqual('http://glance.south.host/glanceapi/public', url) self.assertEqual('http://glance.south.host/glanceapi/public', url)
self.assertRaises(exceptions.EndpointNotFound, sc.url_for, self.assertRaises(
service_name='glance', service_type='compute') exceptions.EndpointNotFound,
sc.url_for,
service_name='glance',
service_type='compute',
)
urls = sc.get_urls(service_type='image', service_name='glance', urls = sc.get_urls(
interface='public') service_type='image', service_name='glance', interface='public'
)
self.assertIn('http://glance.north.host/glanceapi/public', urls) self.assertIn('http://glance.north.host/glanceapi/public', urls)
self.assertIn('http://glance.south.host/glanceapi/public', urls) self.assertIn('http://glance.south.host/glanceapi/public', urls)
urls = sc.get_urls(service_type='image', urls = sc.get_urls(
service_name='Servers', service_type='image', service_name='Servers', interface='public'
interface='public') )
self.assertEqual(0, len(urls)) self.assertEqual(0, len(urls))
@ -304,81 +368,102 @@ class ServiceCatalogTest(utils.TestCase):
s = f.add_service('volume') s = f.add_service('volume')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://public.com:8776/v1/%s' % tenant, public=f'http://public.com:8776/v1/{tenant}',
internal='http://internal:8776/v1/%s' % tenant, internal=f'http://internal:8776/v1/{tenant}',
admin='http://admin:8776/v1/%s' % tenant, admin=f'http://admin:8776/v1/{tenant}',
region=region) region=region,
)
s = f.add_service('image') s = f.add_service('image')
s.add_standard_endpoints(public='http://public.com:9292/v1', s.add_standard_endpoints(
public='http://public.com:9292/v1',
internal='http://internal:9292/v1', internal='http://internal:9292/v1',
admin='http://admin:9292/v1', admin='http://admin:9292/v1',
region=region) region=region,
)
s = f.add_service('compute') s = f.add_service('compute')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://public.com:8774/v2/%s' % tenant, public=f'http://public.com:8774/v2/{tenant}',
internal='http://internal:8774/v2/%s' % tenant, internal=f'http://internal:8774/v2/{tenant}',
admin='http://admin:8774/v2/%s' % tenant, admin=f'http://admin:8774/v2/{tenant}',
region=region) region=region,
)
s = f.add_service('ec2') s = f.add_service('ec2')
s.add_standard_endpoints( s.add_standard_endpoints(
public='http://public.com:8773/services/Cloud', public='http://public.com:8773/services/Cloud',
internal='http://internal:8773/services/Cloud', internal='http://internal:8773/services/Cloud',
admin='http://admin:8773/services/Admin', admin='http://admin:8773/services/Admin',
region=region) region=region,
)
s = f.add_service('identity') s = f.add_service('identity')
s.add_standard_endpoints(public='http://public.com:5000/v3', s.add_standard_endpoints(
public='http://public.com:5000/v3',
internal='http://internal:5000/v3', internal='http://internal:5000/v3',
admin='http://admin:35357/v3', admin='http://admin:35357/v3',
region=region) region=region,
)
pr_auth_ref = access.create(body=f) pr_auth_ref = access.create(body=f)
pr_sc = pr_auth_ref.service_catalog pr_sc = pr_auth_ref.service_catalog
# this will work because there are no service names on that token # this will work because there are no service names on that token
url_ref = 'http://public.com:8774/v2/225da22d3ce34b15877ea70b2a575f58' url_ref = 'http://public.com:8774/v2/225da22d3ce34b15877ea70b2a575f58'
url = pr_sc.url_for(service_type='compute', service_name='NotExist', url = pr_sc.url_for(
interface='public') service_type='compute', service_name='NotExist', interface='public'
)
self.assertEqual(url_ref, url) self.assertEqual(url_ref, url)
ab_auth_ref = access.create(body=self.AUTH_RESPONSE_BODY) ab_auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
ab_sc = ab_auth_ref.service_catalog ab_sc = ab_auth_ref.service_catalog
# this won't work because there is a name and it's not this one # this won't work because there is a name and it's not this one
self.assertRaises(exceptions.EndpointNotFound, ab_sc.url_for, self.assertRaises(
service_type='compute', service_name='NotExist', exceptions.EndpointNotFound,
interface='public') ab_sc.url_for,
service_type='compute',
service_name='NotExist',
interface='public',
)
class ServiceCatalogV3Test(ServiceCatalogTest): class ServiceCatalogV3Test(ServiceCatalogTest):
def test_building_a_service_catalog(self): def test_building_a_service_catalog(self):
sc = access.create(auth_token=uuid.uuid4().hex, sc = access.create(
body=self.AUTH_RESPONSE_BODY).service_catalog auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
self.assertEqual(sc.url_for(service_type='compute'), self.assertEqual(
'https://compute.north.host/novapi/public') sc.url_for(service_type='compute'),
self.assertEqual(sc.url_for(service_type='compute', 'https://compute.north.host/novapi/public',
interface='internal'), )
'https://compute.north.host/novapi/internal') self.assertEqual(
sc.url_for(service_type='compute', interface='internal'),
'https://compute.north.host/novapi/internal',
)
self.assertRaises(exceptions.EndpointNotFound, self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for, sc.url_for,
region_name='South', region_name='South',
service_type='compute') service_type='compute',
)
def test_service_catalog_endpoints(self): def test_service_catalog_endpoints(self):
sc = access.create(auth_token=uuid.uuid4().hex, sc = access.create(
body=self.AUTH_RESPONSE_BODY).service_catalog auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
public_ep = sc.get_endpoints(service_type='compute', public_ep = sc.get_endpoints(
interface='public') service_type='compute', interface='public'
)
self.assertEqual(public_ep['compute'][0]['region_id'], 'North') self.assertEqual(public_ep['compute'][0]['region_id'], 'North')
self.assertEqual(public_ep['compute'][0]['url'], self.assertEqual(
'https://compute.north.host/novapi/public') public_ep['compute'][0]['url'],
'https://compute.north.host/novapi/public',
)
def test_service_catalog_multiple_service_types(self): def test_service_catalog_multiple_service_types(self):
token = fixture.V3Token() token = fixture.V3Token()
@ -386,23 +471,26 @@ class ServiceCatalogV3Test(ServiceCatalogTest):
for i in range(3): for i in range(3):
s = token.add_service('compute') s = token.add_service('compute')
s.add_standard_endpoints(public='public-%d' % i, s.add_standard_endpoints(
public='public-%d' % i,
admin='admin-%d' % i, admin='admin-%d' % i,
internal='internal-%d' % i, internal='internal-%d' % i,
region='region-%d' % i) region='region-%d' % i,
)
auth_ref = access.create(resp=None, body=token) auth_ref = access.create(resp=None, body=token)
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
interface='public') service_type='compute', interface='public'
)
self.assertEqual(set(['public-0', 'public-1', 'public-2']), set(urls)) self.assertEqual({'public-0', 'public-1', 'public-2'}, set(urls))
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
interface='public', service_type='compute', interface='public', region_name='region-1'
region_name='region-1') )
self.assertEqual(('public-1', ), urls) self.assertEqual(('public-1',), urls)
def test_service_catalog_endpoint_id(self): def test_service_catalog_endpoint_id(self):
token = fixture.V3Token() token = fixture.V3Token()
@ -419,37 +507,42 @@ class ServiceCatalogV3Test(ServiceCatalogTest):
auth_ref = access.create(body=token) auth_ref = access.create(body=token)
# initially assert that we get back all our urls for a simple filter # initially assert that we get back all our urls for a simple filter
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
interface='public') service_type='compute', interface='public'
)
self.assertEqual(2, len(urls)) self.assertEqual(2, len(urls))
# with bad endpoint_id nothing should be found # with bad endpoint_id nothing should be found
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
service_type='compute',
endpoint_id=uuid.uuid4().hex, endpoint_id=uuid.uuid4().hex,
interface='public') interface='public',
)
self.assertEqual(0, len(urls)) self.assertEqual(0, len(urls))
# with service_id we get back both public endpoints # with service_id we get back both public endpoints
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
service_id=service_id, service_type='compute', service_id=service_id, interface='public'
interface='public') )
self.assertEqual(2, len(urls)) self.assertEqual(2, len(urls))
# with service_id and endpoint_id we get back the url we want # with service_id and endpoint_id we get back the url we want
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
service_type='compute',
service_id=service_id, service_id=service_id,
endpoint_id=endpoint_id, endpoint_id=endpoint_id,
interface='public') interface='public',
)
self.assertEqual((public_url, ), urls) self.assertEqual((public_url,), urls)
# with service_id and endpoint_id we get back the url we want # with service_id and endpoint_id we get back the url we want
urls = auth_ref.service_catalog.get_urls(service_type='compute', urls = auth_ref.service_catalog.get_urls(
endpoint_id=endpoint_id, service_type='compute', endpoint_id=endpoint_id, interface='public'
interface='public') )
self.assertEqual((public_url, ), urls) self.assertEqual((public_url,), urls)
def test_service_catalog_without_service_type(self): def test_service_catalog_without_service_type(self):
token = fixture.V3Token() token = fixture.V3Token()

View File

@ -23,7 +23,8 @@ def project_scoped_token():
project_id='225da22d3ce34b15877ea70b2a575f58', project_id='225da22d3ce34b15877ea70b2a575f58',
project_name='exampleproject', project_name='exampleproject',
project_domain_id='4e6893b7ba0b4006840c3845660b86ed', project_domain_id='4e6893b7ba0b4006840c3845660b86ed',
project_domain_name='exampledomain') project_domain_name='exampledomain',
)
fixture.add_role(id='76e72a', name='admin') fixture.add_role(id='76e72a', name='admin')
fixture.add_role(id='f4f392', name='member') fixture.add_role(id='f4f392', name='member')
@ -33,36 +34,43 @@ def project_scoped_token():
service = fixture.add_service('volume') service = fixture.add_service('volume')
service.add_standard_endpoints( service.add_standard_endpoints(
public='http://public.com:8776/v1/%s' % tenant, public=f'http://public.com:8776/v1/{tenant}',
internal='http://internal:8776/v1/%s' % tenant, internal=f'http://internal:8776/v1/{tenant}',
admin='http://admin:8776/v1/%s' % tenant, admin=f'http://admin:8776/v1/{tenant}',
region=region) region=region,
)
service = fixture.add_service('image') service = fixture.add_service('image')
service.add_standard_endpoints(public='http://public.com:9292/v1', service.add_standard_endpoints(
public='http://public.com:9292/v1',
internal='http://internal:9292/v1', internal='http://internal:9292/v1',
admin='http://admin:9292/v1', admin='http://admin:9292/v1',
region=region) region=region,
)
service = fixture.add_service('compute') service = fixture.add_service('compute')
service.add_standard_endpoints( service.add_standard_endpoints(
public='http://public.com:8774/v2/%s' % tenant, public=f'http://public.com:8774/v2/{tenant}',
internal='http://internal:8774/v2/%s' % tenant, internal=f'http://internal:8774/v2/{tenant}',
admin='http://admin:8774/v2/%s' % tenant, admin=f'http://admin:8774/v2/{tenant}',
region=region) region=region,
)
service = fixture.add_service('ec2') service = fixture.add_service('ec2')
service.add_standard_endpoints( service.add_standard_endpoints(
public='http://public.com:8773/services/Cloud', public='http://public.com:8773/services/Cloud',
internal='http://internal:8773/services/Cloud', internal='http://internal:8773/services/Cloud',
admin='http://admin:8773/services/Admin', admin='http://admin:8773/services/Admin',
region=region) region=region,
)
service = fixture.add_service('identity') service = fixture.add_service('identity')
service.add_standard_endpoints(public='http://public.com:5000/v3', service.add_standard_endpoints(
public='http://public.com:5000/v3',
internal='http://internal:5000/v3', internal='http://internal:5000/v3',
admin='http://admin:35357/v3', admin='http://admin:35357/v3',
region=region) region=region,
)
return fixture return fixture
@ -75,48 +83,56 @@ def domain_scoped_token():
user_domain_name='exampledomain', user_domain_name='exampledomain',
expires='2010-11-01T03:32:15-05:00', expires='2010-11-01T03:32:15-05:00',
domain_id='8e9283b7ba0b1038840c3842058b86ab', domain_id='8e9283b7ba0b1038840c3842058b86ab',
domain_name='anotherdomain') domain_name='anotherdomain',
)
fixture.add_role(id='76e72a', name='admin') fixture.add_role(id='76e72a', name='admin')
fixture.add_role(id='f4f392', name='member') fixture.add_role(id='f4f392', name='member')
region = 'RegionOne' region = 'RegionOne'
service = fixture.add_service('volume') service = fixture.add_service('volume')
service.add_standard_endpoints(public='http://public.com:8776/v1/None', service.add_standard_endpoints(
public='http://public.com:8776/v1/None',
internal='http://internal.com:8776/v1/None', internal='http://internal.com:8776/v1/None',
admin='http://admin.com:8776/v1/None', admin='http://admin.com:8776/v1/None',
region=region) region=region,
)
service = fixture.add_service('image') service = fixture.add_service('image')
service.add_standard_endpoints(public='http://public.com:9292/v1', service.add_standard_endpoints(
public='http://public.com:9292/v1',
internal='http://internal:9292/v1', internal='http://internal:9292/v1',
admin='http://admin:9292/v1', admin='http://admin:9292/v1',
region=region) region=region,
)
service = fixture.add_service('compute') service = fixture.add_service('compute')
service.add_standard_endpoints(public='http://public.com:8774/v1.1/None', service.add_standard_endpoints(
public='http://public.com:8774/v1.1/None',
internal='http://internal:8774/v1.1/None', internal='http://internal:8774/v1.1/None',
admin='http://admin:8774/v1.1/None', admin='http://admin:8774/v1.1/None',
region=region) region=region,
)
service = fixture.add_service('ec2') service = fixture.add_service('ec2')
service.add_standard_endpoints( service.add_standard_endpoints(
public='http://public.com:8773/services/Cloud', public='http://public.com:8773/services/Cloud',
internal='http://internal:8773/services/Cloud', internal='http://internal:8773/services/Cloud',
admin='http://admin:8773/services/Admin', admin='http://admin:8773/services/Admin',
region=region) region=region,
)
service = fixture.add_service('identity') service = fixture.add_service('identity')
service.add_standard_endpoints(public='http://public.com:5000/v3', service.add_standard_endpoints(
public='http://public.com:5000/v3',
internal='http://internal:5000/v3', internal='http://internal:5000/v3',
admin='http://admin:35357/v3', admin='http://admin:35357/v3',
region=region) region=region,
)
return fixture return fixture
AUTH_SUBJECT_TOKEN = '3e2813b7ba0b4006840c3825860b86ed' AUTH_SUBJECT_TOKEN = '3e2813b7ba0b4006840c3825860b86ed'
AUTH_RESPONSE_HEADERS = { AUTH_RESPONSE_HEADERS = {'X-Subject-Token': AUTH_SUBJECT_TOKEN}
'X-Subject-Token': AUTH_SUBJECT_TOKEN,
}

View File

@ -15,7 +15,6 @@ from keystoneauth1.tests.unit import utils
class ExceptionTests(utils.TestCase): class ExceptionTests(utils.TestCase):
def test_clientexception_with_message(self): def test_clientexception_with_message(self):
test_message = 'Unittest exception message.' test_message = 'Unittest exception message.'
exc = exceptions.ClientException(message=test_message) exc = exceptions.ClientException(message=test_message)
@ -23,10 +22,8 @@ class ExceptionTests(utils.TestCase):
def test_clientexception_with_no_message(self): def test_clientexception_with_no_message(self):
exc = exceptions.ClientException() exc = exceptions.ClientException()
self.assertEqual(exceptions.ClientException.__name__, self.assertEqual(exceptions.ClientException.__name__, exc.message)
exc.message)
def test_using_default_message(self): def test_using_default_message(self):
exc = exceptions.AuthorizationFailure() exc = exceptions.AuthorizationFailure()
self.assertEqual(exceptions.AuthorizationFailure.message, self.assertEqual(exceptions.AuthorizationFailure.message, exc.message)
exc.message)

View File

@ -17,8 +17,7 @@ from keystoneauth1.tests.unit.extras.kerberos import utils
from keystoneauth1.tests.unit import utils as test_utils from keystoneauth1.tests.unit import utils as test_utils
REQUEST = {'auth': {'identity': {'methods': ['kerberos'], REQUEST = {'auth': {'identity': {'methods': ['kerberos'], 'kerberos': {}}}}
'kerberos': {}}}}
class TestCase(test_utils.TestCase): class TestCase(test_utils.TestCase):
@ -27,7 +26,7 @@ class TestCase(test_utils.TestCase):
TEST_V3_URL = test_utils.TestCase.TEST_ROOT_URL + 'v3' TEST_V3_URL = test_utils.TestCase.TEST_ROOT_URL + 'v3'
def setUp(self): def setUp(self):
super(TestCase, self).setUp() super().setUp()
km = utils.KerberosMock(self.requests_mock) km = utils.KerberosMock(self.requests_mock)
self.kerberos_mock = self.useFixture(km) self.kerberos_mock = self.useFixture(km)

View File

@ -16,12 +16,14 @@ from keystoneauth1.tests.unit import utils as test_utils
class FedKerbLoadingTests(test_utils.TestCase): class FedKerbLoadingTests(test_utils.TestCase):
def test_options(self): def test_options(self):
opts = [o.name for o in opts = [
loading.get_plugin_loader('v3fedkerb').get_options()] o.name
for o in loading.get_plugin_loader('v3fedkerb').get_options()
]
allowed_opts = ['system-scope', allowed_opts = [
'system-scope',
'domain-id', 'domain-id',
'domain-name', 'domain-name',
'identity-provider', 'identity-provider',
@ -45,6 +47,6 @@ class FedKerbLoadingTests(test_utils.TestCase):
self.assertRaises(exceptions.MissingRequiredOptions, self.create) self.assertRaises(exceptions.MissingRequiredOptions, self.create)
def test_load(self): def test_load(self):
self.create(auth_url='auth_url', self.create(
identity_provider='idp', auth_url='auth_url', identity_provider='idp', protocol='protocol'
protocol='protocol') )

View File

@ -15,12 +15,14 @@ from keystoneauth1.tests.unit import utils as test_utils
class KerberosLoadingTests(test_utils.TestCase): class KerberosLoadingTests(test_utils.TestCase):
def test_options(self): def test_options(self):
opts = [o.name for o in opts = [
loading.get_plugin_loader('v3kerberos').get_options()] o.name
for o in loading.get_plugin_loader('v3kerberos').get_options()
]
allowed_opts = ['system-scope', allowed_opts = [
'system-scope',
'domain-id', 'domain-id',
'domain-name', 'domain-name',
'project-id', 'project-id',

View File

@ -19,12 +19,11 @@ from keystoneauth1.tests.unit.extras.kerberos import base
class TestMappedAuth(base.TestCase): class TestMappedAuth(base.TestCase):
def setUp(self): def setUp(self):
if kerberos.requests_kerberos is None: if kerberos.requests_kerberos is None:
self.skipTest("Kerberos support isn't available.") self.skipTest("Kerberos support isn't available.")
super(TestMappedAuth, self).setUp() super().setUp()
self.protocol = uuid.uuid4().hex self.protocol = uuid.uuid4().hex
self.identity_provider = uuid.uuid4().hex self.identity_provider = uuid.uuid4().hex
@ -32,18 +31,18 @@ class TestMappedAuth(base.TestCase):
@property @property
def token_url(self): def token_url(self):
fmt = '%s/OS-FEDERATION/identity_providers/%s/protocols/%s/auth' fmt = '%s/OS-FEDERATION/identity_providers/%s/protocols/%s/auth'
return fmt % ( return fmt % (self.TEST_V3_URL, self.identity_provider, self.protocol)
self.TEST_V3_URL,
self.identity_provider,
self.protocol)
def test_unscoped_mapped_auth(self): def test_unscoped_mapped_auth(self):
token_id, _ = self.kerberos_mock.mock_auth_success( token_id, _ = self.kerberos_mock.mock_auth_success(
url=self.token_url, method='GET') url=self.token_url, method='GET'
)
plugin = kerberos.MappedKerberos( plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol, auth_url=self.TEST_V3_URL,
identity_provider=self.identity_provider) protocol=self.protocol,
identity_provider=self.identity_provider,
)
sess = session.Session() sess = session.Session()
tok = plugin.get_token(sess) tok = plugin.get_token(sess)
@ -51,23 +50,27 @@ class TestMappedAuth(base.TestCase):
self.assertEqual(token_id, tok) self.assertEqual(token_id, tok)
def test_project_scoped_mapped_auth(self): def test_project_scoped_mapped_auth(self):
self.kerberos_mock.mock_auth_success(url=self.token_url, self.kerberos_mock.mock_auth_success(url=self.token_url, method='GET')
method='GET')
scoped_id = uuid.uuid4().hex scoped_id = uuid.uuid4().hex
scoped_body = ks_fixture.V3Token() scoped_body = ks_fixture.V3Token()
scoped_body.set_project_scope() scoped_body.set_project_scope()
self.requests_mock.post( self.requests_mock.post(
'%s/auth/tokens' % self.TEST_V3_URL, f'{self.TEST_V3_URL}/auth/tokens',
json=scoped_body, json=scoped_body,
headers={'X-Subject-Token': scoped_id, headers={
'Content-Type': 'application/json'}) 'X-Subject-Token': scoped_id,
'Content-Type': 'application/json',
},
)
plugin = kerberos.MappedKerberos( plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol, auth_url=self.TEST_V3_URL,
protocol=self.protocol,
identity_provider=self.identity_provider, identity_provider=self.identity_provider,
project_id=scoped_body.project_id) project_id=scoped_body.project_id,
)
sess = session.Session() sess = session.Session()
tok = plugin.get_token(sess) tok = plugin.get_token(sess)
@ -77,24 +80,28 @@ class TestMappedAuth(base.TestCase):
self.assertEqual(scoped_body.project_id, proj) self.assertEqual(scoped_body.project_id, proj)
def test_authenticate_with_mutual_authentication_required(self): def test_authenticate_with_mutual_authentication_required(self):
self.kerberos_mock.mock_auth_success(url=self.token_url, self.kerberos_mock.mock_auth_success(url=self.token_url, method='GET')
method='GET')
scoped_id = uuid.uuid4().hex scoped_id = uuid.uuid4().hex
scoped_body = ks_fixture.V3Token() scoped_body = ks_fixture.V3Token()
scoped_body.set_project_scope() scoped_body.set_project_scope()
self.requests_mock.post( self.requests_mock.post(
'%s/auth/tokens' % self.TEST_V3_URL, f'{self.TEST_V3_URL}/auth/tokens',
json=scoped_body, json=scoped_body,
headers={'X-Subject-Token': scoped_id, headers={
'Content-Type': 'application/json'}) 'X-Subject-Token': scoped_id,
'Content-Type': 'application/json',
},
)
plugin = kerberos.MappedKerberos( plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol, auth_url=self.TEST_V3_URL,
protocol=self.protocol,
identity_provider=self.identity_provider, identity_provider=self.identity_provider,
project_id=scoped_body.project_id, project_id=scoped_body.project_id,
mutual_auth='required') mutual_auth='required',
)
sess = session.Session() sess = session.Session()
tok = plugin.get_token(sess) tok = plugin.get_token(sess)
@ -105,24 +112,28 @@ class TestMappedAuth(base.TestCase):
self.assertEqual(self.kerberos_mock.called_auth_server, True) self.assertEqual(self.kerberos_mock.called_auth_server, True)
def test_authenticate_with_mutual_authentication_disabled(self): def test_authenticate_with_mutual_authentication_disabled(self):
self.kerberos_mock.mock_auth_success(url=self.token_url, self.kerberos_mock.mock_auth_success(url=self.token_url, method='GET')
method='GET')
scoped_id = uuid.uuid4().hex scoped_id = uuid.uuid4().hex
scoped_body = ks_fixture.V3Token() scoped_body = ks_fixture.V3Token()
scoped_body.set_project_scope() scoped_body.set_project_scope()
self.requests_mock.post( self.requests_mock.post(
'%s/auth/tokens' % self.TEST_V3_URL, f'{self.TEST_V3_URL}/auth/tokens',
json=scoped_body, json=scoped_body,
headers={'X-Subject-Token': scoped_id, headers={
'Content-Type': 'application/json'}) 'X-Subject-Token': scoped_id,
'Content-Type': 'application/json',
},
)
plugin = kerberos.MappedKerberos( plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol, auth_url=self.TEST_V3_URL,
protocol=self.protocol,
identity_provider=self.identity_provider, identity_provider=self.identity_provider,
project_id=scoped_body.project_id, project_id=scoped_body.project_id,
mutual_auth='disabled') mutual_auth='disabled',
)
sess = session.Session() sess = session.Session()
tok = plugin.get_token(sess) tok = plugin.get_token(sess)

View File

@ -16,12 +16,11 @@ from keystoneauth1.tests.unit.extras.kerberos import base
class TestKerberosAuth(base.TestCase): class TestKerberosAuth(base.TestCase):
def setUp(self): def setUp(self):
if kerberos.requests_kerberos is None: if kerberos.requests_kerberos is None:
self.skipTest("Kerberos support isn't available.") self.skipTest("Kerberos support isn't available.")
super(TestKerberosAuth, self).setUp() super().setUp()
def test_authenticate_with_kerberos_domain_scoped(self): def test_authenticate_with_kerberos_domain_scoped(self):
token_id, token_body = self.kerberos_mock.mock_auth_success() token_id, token_body = self.kerberos_mock.mock_auth_success()
@ -33,22 +32,25 @@ class TestKerberosAuth(base.TestCase):
self.assertRequestBody() self.assertRequestBody()
self.assertEqual( self.assertEqual(
self.kerberos_mock.challenge_header, self.kerberos_mock.challenge_header,
self.requests_mock.last_request.headers['Authorization']) self.requests_mock.last_request.headers['Authorization'],
)
self.assertEqual(token_id, a.auth_ref.auth_token) self.assertEqual(token_id, a.auth_ref.auth_token)
self.assertEqual(token_id, token) self.assertEqual(token_id, token)
def test_authenticate_with_kerberos_mutual_authentication_required(self): def test_authenticate_with_kerberos_mutual_authentication_required(self):
token_id, token_body = self.kerberos_mock.mock_auth_success() token_id, token_body = self.kerberos_mock.mock_auth_success()
a = kerberos.Kerberos(self.TEST_ROOT_URL + 'v3', a = kerberos.Kerberos(
mutual_auth='required') self.TEST_ROOT_URL + 'v3', mutual_auth='required'
)
s = session.Session(a) s = session.Session(a)
token = a.get_token(s) token = a.get_token(s)
self.assertRequestBody() self.assertRequestBody()
self.assertEqual( self.assertEqual(
self.kerberos_mock.challenge_header, self.kerberos_mock.challenge_header,
self.requests_mock.last_request.headers['Authorization']) self.requests_mock.last_request.headers['Authorization'],
)
self.assertEqual(token_id, a.auth_ref.auth_token) self.assertEqual(token_id, a.auth_ref.auth_token)
self.assertEqual(token_id, token) self.assertEqual(token_id, token)
self.assertEqual(self.kerberos_mock.called_auth_server, True) self.assertEqual(self.kerberos_mock.called_auth_server, True)
@ -56,15 +58,17 @@ class TestKerberosAuth(base.TestCase):
def test_authenticate_with_kerberos_mutual_authentication_disabled(self): def test_authenticate_with_kerberos_mutual_authentication_disabled(self):
token_id, token_body = self.kerberos_mock.mock_auth_success() token_id, token_body = self.kerberos_mock.mock_auth_success()
a = kerberos.Kerberos(self.TEST_ROOT_URL + 'v3', a = kerberos.Kerberos(
mutual_auth='disabled') self.TEST_ROOT_URL + 'v3', mutual_auth='disabled'
)
s = session.Session(a) s = session.Session(a)
token = a.get_token(s) token = a.get_token(s)
self.assertRequestBody() self.assertRequestBody()
self.assertEqual( self.assertEqual(
self.kerberos_mock.challenge_header, self.kerberos_mock.challenge_header,
self.requests_mock.last_request.headers['Authorization']) self.requests_mock.last_request.headers['Authorization'],
)
self.assertEqual(token_id, a.auth_ref.auth_token) self.assertEqual(token_id, a.auth_ref.auth_token)
self.assertEqual(token_id, token) self.assertEqual(token_id, token)
self.assertEqual(self.kerberos_mock.called_auth_server, False) self.assertEqual(self.kerberos_mock.called_auth_server, False)

View File

@ -13,6 +13,7 @@
import uuid import uuid
import fixtures import fixtures
try: try:
# requests_kerberos won't be available on py3, it doesn't work with py3. # requests_kerberos won't be available on py3, it doesn't work with py3.
import requests_kerberos import requests_kerberos
@ -24,29 +25,32 @@ from keystoneauth1.tests.unit import utils as test_utils
class KerberosMock(fixtures.Fixture): class KerberosMock(fixtures.Fixture):
def __init__(self, requests_mock): def __init__(self, requests_mock):
super(KerberosMock, self).__init__() super().__init__()
self.challenge_header = 'Negotiate %s' % uuid.uuid4().hex self.challenge_header = f'Negotiate {uuid.uuid4().hex}'
self.pass_header = 'Negotiate %s' % uuid.uuid4().hex self.pass_header = f'Negotiate {uuid.uuid4().hex}'
self.requests_mock = requests_mock self.requests_mock = requests_mock
def setUp(self): def setUp(self):
super(KerberosMock, self).setUp() super().setUp()
if requests_kerberos is None: if requests_kerberos is None:
return return
m = fixtures.MockPatchObject(requests_kerberos.HTTPKerberosAuth, m = fixtures.MockPatchObject(
requests_kerberos.HTTPKerberosAuth,
'generate_request_header', 'generate_request_header',
self._generate_request_header) self._generate_request_header,
)
self.header_fixture = self.useFixture(m) self.header_fixture = self.useFixture(m)
m = fixtures.MockPatchObject(requests_kerberos.HTTPKerberosAuth, m = fixtures.MockPatchObject(
requests_kerberos.HTTPKerberosAuth,
'authenticate_server', 'authenticate_server',
self._authenticate_server) self._authenticate_server,
)
self.authenticate_fixture = self.useFixture(m) self.authenticate_fixture = self.useFixture(m)
@ -62,7 +66,8 @@ class KerberosMock(fixtures.Fixture):
token_id=None, token_id=None,
token_body=None, token_body=None,
method='POST', method='POST',
url=test_utils.TestCase.TEST_ROOT_URL + 'v3/auth/tokens'): url=test_utils.TestCase.TEST_ROOT_URL + 'v3/auth/tokens',
):
if not token_id: if not token_id:
token_id = uuid.uuid4().hex token_id = uuid.uuid4().hex
if not token_body: if not token_body:
@ -70,17 +75,25 @@ class KerberosMock(fixtures.Fixture):
self.called_auth_server = False self.called_auth_server = False
response_list = [{'text': 'Fail', response_list = [
{
'text': 'Fail',
'status_code': 401, 'status_code': 401,
'headers': {'WWW-Authenticate': 'Negotiate'}}, 'headers': {'WWW-Authenticate': 'Negotiate'},
{'headers': {'X-Subject-Token': token_id, },
{
'headers': {
'X-Subject-Token': token_id,
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'WWW-Authenticate': self.pass_header}, 'WWW-Authenticate': self.pass_header,
},
'status_code': 200, 'status_code': 200,
'json': token_body}] 'json': token_body,
},
]
self.requests_mock.register_uri(method, self.requests_mock.register_uri(
url, method, url, response_list=response_list
response_list=response_list) )
return token_id, token_body return token_id, token_body

View File

@ -22,17 +22,20 @@ from keystoneauth1.tests.unit import utils as test_utils
class OAuth1AuthTests(test_utils.TestCase): class OAuth1AuthTests(test_utils.TestCase):
TEST_ROOT_URL = 'http://127.0.0.1:5000/' TEST_ROOT_URL = 'http://127.0.0.1:5000/'
TEST_URL = '%s%s' % (TEST_ROOT_URL, 'v3') TEST_URL = '{}{}'.format(TEST_ROOT_URL, 'v3')
TEST_TOKEN = uuid.uuid4().hex TEST_TOKEN = uuid.uuid4().hex
def stub_auth(self, subject_token=None, **kwargs): def stub_auth(self, subject_token=None, **kwargs):
if not subject_token: if not subject_token:
subject_token = self.TEST_TOKEN subject_token = self.TEST_TOKEN
self.stub_url('POST', ['auth', 'tokens'], self.stub_url(
headers={'X-Subject-Token': subject_token}, **kwargs) 'POST',
['auth', 'tokens'],
headers={'X-Subject-Token': subject_token},
**kwargs,
)
def _validate_oauth_headers(self, auth_header, oauth_client): def _validate_oauth_headers(self, auth_header, oauth_client):
"""Validate data in the headers. """Validate data in the headers.
@ -42,22 +45,27 @@ class OAuth1AuthTests(test_utils.TestCase):
""" """
self.assertThat(auth_header, matchers.StartsWith('OAuth ')) self.assertThat(auth_header, matchers.StartsWith('OAuth '))
parameters = dict( parameters = dict(
oauth1.rfc5849.utils.parse_authorization_header(auth_header)) oauth1.rfc5849.utils.parse_authorization_header(auth_header)
)
self.assertEqual('HMAC-SHA1', parameters['oauth_signature_method']) self.assertEqual('HMAC-SHA1', parameters['oauth_signature_method'])
self.assertEqual('1.0', parameters['oauth_version']) self.assertEqual('1.0', parameters['oauth_version'])
self.assertIsInstance(parameters['oauth_nonce'], str) self.assertIsInstance(parameters['oauth_nonce'], str)
self.assertEqual(oauth_client.client_key, self.assertEqual(
parameters['oauth_consumer_key']) oauth_client.client_key, parameters['oauth_consumer_key']
)
if oauth_client.resource_owner_key: if oauth_client.resource_owner_key:
self.assertEqual(oauth_client.resource_owner_key, self.assertEqual(
parameters['oauth_token'],) oauth_client.resource_owner_key, parameters['oauth_token']
)
if oauth_client.verifier: if oauth_client.verifier:
self.assertEqual(oauth_client.verifier, self.assertEqual(
parameters['oauth_verifier']) oauth_client.verifier, parameters['oauth_verifier']
)
if oauth_client.callback_uri: if oauth_client.callback_uri:
self.assertEqual(oauth_client.callback_uri, self.assertEqual(
parameters['oauth_callback']) oauth_client.callback_uri, parameters['oauth_callback']
)
return parameters return parameters
def test_oauth_authenticate_success(self): def test_oauth_authenticate_success(self):
@ -66,18 +74,22 @@ class OAuth1AuthTests(test_utils.TestCase):
access_key = uuid.uuid4().hex access_key = uuid.uuid4().hex
access_secret = uuid.uuid4().hex access_secret = uuid.uuid4().hex
oauth_token = fixture.V3Token(methods=['oauth1'], oauth_token = fixture.V3Token(
methods=['oauth1'],
oauth_consumer_id=consumer_key, oauth_consumer_id=consumer_key,
oauth_access_token_id=access_key) oauth_access_token_id=access_key,
)
oauth_token.set_project_scope() oauth_token.set_project_scope()
self.stub_auth(json=oauth_token) self.stub_auth(json=oauth_token)
a = ksa_oauth1.V3OAuth1(self.TEST_URL, a = ksa_oauth1.V3OAuth1(
self.TEST_URL,
consumer_key=consumer_key, consumer_key=consumer_key,
consumer_secret=consumer_secret, consumer_secret=consumer_secret,
access_key=access_key, access_key=access_key,
access_secret=access_secret) access_secret=access_secret,
)
s = session.Session(auth=a) s = session.Session(auth=a)
t = s.get_token() t = s.get_token()
@ -85,32 +97,32 @@ class OAuth1AuthTests(test_utils.TestCase):
self.assertEqual(self.TEST_TOKEN, t) self.assertEqual(self.TEST_TOKEN, t)
OAUTH_REQUEST_BODY = { OAUTH_REQUEST_BODY = {
"auth": { "auth": {"identity": {"methods": ["oauth1"], "oauth1": {}}}
"identity": {
"methods": ["oauth1"],
"oauth1": {}
}
}
} }
self.assertRequestBodyIs(json=OAUTH_REQUEST_BODY) self.assertRequestBodyIs(json=OAUTH_REQUEST_BODY)
# Assert that the headers have the same oauthlib data # Assert that the headers have the same oauthlib data
req_headers = self.requests_mock.last_request.headers req_headers = self.requests_mock.last_request.headers
oauth_client = oauth1.Client(consumer_key, oauth_client = oauth1.Client(
consumer_key,
client_secret=consumer_secret, client_secret=consumer_secret,
resource_owner_key=access_key, resource_owner_key=access_key,
resource_owner_secret=access_secret, resource_owner_secret=access_secret,
signature_method=oauth1.SIGNATURE_HMAC) signature_method=oauth1.SIGNATURE_HMAC,
self._validate_oauth_headers(req_headers['Authorization'], )
oauth_client) self._validate_oauth_headers(
req_headers['Authorization'], oauth_client
)
def test_warning_dual_scope(self): def test_warning_dual_scope(self):
ksa_oauth1.V3OAuth1(self.TEST_URL, ksa_oauth1.V3OAuth1(
self.TEST_URL,
consumer_key=uuid.uuid4().hex, consumer_key=uuid.uuid4().hex,
consumer_secret=uuid.uuid4().hex, consumer_secret=uuid.uuid4().hex,
access_key=uuid.uuid4().hex, access_key=uuid.uuid4().hex,
access_secret=uuid.uuid4().hex, access_secret=uuid.uuid4().hex,
project_id=uuid.uuid4().hex) project_id=uuid.uuid4().hex,
)
self.assertIn('ignored by the identity server', self.logger.output) self.assertIn('ignored by the identity server', self.logger.output)

View File

@ -17,9 +17,8 @@ from keystoneauth1.tests.unit import utils as test_utils
class OAuth1LoadingTests(test_utils.TestCase): class OAuth1LoadingTests(test_utils.TestCase):
def setUp(self): def setUp(self):
super(OAuth1LoadingTests, self).setUp() super().setUp()
self.auth_url = uuid.uuid4().hex self.auth_url = uuid.uuid4().hex
def create(self, **kwargs): def create(self, **kwargs):
@ -33,10 +32,12 @@ class OAuth1LoadingTests(test_utils.TestCase):
consumer_key = uuid.uuid4().hex consumer_key = uuid.uuid4().hex
consumer_secret = uuid.uuid4().hex consumer_secret = uuid.uuid4().hex
p = self.create(access_key=access_key, p = self.create(
access_key=access_key,
access_secret=access_secret, access_secret=access_secret,
consumer_key=consumer_key, consumer_key=consumer_key,
consumer_secret=consumer_secret) consumer_secret=consumer_secret,
)
oauth_method = p.auth_methods[0] oauth_method = p.auth_methods[0]
@ -49,9 +50,13 @@ class OAuth1LoadingTests(test_utils.TestCase):
def test_options(self): def test_options(self):
options = loading.get_plugin_loader('v3oauth1').get_options() options = loading.get_plugin_loader('v3oauth1').get_options()
self.assertEqual(set([o.name for o in options]), self.assertEqual(
set(['auth-url', {o.name for o in options},
{
'auth-url',
'access-key', 'access-key',
'access-secret', 'access-secret',
'consumer-key', 'consumer-key',
'consumer-secret'])) 'consumer-secret',
},
)

View File

@ -23,22 +23,25 @@ def template(f, **kwargs):
def soap_response(**kwargs): def soap_response(**kwargs):
kwargs.setdefault('provider', 'https://idp.testshib.org/idp/shibboleth') kwargs.setdefault('provider', 'https://idp.testshib.org/idp/shibboleth')
kwargs.setdefault('consumer', kwargs.setdefault(
'https://openstack4.local/Shibboleth.sso/SAML2/ECP') 'consumer', 'https://openstack4.local/Shibboleth.sso/SAML2/ECP'
)
kwargs.setdefault('issuer', 'https://openstack4.local/shibboleth') kwargs.setdefault('issuer', 'https://openstack4.local/shibboleth')
return template('soap_response.xml', **kwargs).encode('utf-8') return template('soap_response.xml', **kwargs).encode('utf-8')
def saml_assertion(**kwargs): def saml_assertion(**kwargs):
kwargs.setdefault('issuer', 'https://idp.testshib.org/idp/shibboleth') kwargs.setdefault('issuer', 'https://idp.testshib.org/idp/shibboleth')
kwargs.setdefault('destination', kwargs.setdefault(
'https://openstack4.local/Shibboleth.sso/SAML2/ECP') 'destination', 'https://openstack4.local/Shibboleth.sso/SAML2/ECP'
)
return template('saml_assertion.xml', **kwargs).encode('utf-8') return template('saml_assertion.xml', **kwargs).encode('utf-8')
def authn_request(**kwargs): def authn_request(**kwargs):
kwargs.setdefault('issuer', kwargs.setdefault(
'https://openstack4.local/Shibboleth.sso/SAML2/ECP') 'issuer', 'https://openstack4.local/Shibboleth.sso/SAML2/ECP'
)
return template('authn_request.xml', **kwargs).encode('utf-8') return template('authn_request.xml', **kwargs).encode('utf-8')
@ -56,19 +59,13 @@ UNSCOPED_TOKEN = {
"expires_at": "2014-06-09T10:48:59.643375Z", "expires_at": "2014-06-09T10:48:59.643375Z",
"user": { "user": {
"OS-FEDERATION": { "OS-FEDERATION": {
"identity_provider": { "identity_provider": {"id": "testshib"},
"id": "testshib" "protocol": {"id": "saml2"},
}, "groups": [{"id": "1764fa5cf69a49a4918131de5ce4af9a"}],
"protocol": {
"id": "saml2"
},
"groups": [
{"id": "1764fa5cf69a49a4918131de5ce4af9a"}
]
}, },
"id": "testhib%20user", "id": "testhib%20user",
"name": "testhib user" "name": "testhib user",
} },
} }
} }
@ -78,26 +75,22 @@ PROJECTS = {
"domain_id": "37ef61", "domain_id": "37ef61",
"enabled": 'true', "enabled": 'true',
"id": "12d706", "id": "12d706",
"links": { "links": {"self": "http://identity:35357/v3/projects/12d706"},
"self": "http://identity:35357/v3/projects/12d706" "name": "a project name",
},
"name": "a project name"
}, },
{ {
"domain_id": "37ef61", "domain_id": "37ef61",
"enabled": 'true', "enabled": 'true',
"id": "9ca0eb", "id": "9ca0eb",
"links": { "links": {"self": "http://identity:35357/v3/projects/9ca0eb"},
"self": "http://identity:35357/v3/projects/9ca0eb" "name": "another project",
}, },
"name": "another project"
}
], ],
"links": { "links": {
"self": "http://identity:35357/v3/OS-FEDERATION/projects", "self": "http://identity:35357/v3/OS-FEDERATION/projects",
"previous": 'null', "previous": 'null',
"next": 'null' "next": 'null',
} },
} }
DOMAINS = { DOMAINS = {
@ -106,15 +99,13 @@ DOMAINS = {
"description": "desc of domain", "description": "desc of domain",
"enabled": 'true', "enabled": 'true',
"id": "37ef61", "id": "37ef61",
"links": { "links": {"self": "http://identity:35357/v3/domains/37ef61"},
"self": "http://identity:35357/v3/domains/37ef61" "name": "my domain",
},
"name": "my domain"
} }
], ],
"links": { "links": {
"self": "http://identity:35357/v3/OS-FEDERATION/domains", "self": "http://identity:35357/v3/OS-FEDERATION/domains",
"previous": 'null', "previous": 'null',
"next": 'null' "next": 'null',
} },
} }

View File

@ -24,7 +24,6 @@ from keystoneauth1.tests.unit import matchers
class AuthenticateviaADFSTests(utils.TestCase): class AuthenticateviaADFSTests(utils.TestCase):
GROUP = 'auth' GROUP = 'auth'
NAMESPACES = { NAMESPACES = {
@ -33,24 +32,23 @@ class AuthenticateviaADFSTests(utils.TestCase):
'wsa': 'http://www.w3.org/2005/08/addressing', 'wsa': 'http://www.w3.org/2005/08/addressing',
'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy', 'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy',
'a': 'http://www.w3.org/2005/08/addressing', 'a': 'http://www.w3.org/2005/08/addressing',
'o': ('http://docs.oasis-open.org/wss/2004/01/oasis' 'o': (
'-200401-wss-wssecurity-secext-1.0.xsd') 'http://docs.oasis-open.org/wss/2004/01/oasis'
'-200401-wss-wssecurity-secext-1.0.xsd'
),
} }
USER_XPATH = ('/s:Envelope/s:Header' USER_XPATH = '/s:Envelope/s:Header/o:Security/o:UsernameToken/o:Username'
'/o:Security' PASSWORD_XPATH = (
'/o:UsernameToken' '/s:Envelope/s:Header/o:Security/o:UsernameToken/o:Password'
'/o:Username') )
PASSWORD_XPATH = ('/s:Envelope/s:Header' ADDRESS_XPATH = (
'/o:Security' '/s:Envelope/s:Body'
'/o:UsernameToken'
'/o:Password')
ADDRESS_XPATH = ('/s:Envelope/s:Body'
'/trust:RequestSecurityToken' '/trust:RequestSecurityToken'
'/wsp:AppliesTo/wsa:EndpointReference' '/wsp:AppliesTo/wsa:EndpointReference'
'/wsa:Address') '/wsa:Address'
TO_XPATH = ('/s:Envelope/s:Header' )
'/a:To') TO_XPATH = '/s:Envelope/s:Header/a:To'
TEST_TOKEN = uuid.uuid4().hex TEST_TOKEN = uuid.uuid4().hex
@ -61,24 +59,32 @@ class AuthenticateviaADFSTests(utils.TestCase):
return '4b911420-4982-4009-8afc-5c596cd487f5' return '4b911420-4982-4009-8afc-5c596cd487f5'
def setUp(self): def setUp(self):
super(AuthenticateviaADFSTests, self).setUp() super().setUp()
self.IDENTITY_PROVIDER = 'adfs' self.IDENTITY_PROVIDER = 'adfs'
self.IDENTITY_PROVIDER_URL = ('http://adfs.local/adfs/service/trust/13' self.IDENTITY_PROVIDER_URL = (
'/usernamemixed') 'http://adfs.local/adfs/service/trust/13/usernamemixed'
self.FEDERATION_AUTH_URL = '%s/%s' % ( )
self.FEDERATION_AUTH_URL = '{}/{}'.format(
self.TEST_URL, self.TEST_URL,
'OS-FEDERATION/identity_providers/adfs/protocols/saml2/auth') 'OS-FEDERATION/identity_providers/adfs/protocols/saml2/auth',
)
self.SP_ENDPOINT = 'https://openstack4.local/Shibboleth.sso/ADFS' self.SP_ENDPOINT = 'https://openstack4.local/Shibboleth.sso/ADFS'
self.SP_ENTITYID = 'https://openstack4.local' self.SP_ENTITYID = 'https://openstack4.local'
self.adfsplugin = saml2.V3ADFSPassword( self.adfsplugin = saml2.V3ADFSPassword(
self.TEST_URL, self.IDENTITY_PROVIDER, self.TEST_URL,
self.IDENTITY_PROVIDER_URL, self.SP_ENDPOINT, self.IDENTITY_PROVIDER,
self.TEST_USER, self.TEST_TOKEN, self.PROTOCOL) self.IDENTITY_PROVIDER_URL,
self.SP_ENDPOINT,
self.TEST_USER,
self.TEST_TOKEN,
self.PROTOCOL,
)
self.ADFS_SECURITY_TOKEN_RESPONSE = utils._load_xml( self.ADFS_SECURITY_TOKEN_RESPONSE = utils._load_xml(
'ADFS_RequestSecurityTokenResponse.xml') 'ADFS_RequestSecurityTokenResponse.xml'
)
self.ADFS_FAULT = utils._load_xml('ADFS_fault.xml') self.ADFS_FAULT = utils._load_xml('ADFS_fault.xml')
def test_get_adfs_security_token(self): def test_get_adfs_security_token(self):
@ -86,7 +92,8 @@ class AuthenticateviaADFSTests(utils.TestCase):
self.requests_mock.post( self.requests_mock.post(
self.IDENTITY_PROVIDER_URL, self.IDENTITY_PROVIDER_URL,
content=utils.make_oneline(self.ADFS_SECURITY_TOKEN_RESPONSE), content=utils.make_oneline(self.ADFS_SECURITY_TOKEN_RESPONSE),
status_code=200) status_code=200,
)
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
self.adfsplugin._get_adfs_security_token(self.session) self.adfsplugin._get_adfs_security_token(self.session)
@ -94,59 +101,72 @@ class AuthenticateviaADFSTests(utils.TestCase):
adfs_response = etree.tostring(self.adfsplugin.adfs_token) adfs_response = etree.tostring(self.adfsplugin.adfs_token)
fixture_response = self.ADFS_SECURITY_TOKEN_RESPONSE fixture_response = self.ADFS_SECURITY_TOKEN_RESPONSE
self.assertThat(fixture_response, self.assertThat(fixture_response, matchers.XMLEquals(adfs_response))
matchers.XMLEquals(adfs_response))
def test_adfs_request_user(self): def test_adfs_request_user(self):
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
user = self.adfsplugin.prepared_request.xpath( user = self.adfsplugin.prepared_request.xpath(
self.USER_XPATH, namespaces=self.NAMESPACES)[0] self.USER_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.TEST_USER, user.text) self.assertEqual(self.TEST_USER, user.text)
def test_adfs_request_password(self): def test_adfs_request_password(self):
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
password = self.adfsplugin.prepared_request.xpath( password = self.adfsplugin.prepared_request.xpath(
self.PASSWORD_XPATH, namespaces=self.NAMESPACES)[0] self.PASSWORD_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.TEST_TOKEN, password.text) self.assertEqual(self.TEST_TOKEN, password.text)
def test_adfs_request_to(self): def test_adfs_request_to(self):
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
to = self.adfsplugin.prepared_request.xpath( to = self.adfsplugin.prepared_request.xpath(
self.TO_XPATH, namespaces=self.NAMESPACES)[0] self.TO_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.IDENTITY_PROVIDER_URL, to.text) self.assertEqual(self.IDENTITY_PROVIDER_URL, to.text)
def test_prepare_adfs_request_address(self): def test_prepare_adfs_request_address(self):
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
address = self.adfsplugin.prepared_request.xpath( address = self.adfsplugin.prepared_request.xpath(
self.ADDRESS_XPATH, namespaces=self.NAMESPACES)[0] self.ADDRESS_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.SP_ENDPOINT, address.text) self.assertEqual(self.SP_ENDPOINT, address.text)
def test_prepare_adfs_request_custom_endpointreference(self): def test_prepare_adfs_request_custom_endpointreference(self):
self.adfsplugin = saml2.V3ADFSPassword( self.adfsplugin = saml2.V3ADFSPassword(
self.TEST_URL, self.IDENTITY_PROVIDER, self.TEST_URL,
self.IDENTITY_PROVIDER_URL, self.SP_ENDPOINT, self.IDENTITY_PROVIDER,
self.TEST_USER, self.TEST_TOKEN, self.PROTOCOL, self.SP_ENTITYID) self.IDENTITY_PROVIDER_URL,
self.SP_ENDPOINT,
self.TEST_USER,
self.TEST_TOKEN,
self.PROTOCOL,
self.SP_ENTITYID,
)
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
address = self.adfsplugin.prepared_request.xpath( address = self.adfsplugin.prepared_request.xpath(
self.ADDRESS_XPATH, namespaces=self.NAMESPACES)[0] self.ADDRESS_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.SP_ENTITYID, address.text) self.assertEqual(self.SP_ENTITYID, address.text)
def test_prepare_sp_request(self): def test_prepare_sp_request(self):
assertion = etree.XML(self.ADFS_SECURITY_TOKEN_RESPONSE) assertion = etree.XML(self.ADFS_SECURITY_TOKEN_RESPONSE)
assertion = assertion.xpath( assertion = assertion.xpath(
saml2.V3ADFSPassword.ADFS_ASSERTION_XPATH, saml2.V3ADFSPassword.ADFS_ASSERTION_XPATH,
namespaces=saml2.V3ADFSPassword.ADFS_TOKEN_NAMESPACES) namespaces=saml2.V3ADFSPassword.ADFS_TOKEN_NAMESPACES,
)
assertion = assertion[0] assertion = assertion[0]
assertion = etree.tostring(assertion) assertion = etree.tostring(assertion)
assertion = assertion.replace( assertion = assertion.replace(
b'http://docs.oasis-open.org/ws-sx/ws-trust/200512', b'http://docs.oasis-open.org/ws-sx/ws-trust/200512',
b'http://schemas.xmlsoap.org/ws/2005/02/trust') b'http://schemas.xmlsoap.org/ws/2005/02/trust',
)
assertion = urllib.parse.quote(assertion) assertion = urllib.parse.quote(assertion)
assertion = 'wa=wsignin1.0&wresult=' + assertion assertion = 'wa=wsignin1.0&wresult=' + assertion
self.adfsplugin.adfs_token = etree.XML( self.adfsplugin.adfs_token = etree.XML(
self.ADFS_SECURITY_TOKEN_RESPONSE) self.ADFS_SECURITY_TOKEN_RESPONSE
)
self.adfsplugin._prepare_sp_request() self.adfsplugin._prepare_sp_request()
self.assertEqual(assertion, self.adfsplugin.encoded_assertion) self.assertEqual(assertion, self.adfsplugin.encoded_assertion)
@ -158,15 +178,19 @@ class AuthenticateviaADFSTests(utils.TestCase):
error message from the XML message indicating where was the problem. error message from the XML message indicating where was the problem.
""" """
content = utils.make_oneline(self.ADFS_FAULT) content = utils.make_oneline(self.ADFS_FAULT)
self.requests_mock.register_uri('POST', self.requests_mock.register_uri(
'POST',
self.IDENTITY_PROVIDER_URL, self.IDENTITY_PROVIDER_URL,
content=content, content=content,
status_code=500) status_code=500,
)
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
self.assertRaises(exceptions.AuthorizationFailure, self.assertRaises(
exceptions.AuthorizationFailure,
self.adfsplugin._get_adfs_security_token, self.adfsplugin._get_adfs_security_token,
self.session) self.session,
)
# TODO(marek-denis): Python3 tests complain about missing 'message' # TODO(marek-denis): Python3 tests complain about missing 'message'
# attributes # attributes
# self.assertEqual('a:FailedAuthentication', e.message) # self.assertEqual('a:FailedAuthentication', e.message)
@ -178,14 +202,18 @@ class AuthenticateviaADFSTests(utils.TestCase):
and correctly raise exceptions.InternalServerError once it cannot and correctly raise exceptions.InternalServerError once it cannot
parse XML fault message parse XML fault message
""" """
self.requests_mock.register_uri('POST', self.requests_mock.register_uri(
'POST',
self.IDENTITY_PROVIDER_URL, self.IDENTITY_PROVIDER_URL,
content=b'NOT XML', content=b'NOT XML',
status_code=500) status_code=500,
)
self.adfsplugin._prepare_adfs_request() self.adfsplugin._prepare_adfs_request()
self.assertRaises(exceptions.InternalServerError, self.assertRaises(
exceptions.InternalServerError,
self.adfsplugin._get_adfs_security_token, self.adfsplugin._get_adfs_security_token,
self.session) self.session,
)
# TODO(marek-denis): Need to figure out how to properly send cookies # TODO(marek-denis): Need to figure out how to properly send cookies
# from the request_mock methods. # from the request_mock methods.
@ -193,9 +221,9 @@ class AuthenticateviaADFSTests(utils.TestCase):
"""Test whether SP issues a cookie.""" """Test whether SP issues a cookie."""
cookie = uuid.uuid4().hex cookie = uuid.uuid4().hex
self.requests_mock.post(self.SP_ENDPOINT, self.requests_mock.post(
headers={"set-cookie": cookie}, self.SP_ENDPOINT, headers={"set-cookie": cookie}, status_code=302
status_code=302) )
self.adfsplugin.adfs_token = self._build_adfs_request() self.adfsplugin.adfs_token = self._build_adfs_request()
self.adfsplugin._prepare_sp_request() self.adfsplugin._prepare_sp_request()
@ -204,55 +232,70 @@ class AuthenticateviaADFSTests(utils.TestCase):
self.assertEqual(1, len(self.session.session.cookies)) self.assertEqual(1, len(self.session.session.cookies))
def test_send_assertion_to_service_provider_bad_status(self): def test_send_assertion_to_service_provider_bad_status(self):
self.requests_mock.register_uri('POST', self.SP_ENDPOINT, self.requests_mock.register_uri(
status_code=500) 'POST', self.SP_ENDPOINT, status_code=500
)
self.adfsplugin.adfs_token = etree.XML( self.adfsplugin.adfs_token = etree.XML(
self.ADFS_SECURITY_TOKEN_RESPONSE) self.ADFS_SECURITY_TOKEN_RESPONSE
)
self.adfsplugin._prepare_sp_request() self.adfsplugin._prepare_sp_request()
self.assertRaises( self.assertRaises(
exceptions.InternalServerError, exceptions.InternalServerError,
self.adfsplugin._send_assertion_to_service_provider, self.adfsplugin._send_assertion_to_service_provider,
self.session) self.session,
)
def test_access_sp_no_cookies_fail(self): def test_access_sp_no_cookies_fail(self):
# clean cookie jar # clean cookie jar
self.session.session.cookies = [] self.session.session.cookies = []
self.assertRaises(exceptions.AuthorizationFailure, self.assertRaises(
exceptions.AuthorizationFailure,
self.adfsplugin._access_service_provider, self.adfsplugin._access_service_provider,
self.session) self.session,
)
def test_check_valid_token_when_authenticated(self): def test_check_valid_token_when_authenticated(self):
self.requests_mock.register_uri( self.requests_mock.register_uri(
'GET', self.FEDERATION_AUTH_URL, 'GET',
self.FEDERATION_AUTH_URL,
json=saml2_fixtures.UNSCOPED_TOKEN, json=saml2_fixtures.UNSCOPED_TOKEN,
headers=client_fixtures.AUTH_RESPONSE_HEADERS) headers=client_fixtures.AUTH_RESPONSE_HEADERS,
)
self.session.session.cookies = [object()] self.session.session.cookies = [object()]
self.adfsplugin._access_service_provider(self.session) self.adfsplugin._access_service_provider(self.session)
response = self.adfsplugin.authenticated_response response = self.adfsplugin.authenticated_response
self.assertEqual(client_fixtures.AUTH_RESPONSE_HEADERS, self.assertEqual(
response.headers) client_fixtures.AUTH_RESPONSE_HEADERS, response.headers
)
self.assertEqual(saml2_fixtures.UNSCOPED_TOKEN['token'], self.assertEqual(
response.json()['token']) saml2_fixtures.UNSCOPED_TOKEN['token'], response.json()['token']
)
def test_end_to_end_workflow(self): def test_end_to_end_workflow(self):
self.requests_mock.register_uri( self.requests_mock.register_uri(
'POST', self.IDENTITY_PROVIDER_URL, 'POST',
self.IDENTITY_PROVIDER_URL,
content=self.ADFS_SECURITY_TOKEN_RESPONSE, content=self.ADFS_SECURITY_TOKEN_RESPONSE,
status_code=200) status_code=200,
)
self.requests_mock.register_uri( self.requests_mock.register_uri(
'POST', self.SP_ENDPOINT, 'POST',
self.SP_ENDPOINT,
headers={"set-cookie": 'x'}, headers={"set-cookie": 'x'},
status_code=302) status_code=302,
)
self.requests_mock.register_uri( self.requests_mock.register_uri(
'GET', self.FEDERATION_AUTH_URL, 'GET',
self.FEDERATION_AUTH_URL,
json=saml2_fixtures.UNSCOPED_TOKEN, json=saml2_fixtures.UNSCOPED_TOKEN,
headers=client_fixtures.AUTH_RESPONSE_HEADERS) headers=client_fixtures.AUTH_RESPONSE_HEADERS,
)
# NOTE(marek-denis): We need to mimic this until self.requests_mock can # NOTE(marek-denis): We need to mimic this until self.requests_mock can
# issue cookies properly. # issue cookies properly.

View File

@ -53,8 +53,8 @@ class SamlAuth2PluginTests(utils.TestCase):
return [r.url.strip('/') for r in self.requests_mock.request_history] return [r.url.strip('/') for r in self.requests_mock.request_history]
def basic_header(self, username=TEST_USER, password=TEST_PASS): def basic_header(self, username=TEST_USER, password=TEST_PASS):
user_pass = ('%s:%s' % (username, password)).encode('utf-8') user_pass = (f'{username}:{password}').encode()
return 'Basic %s' % base64.b64encode(user_pass).decode('utf-8') return 'Basic {}'.format(base64.b64encode(user_pass).decode('utf-8'))
def test_request_accept_headers(self): def test_request_accept_headers(self):
# Include some random Accept header # Include some random Accept header
@ -70,18 +70,23 @@ class SamlAuth2PluginTests(utils.TestCase):
# added the PAOS_HEADER to it using the correct media type separator # added the PAOS_HEADER to it using the correct media type separator
accept_header = plugin_headers['Accept'] accept_header = plugin_headers['Accept']
self.assertIn(self.HEADER_MEDIA_TYPE_SEPARATOR, accept_header) self.assertIn(self.HEADER_MEDIA_TYPE_SEPARATOR, accept_header)
self.assertIn(random_header, self.assertIn(
accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR)) random_header,
self.assertIn(PAOS_HEADER, accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR),
accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR)) )
self.assertIn(
PAOS_HEADER, accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR)
)
def test_passed_when_not_200(self): def test_passed_when_not_200(self):
text = uuid.uuid4().hex text = uuid.uuid4().hex
test_url = 'http://another.test' test_url = 'http://another.test'
self.requests_mock.get(test_url, self.requests_mock.get(
test_url,
status_code=201, status_code=201,
headers=CONTENT_TYPE_PAOS_HEADER, headers=CONTENT_TYPE_PAOS_HEADER,
text=text) text=text,
)
resp = requests.get(test_url, auth=self.get_plugin()) resp = requests.get(test_url, auth=self.get_plugin())
self.assertEqual(201, resp.status_code) self.assertEqual(201, resp.status_code)
@ -99,82 +104,115 @@ class SamlAuth2PluginTests(utils.TestCase):
def test_standard_workflow_302_redirect(self): def test_standard_workflow_302_redirect(self):
text = uuid.uuid4().hex text = uuid.uuid4().hex
self.requests_mock.get(self.TEST_SP_URL, response_list=[ self.requests_mock.get(
dict(headers=CONTENT_TYPE_PAOS_HEADER, self.TEST_SP_URL,
content=utils.make_oneline(saml2_fixtures.SP_SOAP_RESPONSE)), response_list=[
dict(text=text) {
]) 'headers': CONTENT_TYPE_PAOS_HEADER,
'content': utils.make_oneline(
saml2_fixtures.SP_SOAP_RESPONSE
),
},
{'text': text},
],
)
authm = self.requests_mock.post(self.TEST_IDP_URL, authm = self.requests_mock.post(
content=saml2_fixtures.SAML2_ASSERTION) self.TEST_IDP_URL, content=saml2_fixtures.SAML2_ASSERTION
)
self.requests_mock.post( self.requests_mock.post(
self.TEST_CONSUMER_URL, self.TEST_CONSUMER_URL,
status_code=302, status_code=302,
headers={'Location': self.TEST_SP_URL}) headers={'Location': self.TEST_SP_URL},
)
resp = requests.get(self.TEST_SP_URL, auth=self.get_plugin()) resp = requests.get(self.TEST_SP_URL, auth=self.get_plugin())
self.assertEqual(200, resp.status_code) self.assertEqual(200, resp.status_code)
self.assertEqual(text, resp.text) self.assertEqual(text, resp.text)
self.assertEqual(self.calls, [self.TEST_SP_URL, self.assertEqual(
self.calls,
[
self.TEST_SP_URL,
self.TEST_IDP_URL, self.TEST_IDP_URL,
self.TEST_CONSUMER_URL, self.TEST_CONSUMER_URL,
self.TEST_SP_URL]) self.TEST_SP_URL,
],
)
self.assertEqual(self.basic_header(), self.assertEqual(
authm.last_request.headers['Authorization']) self.basic_header(), authm.last_request.headers['Authorization']
)
authn_request = self.requests_mock.request_history[1].text authn_request = self.requests_mock.request_history[1].text
self.assertThat(saml2_fixtures.AUTHN_REQUEST, self.assertThat(
matchers.XMLEquals(authn_request)) saml2_fixtures.AUTHN_REQUEST, matchers.XMLEquals(authn_request)
)
def test_standard_workflow_303_redirect(self): def test_standard_workflow_303_redirect(self):
text = uuid.uuid4().hex text = uuid.uuid4().hex
self.requests_mock.get(self.TEST_SP_URL, response_list=[ self.requests_mock.get(
dict(headers=CONTENT_TYPE_PAOS_HEADER, self.TEST_SP_URL,
content=utils.make_oneline(saml2_fixtures.SP_SOAP_RESPONSE)), response_list=[
dict(text=text) {
]) 'headers': CONTENT_TYPE_PAOS_HEADER,
'content': utils.make_oneline(
saml2_fixtures.SP_SOAP_RESPONSE
),
},
{'text': text},
],
)
authm = self.requests_mock.post(self.TEST_IDP_URL, authm = self.requests_mock.post(
content=saml2_fixtures.SAML2_ASSERTION) self.TEST_IDP_URL, content=saml2_fixtures.SAML2_ASSERTION
)
self.requests_mock.post( self.requests_mock.post(
self.TEST_CONSUMER_URL, self.TEST_CONSUMER_URL,
status_code=303, status_code=303,
headers={'Location': self.TEST_SP_URL}) headers={'Location': self.TEST_SP_URL},
)
resp = requests.get(self.TEST_SP_URL, auth=self.get_plugin()) resp = requests.get(self.TEST_SP_URL, auth=self.get_plugin())
self.assertEqual(200, resp.status_code) self.assertEqual(200, resp.status_code)
self.assertEqual(text, resp.text) self.assertEqual(text, resp.text)
url_flow = [self.TEST_SP_URL, url_flow = [
self.TEST_SP_URL,
self.TEST_IDP_URL, self.TEST_IDP_URL,
self.TEST_CONSUMER_URL, self.TEST_CONSUMER_URL,
self.TEST_SP_URL] self.TEST_SP_URL,
]
self.assertEqual(url_flow, [r.url.rstrip('/') for r in resp.history]) self.assertEqual(url_flow, [r.url.rstrip('/') for r in resp.history])
self.assertEqual(url_flow, self.calls) self.assertEqual(url_flow, self.calls)
self.assertEqual(self.basic_header(), self.assertEqual(
authm.last_request.headers['Authorization']) self.basic_header(), authm.last_request.headers['Authorization']
)
authn_request = self.requests_mock.request_history[1].text authn_request = self.requests_mock.request_history[1].text
self.assertThat(saml2_fixtures.AUTHN_REQUEST, self.assertThat(
matchers.XMLEquals(authn_request)) saml2_fixtures.AUTHN_REQUEST, matchers.XMLEquals(authn_request)
)
def test_initial_sp_call_invalid_response(self): def test_initial_sp_call_invalid_response(self):
"""Send initial SP HTTP request and receive wrong server response.""" """Send initial SP HTTP request and receive wrong server response."""
self.requests_mock.get(self.TEST_SP_URL, self.requests_mock.get(
self.TEST_SP_URL,
headers=CONTENT_TYPE_PAOS_HEADER, headers=CONTENT_TYPE_PAOS_HEADER,
text='NON XML RESPONSE') text='NON XML RESPONSE',
)
self.assertRaises(InvalidResponse, self.assertRaises(
InvalidResponse,
requests.get, requests.get,
self.TEST_SP_URL, self.TEST_SP_URL,
auth=self.get_plugin()) auth=self.get_plugin(),
)
self.assertEqual(self.calls, [self.TEST_SP_URL]) self.assertEqual(self.calls, [self.TEST_SP_URL])
@ -184,25 +222,28 @@ class SamlAuth2PluginTests(utils.TestCase):
soap_response = saml2_fixtures.soap_response(consumer=consumer1) soap_response = saml2_fixtures.soap_response(consumer=consumer1)
saml_assertion = saml2_fixtures.saml_assertion(destination=consumer2) saml_assertion = saml2_fixtures.saml_assertion(destination=consumer2)
self.requests_mock.get(self.TEST_SP_URL, self.requests_mock.get(
self.TEST_SP_URL,
headers=CONTENT_TYPE_PAOS_HEADER, headers=CONTENT_TYPE_PAOS_HEADER,
content=soap_response) content=soap_response,
)
self.requests_mock.post(self.TEST_IDP_URL, content=saml_assertion) self.requests_mock.post(self.TEST_IDP_URL, content=saml_assertion)
# receive the SAML error, body unchecked # receive the SAML error, body unchecked
saml_error = self.requests_mock.post(consumer1) saml_error = self.requests_mock.post(consumer1)
self.assertRaises(saml2.v3.saml2.ConsumerMismatch, self.assertRaises(
saml2.v3.saml2.ConsumerMismatch,
requests.get, requests.get,
self.TEST_SP_URL, self.TEST_SP_URL,
auth=self.get_plugin()) auth=self.get_plugin(),
)
self.assertTrue(saml_error.called) self.assertTrue(saml_error.called)
class AuthenticateviaSAML2Tests(utils.TestCase): class AuthenticateviaSAML2Tests(utils.TestCase):
TEST_USER = 'user' TEST_USER = 'user'
TEST_PASS = 'pass' TEST_PASS = 'pass'
TEST_IDP = 'tester' TEST_IDP = 'tester'
@ -226,8 +267,10 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
kwargs.setdefault('identity_provider', self.TEST_IDP) kwargs.setdefault('identity_provider', self.TEST_IDP)
kwargs.setdefault('protocol', self.TEST_PROTOCOL) kwargs.setdefault('protocol', self.TEST_PROTOCOL)
templ = ('%(base)s/OS-FEDERATION/identity_providers/' templ = (
'%(identity_provider)s/protocols/%(protocol)s/auth') '%(base)s/OS-FEDERATION/identity_providers/'
'%(identity_provider)s/protocols/%(protocol)s/auth'
)
return templ % kwargs return templ % kwargs
@property @property
@ -235,11 +278,11 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
return [r.url.strip('/') for r in self.requests_mock.request_history] return [r.url.strip('/') for r in self.requests_mock.request_history]
def basic_header(self, username=TEST_USER, password=TEST_PASS): def basic_header(self, username=TEST_USER, password=TEST_PASS):
user_pass = ('%s:%s' % (username, password)).encode('utf-8') user_pass = (f'{username}:{password}').encode()
return 'Basic %s' % base64.b64encode(user_pass).decode('utf-8') return 'Basic {}'.format(base64.b64encode(user_pass).decode('utf-8'))
def setUp(self): def setUp(self):
super(AuthenticateviaSAML2Tests, self).setUp() super().setUp()
self.session = session.Session() self.session = session.Session()
self.default_sp_url = self.sp_url() self.default_sp_url = self.sp_url()
@ -247,35 +290,51 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
token_id = uuid.uuid4().hex token_id = uuid.uuid4().hex
token = ksa_fixtures.V3Token() token = ksa_fixtures.V3Token()
self.requests_mock.get(self.default_sp_url, response_list=[ self.requests_mock.get(
dict(headers=CONTENT_TYPE_PAOS_HEADER, self.default_sp_url,
content=utils.make_oneline(saml2_fixtures.SP_SOAP_RESPONSE)), response_list=[
dict(headers={'X-Subject-Token': token_id}, json=token) {
]) 'headers': CONTENT_TYPE_PAOS_HEADER,
'content': utils.make_oneline(
saml2_fixtures.SP_SOAP_RESPONSE
),
},
{'headers': {'X-Subject-Token': token_id}, 'json': token},
],
)
authm = self.requests_mock.post(self.TEST_IDP_URL, authm = self.requests_mock.post(
content=saml2_fixtures.SAML2_ASSERTION) self.TEST_IDP_URL, content=saml2_fixtures.SAML2_ASSERTION
)
self.requests_mock.post( self.requests_mock.post(
self.TEST_CONSUMER_URL, self.TEST_CONSUMER_URL,
status_code=302, status_code=302,
headers={'Location': self.sp_url()}) headers={'Location': self.sp_url()},
)
auth_ref = self.get_plugin().get_auth_ref(self.session) auth_ref = self.get_plugin().get_auth_ref(self.session)
self.assertEqual(token_id, auth_ref.auth_token) self.assertEqual(token_id, auth_ref.auth_token)
self.assertEqual(self.calls, [self.default_sp_url, self.assertEqual(
self.calls,
[
self.default_sp_url,
self.TEST_IDP_URL, self.TEST_IDP_URL,
self.TEST_CONSUMER_URL, self.TEST_CONSUMER_URL,
self.default_sp_url]) self.default_sp_url,
],
)
self.assertEqual(self.basic_header(), self.assertEqual(
authm.last_request.headers['Authorization']) self.basic_header(), authm.last_request.headers['Authorization']
)
authn_request = self.requests_mock.request_history[1].text authn_request = self.requests_mock.request_history[1].text
self.assertThat(saml2_fixtures.AUTHN_REQUEST, self.assertThat(
matchers.XMLEquals(authn_request)) saml2_fixtures.AUTHN_REQUEST, matchers.XMLEquals(authn_request)
)
def test_consumer_mismatch_error_workflow(self): def test_consumer_mismatch_error_workflow(self):
consumer1 = 'http://keystone.test/Shibboleth.sso/SAML2/ECP' consumer1 = 'http://keystone.test/Shibboleth.sso/SAML2/ECP'
@ -284,29 +343,37 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
soap_response = saml2_fixtures.soap_response(consumer=consumer1) soap_response = saml2_fixtures.soap_response(consumer=consumer1)
saml_assertion = saml2_fixtures.saml_assertion(destination=consumer2) saml_assertion = saml2_fixtures.saml_assertion(destination=consumer2)
self.requests_mock.get(self.default_sp_url, self.requests_mock.get(
self.default_sp_url,
headers=CONTENT_TYPE_PAOS_HEADER, headers=CONTENT_TYPE_PAOS_HEADER,
content=soap_response) content=soap_response,
)
self.requests_mock.post(self.TEST_IDP_URL, content=saml_assertion) self.requests_mock.post(self.TEST_IDP_URL, content=saml_assertion)
# receive the SAML error, body unchecked # receive the SAML error, body unchecked
saml_error = self.requests_mock.post(consumer1) saml_error = self.requests_mock.post(consumer1)
self.assertRaises(exceptions.AuthorizationFailure, self.assertRaises(
exceptions.AuthorizationFailure,
self.get_plugin().get_auth_ref, self.get_plugin().get_auth_ref,
self.session) self.session,
)
self.assertTrue(saml_error.called) self.assertTrue(saml_error.called)
def test_initial_sp_call_invalid_response(self): def test_initial_sp_call_invalid_response(self):
"""Send initial SP HTTP request and receive wrong server response.""" """Send initial SP HTTP request and receive wrong server response."""
self.requests_mock.get(self.default_sp_url, self.requests_mock.get(
self.default_sp_url,
headers=CONTENT_TYPE_PAOS_HEADER, headers=CONTENT_TYPE_PAOS_HEADER,
text='NON XML RESPONSE') text='NON XML RESPONSE',
)
self.assertRaises(exceptions.AuthorizationFailure, self.assertRaises(
exceptions.AuthorizationFailure,
self.get_plugin().get_auth_ref, self.get_plugin().get_auth_ref,
self.session) self.session,
)
self.assertEqual(self.calls, [self.default_sp_url]) self.assertEqual(self.calls, [self.default_sp_url])

View File

@ -31,9 +31,8 @@ def _load_xml(filename):
class TestCase(utils.TestCase): class TestCase(utils.TestCase):
TEST_URL = 'https://keystone:5000/v3' TEST_URL = 'https://keystone:5000/v3'
def setUp(self): def setUp(self):
super(TestCase, self).setUp() super().setUp()
self.session = session.Session() self.session = session.Session()

View File

@ -21,9 +21,8 @@ from keystoneauth1.tests.unit import utils
class AccessInfoPluginTests(utils.TestCase): class AccessInfoPluginTests(utils.TestCase):
def setUp(self): def setUp(self):
super(AccessInfoPluginTests, self).setUp() super().setUp()
self.session = session.Session() self.session = session.Session()
self.auth_token = uuid.uuid4().hex self.auth_token = uuid.uuid4().hex
@ -37,19 +36,22 @@ class AccessInfoPluginTests(utils.TestCase):
def test_auth_ref(self): def test_auth_ref(self):
plugin_obj = self._plugin() plugin_obj = self._plugin()
self.assertEqual(self.TEST_ROOT_URL, self.assertEqual(
plugin_obj.get_endpoint(self.session, self.TEST_ROOT_URL,
service_type='identity', plugin_obj.get_endpoint(
interface='public')) self.session, service_type='identity', interface='public'
),
)
self.assertEqual(self.auth_token, plugin_obj.get_token(session)) self.assertEqual(self.auth_token, plugin_obj.get_token(session))
def test_auth_url(self): def test_auth_url(self):
auth_url = 'http://keystone.test.url' auth_url = 'http://keystone.test.url'
obj = self._plugin(auth_url=auth_url) obj = self._plugin(auth_url=auth_url)
self.assertEqual(auth_url, self.assertEqual(
obj.get_endpoint(self.session, auth_url,
interface=plugin.AUTH_INTERFACE)) obj.get_endpoint(self.session, interface=plugin.AUTH_INTERFACE),
)
def test_invalidate(self): def test_invalidate(self):
plugin = self._plugin() plugin = self._plugin()

File diff suppressed because it is too large Load Diff

View File

@ -25,78 +25,89 @@ from keystoneauth1.tests.unit import utils
class V2IdentityPlugin(utils.TestCase): class V2IdentityPlugin(utils.TestCase):
TEST_ROOT_URL = 'http://127.0.0.1:5000/' TEST_ROOT_URL = 'http://127.0.0.1:5000/'
TEST_URL = '%s%s' % (TEST_ROOT_URL, 'v2.0') TEST_URL = '{}{}'.format(TEST_ROOT_URL, 'v2.0')
TEST_ROOT_ADMIN_URL = 'http://127.0.0.1:35357/' TEST_ROOT_ADMIN_URL = 'http://127.0.0.1:35357/'
TEST_ADMIN_URL = '%s%s' % (TEST_ROOT_ADMIN_URL, 'v2.0') TEST_ADMIN_URL = '{}{}'.format(TEST_ROOT_ADMIN_URL, 'v2.0')
TEST_PASS = 'password' TEST_PASS = 'password'
TEST_SERVICE_CATALOG = [{ TEST_SERVICE_CATALOG = [
"endpoints": [{ {
"endpoints": [
{
"adminURL": "http://cdn.admin-nets.local:8774/v1.0", "adminURL": "http://cdn.admin-nets.local:8774/v1.0",
"region": "RegionOne", "region": "RegionOne",
"internalURL": "http://127.0.0.1:8774/v1.0", "internalURL": "http://127.0.0.1:8774/v1.0",
"publicURL": "http://cdn.admin-nets.local:8774/v1.0/" "publicURL": "http://cdn.admin-nets.local:8774/v1.0/",
}], }
],
"type": "nova_compat", "type": "nova_compat",
"name": "nova_compat" "name": "nova_compat",
}, { },
"endpoints": [{ {
"endpoints": [
{
"adminURL": "http://nova/novapi/admin", "adminURL": "http://nova/novapi/admin",
"region": "RegionOne", "region": "RegionOne",
"internalURL": "http://nova/novapi/internal", "internalURL": "http://nova/novapi/internal",
"publicURL": "http://nova/novapi/public" "publicURL": "http://nova/novapi/public",
}], }
],
"type": "compute", "type": "compute",
"name": "nova" "name": "nova",
}, { },
"endpoints": [{ {
"endpoints": [
{
"adminURL": "http://glance/glanceapi/admin", "adminURL": "http://glance/glanceapi/admin",
"region": "RegionOne", "region": "RegionOne",
"internalURL": "http://glance/glanceapi/internal", "internalURL": "http://glance/glanceapi/internal",
"publicURL": "http://glance/glanceapi/public" "publicURL": "http://glance/glanceapi/public",
}], }
],
"type": "image", "type": "image",
"name": "glance" "name": "glance",
}, { },
"endpoints": [{ {
"endpoints": [
{
"adminURL": TEST_ADMIN_URL, "adminURL": TEST_ADMIN_URL,
"region": "RegionOne", "region": "RegionOne",
"internalURL": "http://127.0.0.1:5000/v2.0", "internalURL": "http://127.0.0.1:5000/v2.0",
"publicURL": "http://127.0.0.1:5000/v2.0" "publicURL": "http://127.0.0.1:5000/v2.0",
}], }
],
"type": "identity", "type": "identity",
"name": "keystone" "name": "keystone",
}, { },
"endpoints": [{ {
"endpoints": [
{
"adminURL": "http://swift/swiftapi/admin", "adminURL": "http://swift/swiftapi/admin",
"region": "RegionOne", "region": "RegionOne",
"internalURL": "http://swift/swiftapi/internal", "internalURL": "http://swift/swiftapi/internal",
"publicURL": "http://swift/swiftapi/public" "publicURL": "http://swift/swiftapi/public",
}], }
],
"type": "object-store", "type": "object-store",
"name": "swift" "name": "swift",
}] },
]
def setUp(self): def setUp(self):
super(V2IdentityPlugin, self).setUp() super().setUp()
self.TEST_RESPONSE_DICT = { self.TEST_RESPONSE_DICT = {
"access": { "access": {
"token": { "token": {
"expires": "%i-02-01T00:00:10.000123Z" % "expires": "%i-02-01T00:00:10.000123Z"
(1 + time.gmtime().tm_year), % (1 + time.gmtime().tm_year),
"id": self.TEST_TOKEN, "id": self.TEST_TOKEN,
"tenant": { "tenant": {"id": self.TEST_TENANT_ID},
"id": self.TEST_TENANT_ID
},
},
"user": {
"id": self.TEST_USER
}, },
"user": {"id": self.TEST_USER},
"serviceCatalog": self.TEST_SERVICE_CATALOG, "serviceCatalog": self.TEST_SERVICE_CATALOG,
}, }
} }
def stub_auth(self, **kwargs): def stub_auth(self, **kwargs):
@ -104,16 +115,24 @@ class V2IdentityPlugin(utils.TestCase):
def test_authenticate_with_username_password(self): def test_authenticate_with_username_password(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
self.assertIsNone(a.user_id) self.assertIsNone(a.user_id)
self.assertFalse(a.has_scope_parameters) self.assertFalse(a.has_scope_parameters)
s = session.Session(a) s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN}, self.assertEqual(
s.get_auth_headers()) {'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'username': self.TEST_USER, req = {
'password': self.TEST_PASS}}} 'auth': {
'passwordCredentials': {
'username': self.TEST_USER,
'password': self.TEST_PASS,
}
}
}
self.assertRequestBodyIs(json=req) self.assertRequestBodyIs(json=req)
self.assertRequestHeaderEqual('Content-Type', 'application/json') self.assertRequestHeaderEqual('Content-Type', 'application/json')
self.assertRequestHeaderEqual('Accept', 'application/json') self.assertRequestHeaderEqual('Accept', 'application/json')
@ -121,16 +140,24 @@ class V2IdentityPlugin(utils.TestCase):
def test_authenticate_with_user_id_password(self): def test_authenticate_with_user_id_password(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, user_id=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, user_id=self.TEST_USER, password=self.TEST_PASS
)
self.assertIsNone(a.username) self.assertIsNone(a.username)
self.assertFalse(a.has_scope_parameters) self.assertFalse(a.has_scope_parameters)
s = session.Session(a) s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN}, self.assertEqual(
s.get_auth_headers()) {'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'userId': self.TEST_USER, req = {
'password': self.TEST_PASS}}} 'auth': {
'passwordCredentials': {
'userId': self.TEST_USER,
'password': self.TEST_PASS,
}
}
}
self.assertRequestBodyIs(json=req) self.assertRequestBodyIs(json=req)
self.assertRequestHeaderEqual('Content-Type', 'application/json') self.assertRequestHeaderEqual('Content-Type', 'application/json')
self.assertRequestHeaderEqual('Accept', 'application/json') self.assertRequestHeaderEqual('Accept', 'application/json')
@ -138,33 +165,55 @@ class V2IdentityPlugin(utils.TestCase):
def test_authenticate_with_username_password_scoped(self): def test_authenticate_with_username_password_scoped(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS, tenant_id=self.TEST_TENANT_ID) self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=self.TEST_TENANT_ID,
)
self.assertTrue(a.has_scope_parameters) self.assertTrue(a.has_scope_parameters)
self.assertIsNone(a.user_id) self.assertIsNone(a.user_id)
s = session.Session(a) s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN}, self.assertEqual(
s.get_auth_headers()) {'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'username': self.TEST_USER, req = {
'password': self.TEST_PASS}, 'auth': {
'tenantId': self.TEST_TENANT_ID}} 'passwordCredentials': {
'username': self.TEST_USER,
'password': self.TEST_PASS,
},
'tenantId': self.TEST_TENANT_ID,
}
}
self.assertRequestBodyIs(json=req) self.assertRequestBodyIs(json=req)
self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN) self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN)
def test_authenticate_with_user_id_password_scoped(self): def test_authenticate_with_user_id_password_scoped(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, user_id=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS, tenant_id=self.TEST_TENANT_ID) self.TEST_URL,
user_id=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=self.TEST_TENANT_ID,
)
self.assertIsNone(a.username) self.assertIsNone(a.username)
self.assertTrue(a.has_scope_parameters) self.assertTrue(a.has_scope_parameters)
s = session.Session(a) s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN}, self.assertEqual(
s.get_auth_headers()) {'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'userId': self.TEST_USER, req = {
'password': self.TEST_PASS}, 'auth': {
'tenantId': self.TEST_TENANT_ID}} 'passwordCredentials': {
'userId': self.TEST_USER,
'password': self.TEST_PASS,
},
'tenantId': self.TEST_TENANT_ID,
}
}
self.assertRequestBodyIs(json=req) self.assertRequestBodyIs(json=req)
self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN) self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN)
@ -172,8 +221,9 @@ class V2IdentityPlugin(utils.TestCase):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Token(self.TEST_URL, 'foo') a = v2.Token(self.TEST_URL, 'foo')
s = session.Session(a) s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN}, self.assertEqual(
s.get_auth_headers()) {'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'token': {'id': 'foo'}}} req = {'auth': {'token': {'id': 'foo'}}}
self.assertRequestBodyIs(json=req) self.assertRequestBodyIs(json=req)
@ -184,40 +234,55 @@ class V2IdentityPlugin(utils.TestCase):
def test_with_trust_id(self): def test_with_trust_id(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS, trust_id='trust') self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
trust_id='trust',
)
self.assertTrue(a.has_scope_parameters) self.assertTrue(a.has_scope_parameters)
s = session.Session(a) s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN}, self.assertEqual(
s.get_auth_headers()) {'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'username': self.TEST_USER, req = {
'password': self.TEST_PASS}, 'auth': {
'trust_id': 'trust'}} 'passwordCredentials': {
'username': self.TEST_USER,
'password': self.TEST_PASS,
},
'trust_id': 'trust',
}
}
self.assertRequestBodyIs(json=req) self.assertRequestBodyIs(json=req)
self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN) self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN)
def _do_service_url_test(self, base_url, endpoint_filter): def _do_service_url_test(self, base_url, endpoint_filter):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
self.stub_url('GET', ['path'], self.stub_url(
base_url=base_url, 'GET', ['path'], base_url=base_url, text='SUCCESS', status_code=200
text='SUCCESS', status_code=200) )
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a) s = session.Session(auth=a)
resp = s.get('/path', endpoint_filter=endpoint_filter) resp = s.get('/path', endpoint_filter=endpoint_filter)
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
self.assertEqual(self.requests_mock.last_request.url, self.assertEqual(
base_url + '/path') self.requests_mock.last_request.url, base_url + '/path'
)
def test_service_url(self): def test_service_url(self):
endpoint_filter = {'service_type': 'compute', endpoint_filter = {
'service_type': 'compute',
'interface': 'admin', 'interface': 'admin',
'service_name': 'nova'} 'service_name': 'nova',
}
self._do_service_url_test('http://nova/novapi/admin', endpoint_filter) self._do_service_url_test('http://nova/novapi/admin', endpoint_filter)
def test_service_url_defaults_to_public(self): def test_service_url_defaults_to_public(self):
@ -227,47 +292,62 @@ class V2IdentityPlugin(utils.TestCase):
def test_endpoint_filter_without_service_type_fails(self): def test_endpoint_filter_without_service_type_fails(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a) s = session.Session(auth=a)
self.assertRaises(exceptions.EndpointNotFound, s.get, '/path', self.assertRaises(
endpoint_filter={'interface': 'admin'}) exceptions.EndpointNotFound,
s.get,
'/path',
endpoint_filter={'interface': 'admin'},
)
def test_full_url_overrides_endpoint_filter(self): def test_full_url_overrides_endpoint_filter(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
self.stub_url('GET', [], self.stub_url(
'GET',
[],
base_url='http://testurl/', base_url='http://testurl/',
text='SUCCESS', status_code=200) text='SUCCESS',
status_code=200,
)
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a) s = session.Session(auth=a)
resp = s.get('http://testurl/', resp = s.get(
endpoint_filter={'service_type': 'compute'}) 'http://testurl/', endpoint_filter={'service_type': 'compute'}
)
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, 'SUCCESS') self.assertEqual(resp.text, 'SUCCESS')
def test_invalid_auth_response_dict(self): def test_invalid_auth_response_dict(self):
self.stub_auth(json={'hello': 'world'}) self.stub_auth(json={'hello': 'world'})
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a) s = session.Session(auth=a)
self.assertRaises(exceptions.InvalidResponse, s.get, 'http://any', self.assertRaises(
authenticated=True) exceptions.InvalidResponse, s.get, 'http://any', authenticated=True
)
def test_invalid_auth_response_type(self): def test_invalid_auth_response_type(self):
self.stub_url('POST', ['tokens'], text='testdata') self.stub_url('POST', ['tokens'], text='testdata')
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a) s = session.Session(auth=a)
self.assertRaises(exceptions.InvalidResponse, s.get, 'http://any', self.assertRaises(
authenticated=True) exceptions.InvalidResponse, s.get, 'http://any', authenticated=True
)
def test_invalidate_response(self): def test_invalidate_response(self):
resp_data1 = copy.deepcopy(self.TEST_RESPONSE_DICT) resp_data1 = copy.deepcopy(self.TEST_RESPONSE_DICT)
@ -279,8 +359,9 @@ class V2IdentityPlugin(utils.TestCase):
auth_responses = [{'json': resp_data1}, {'json': resp_data2}] auth_responses = [{'json': resp_data1}, {'json': resp_data2}]
self.stub_auth(response_list=auth_responses) self.stub_auth(response_list=auth_responses)
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a) s = session.Session(auth=a)
self.assertEqual('token1', s.get_token()) self.assertEqual('token1', s.get_token())
@ -294,41 +375,50 @@ class V2IdentityPlugin(utils.TestCase):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
password = uuid.uuid4().hex password = uuid.uuid4().hex
a = v2.Password(self.TEST_URL, username=self.TEST_USER, a = v2.Password(
password=password) self.TEST_URL, username=self.TEST_USER, password=password
)
s = session.Session(auth=a) s = session.Session(auth=a)
self.assertEqual(self.TEST_TOKEN, s.get_token()) self.assertEqual(self.TEST_TOKEN, s.get_token())
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN}, self.assertEqual(
s.get_auth_headers()) {'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
self.assertNotIn(password, self.logger.output) self.assertNotIn(password, self.logger.output)
def test_password_with_no_user_id_or_name(self): def test_password_with_no_user_id_or_name(self):
self.assertRaises(TypeError, self.assertRaises(
v2.Password, self.TEST_URL, password=self.TEST_PASS) TypeError, v2.Password, self.TEST_URL, password=self.TEST_PASS
)
def test_password_cache_id(self): def test_password_cache_id(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT) self.stub_auth(json=self.TEST_RESPONSE_DICT)
trust_id = uuid.uuid4().hex trust_id = uuid.uuid4().hex
a = v2.Password(self.TEST_URL, a = v2.Password(
self.TEST_URL,
username=self.TEST_USER, username=self.TEST_USER,
password=self.TEST_PASS, password=self.TEST_PASS,
trust_id=trust_id) trust_id=trust_id,
)
b = v2.Password(self.TEST_URL, b = v2.Password(
self.TEST_URL,
username=self.TEST_USER, username=self.TEST_USER,
password=self.TEST_PASS, password=self.TEST_PASS,
trust_id=trust_id) trust_id=trust_id,
)
a_id = a.get_cache_id() a_id = a.get_cache_id()
b_id = b.get_cache_id() b_id = b.get_cache_id()
self.assertEqual(a_id, b_id) self.assertEqual(a_id, b_id)
c = v2.Password(self.TEST_URL, c = v2.Password(
self.TEST_URL,
username=self.TEST_USER, username=self.TEST_USER,
password=self.TEST_PASS, password=self.TEST_PASS,
tenant_id=trust_id) # same value different param tenant_id=trust_id,
) # same value different param
c_id = c.get_cache_id() c_id = c.get_cache_id()
@ -350,18 +440,21 @@ class V2IdentityPlugin(utils.TestCase):
auth_ref = access.create(body=token) auth_ref = access.create(body=token)
a = v2.Password(self.TEST_URL, a = v2.Password(
self.TEST_URL,
username=self.TEST_USER, username=self.TEST_USER,
password=self.TEST_PASS, password=self.TEST_PASS,
tenant_id=uuid.uuid4().hex) tenant_id=uuid.uuid4().hex,
)
initial_cache_id = a.get_cache_id() initial_cache_id = a.get_cache_id()
state = a.get_auth_state() state = a.get_auth_state()
self.assertIsNone(state) self.assertIsNone(state)
state = json.dumps({'auth_token': auth_ref.auth_token, state = json.dumps(
'body': auth_ref._data}) {'auth_token': auth_ref.auth_token, 'body': auth_ref._data}
)
a.set_auth_state(state) a.set_auth_state(state)
self.assertEqual(token.token_id, a.auth_ref.auth_token) self.assertEqual(token.token_id, a.auth_ref.auth_token)

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,6 @@ from keystoneauth1.tests.unit import utils
class TesterFederationPlugin(v3.FederationBaseAuth): class TesterFederationPlugin(v3.FederationBaseAuth):
def get_unscoped_auth_ref(self, sess, **kwargs): def get_unscoped_auth_ref(self, sess, **kwargs):
# This would go and talk to an idp or something # This would go and talk to an idp or something
resp = sess.post(self.federated_token_url, authenticated=False) resp = sess.post(self.federated_token_url, authenticated=False)
@ -32,11 +31,10 @@ class TesterFederationPlugin(v3.FederationBaseAuth):
class V3FederatedPlugin(utils.TestCase): class V3FederatedPlugin(utils.TestCase):
AUTH_URL = 'http://keystone/v3' AUTH_URL = 'http://keystone/v3'
def setUp(self): def setUp(self):
super(V3FederatedPlugin, self).setUp() super().setUp()
self.unscoped_token = fixture.V3Token() self.unscoped_token = fixture.V3Token()
self.unscoped_token_id = uuid.uuid4().hex self.unscoped_token_id = uuid.uuid4().hex
@ -46,26 +44,30 @@ class V3FederatedPlugin(utils.TestCase):
self.scoped_token_id = uuid.uuid4().hex self.scoped_token_id = uuid.uuid4().hex
s = self.scoped_token.add_service('compute', name='nova') s = self.scoped_token.add_service('compute', name='nova')
s.add_standard_endpoints(public='http://nova/public', s.add_standard_endpoints(
public='http://nova/public',
admin='http://nova/admin', admin='http://nova/admin',
internal='http://nova/internal') internal='http://nova/internal',
)
self.idp = uuid.uuid4().hex self.idp = uuid.uuid4().hex
self.protocol = uuid.uuid4().hex self.protocol = uuid.uuid4().hex
self.token_url = ('%s/OS-FEDERATION/identity_providers/%s/protocols/%s' self.token_url = (
'/auth' % (self.AUTH_URL, self.idp, self.protocol)) f'{self.AUTH_URL}/OS-FEDERATION/identity_providers/{self.idp}/protocols/{self.protocol}'
'/auth'
)
headers = {'X-Subject-Token': self.unscoped_token_id} headers = {'X-Subject-Token': self.unscoped_token_id}
self.unscoped_mock = self.requests_mock.post(self.token_url, self.unscoped_mock = self.requests_mock.post(
json=self.unscoped_token, self.token_url, json=self.unscoped_token, headers=headers
headers=headers) )
headers = {'X-Subject-Token': self.scoped_token_id} headers = {'X-Subject-Token': self.scoped_token_id}
auth_url = self.AUTH_URL + '/auth/tokens' auth_url = self.AUTH_URL + '/auth/tokens'
self.scoped_mock = self.requests_mock.post(auth_url, self.scoped_mock = self.requests_mock.post(
json=self.scoped_token, auth_url, json=self.scoped_token, headers=headers
headers=headers) )
def get_plugin(self, **kwargs): def get_plugin(self, **kwargs):
kwargs.setdefault('auth_url', self.AUTH_URL) kwargs.setdefault('auth_url', self.AUTH_URL)
@ -98,9 +100,8 @@ class V3FederatedPlugin(utils.TestCase):
class K2KAuthPluginTest(utils.TestCase): class K2KAuthPluginTest(utils.TestCase):
TEST_ROOT_URL = 'http://127.0.0.1:5000/' TEST_ROOT_URL = 'http://127.0.0.1:5000/'
TEST_URL = '%s%s' % (TEST_ROOT_URL, 'v3') TEST_URL = '{}{}'.format(TEST_ROOT_URL, 'v3')
TEST_PASS = 'password' TEST_PASS = 'password'
REQUEST_ECP_URL = TEST_URL + '/auth/OS-FEDERATION/saml2/ecp' REQUEST_ECP_URL = TEST_URL + '/auth/OS-FEDERATION/saml2/ecp'
@ -108,39 +109,45 @@ class K2KAuthPluginTest(utils.TestCase):
SP_ROOT_URL = 'https://sp1.com/v3' SP_ROOT_URL = 'https://sp1.com/v3'
SP_ID = 'sp1' SP_ID = 'sp1'
SP_URL = 'https://sp1.com/Shibboleth.sso/SAML2/ECP' SP_URL = 'https://sp1.com/Shibboleth.sso/SAML2/ECP'
SP_AUTH_URL = (SP_ROOT_URL + SP_AUTH_URL = (
'/OS-FEDERATION/identity_providers' SP_ROOT_URL + '/OS-FEDERATION/identity_providers'
'/testidp/protocols/saml2/auth') '/testidp/protocols/saml2/auth'
)
SERVICE_PROVIDER_DICT = { SERVICE_PROVIDER_DICT = {
'id': SP_ID, 'id': SP_ID,
'auth_url': SP_AUTH_URL, 'auth_url': SP_AUTH_URL,
'sp_url': SP_URL 'sp_url': SP_URL,
} }
def setUp(self): def setUp(self):
super(K2KAuthPluginTest, self).setUp() super().setUp()
self.token_v3 = fixture.V3Token() self.token_v3 = fixture.V3Token()
self.token_v3.add_service_provider( self.token_v3.add_service_provider(
self.SP_ID, self.SP_AUTH_URL, self.SP_URL) self.SP_ID, self.SP_AUTH_URL, self.SP_URL
)
self.session = session.Session() self.session = session.Session()
self.k2kplugin = self.get_plugin() self.k2kplugin = self.get_plugin()
def _get_base_plugin(self): def _get_base_plugin(self):
self.stub_url('POST', ['auth', 'tokens'], self.stub_url(
'POST',
['auth', 'tokens'],
headers={'X-Subject-Token': uuid.uuid4().hex}, headers={'X-Subject-Token': uuid.uuid4().hex},
json=self.token_v3) json=self.token_v3,
return v3.Password(self.TEST_URL, )
username=self.TEST_USER, return v3.Password(
password=self.TEST_PASS) self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
def _mock_k2k_flow_urls(self, redirect_code=302): def _mock_k2k_flow_urls(self, redirect_code=302):
# List versions available for auth # List versions available for auth
self.requests_mock.get( self.requests_mock.get(
self.TEST_URL, self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)}, json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'}) headers={'Content-Type': 'application/json'},
)
# The IdP should return a ECP wrapped assertion when requested # The IdP should return a ECP wrapped assertion when requested
self.requests_mock.register_uri( self.requests_mock.register_uri(
@ -148,7 +155,8 @@ class K2KAuthPluginTest(utils.TestCase):
self.REQUEST_ECP_URL, self.REQUEST_ECP_URL,
content=bytes(k2k_fixtures.ECP_ENVELOPE, 'latin-1'), content=bytes(k2k_fixtures.ECP_ENVELOPE, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'}, headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=200) status_code=200,
)
# The SP should respond with a redirect (302 or 303) # The SP should respond with a redirect (302 or 303)
self.requests_mock.register_uri( self.requests_mock.register_uri(
@ -156,14 +164,16 @@ class K2KAuthPluginTest(utils.TestCase):
self.SP_URL, self.SP_URL,
content=bytes(k2k_fixtures.TOKEN_BASED_ECP, 'latin-1'), content=bytes(k2k_fixtures.TOKEN_BASED_ECP, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'}, headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=redirect_code) status_code=redirect_code,
)
# Should not follow the redirect URL, but use the auth_url attribute # Should not follow the redirect URL, but use the auth_url attribute
self.requests_mock.register_uri( self.requests_mock.register_uri(
'GET', 'GET',
self.SP_AUTH_URL, self.SP_AUTH_URL,
json=k2k_fixtures.UNSCOPED_TOKEN, json=k2k_fixtures.UNSCOPED_TOKEN,
headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER}) headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER},
)
def get_plugin(self, **kwargs): def get_plugin(self, **kwargs):
kwargs.setdefault('base_plugin', self._get_base_plugin()) kwargs.setdefault('base_plugin', self._get_base_plugin())
@ -178,84 +188,108 @@ class K2KAuthPluginTest(utils.TestCase):
self.requests_mock.get( self.requests_mock.get(
self.TEST_URL, self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)}, json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'}) headers={'Content-Type': 'application/json'},
)
self.requests_mock.register_uri( self.requests_mock.register_uri(
'POST', self.REQUEST_ECP_URL, 'POST', self.REQUEST_ECP_URL, status_code=401
status_code=401) )
self.assertRaises(exceptions.AuthorizationFailure, self.assertRaises(
exceptions.AuthorizationFailure,
self.k2kplugin._get_ecp_assertion, self.k2kplugin._get_ecp_assertion,
self.session) self.session,
)
def test_get_ecp_assertion_empty_response(self): def test_get_ecp_assertion_empty_response(self):
self.requests_mock.get( self.requests_mock.get(
self.TEST_URL, self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)}, json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'}) headers={'Content-Type': 'application/json'},
)
self.requests_mock.register_uri( self.requests_mock.register_uri(
'POST', self.REQUEST_ECP_URL, 'POST',
self.REQUEST_ECP_URL,
headers={'Content-Type': 'application/vnd.paos+xml'}, headers={'Content-Type': 'application/vnd.paos+xml'},
content=b'', status_code=200) content=b'',
status_code=200,
)
self.assertRaises(exceptions.InvalidResponse, self.assertRaises(
exceptions.InvalidResponse,
self.k2kplugin._get_ecp_assertion, self.k2kplugin._get_ecp_assertion,
self.session) self.session,
)
def test_get_ecp_assertion_wrong_headers(self): def test_get_ecp_assertion_wrong_headers(self):
self.requests_mock.get( self.requests_mock.get(
self.TEST_URL, self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)}, json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'}) headers={'Content-Type': 'application/json'},
)
self.requests_mock.register_uri( self.requests_mock.register_uri(
'POST', self.REQUEST_ECP_URL, 'POST',
self.REQUEST_ECP_URL,
headers={'Content-Type': uuid.uuid4().hex}, headers={'Content-Type': uuid.uuid4().hex},
content=b'', status_code=200) content=b'',
status_code=200,
)
self.assertRaises(exceptions.InvalidResponse, self.assertRaises(
exceptions.InvalidResponse,
self.k2kplugin._get_ecp_assertion, self.k2kplugin._get_ecp_assertion,
self.session) self.session,
)
def test_send_ecp_authn_response(self): def test_send_ecp_authn_response(self):
self._mock_k2k_flow_urls() self._mock_k2k_flow_urls()
# Perform the request # Perform the request
response = self.k2kplugin._send_service_provider_ecp_authn_response( response = self.k2kplugin._send_service_provider_ecp_authn_response(
self.session, self.SP_URL, self.SP_AUTH_URL) self.session, self.SP_URL, self.SP_AUTH_URL
)
# Check the response # Check the response
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER, self.assertEqual(
response.headers['X-Subject-Token']) k2k_fixtures.UNSCOPED_TOKEN_HEADER,
response.headers['X-Subject-Token'],
)
def test_send_ecp_authn_response_303_redirect(self): def test_send_ecp_authn_response_303_redirect(self):
self._mock_k2k_flow_urls(redirect_code=303) self._mock_k2k_flow_urls(redirect_code=303)
# Perform the request # Perform the request
response = self.k2kplugin._send_service_provider_ecp_authn_response( response = self.k2kplugin._send_service_provider_ecp_authn_response(
self.session, self.SP_URL, self.SP_AUTH_URL) self.session, self.SP_URL, self.SP_AUTH_URL
)
# Check the response # Check the response
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER, self.assertEqual(
response.headers['X-Subject-Token']) k2k_fixtures.UNSCOPED_TOKEN_HEADER,
response.headers['X-Subject-Token'],
)
def test_end_to_end_workflow(self): def test_end_to_end_workflow(self):
self._mock_k2k_flow_urls() self._mock_k2k_flow_urls()
auth_ref = self.k2kplugin.get_auth_ref(self.session) auth_ref = self.k2kplugin.get_auth_ref(self.session)
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER, self.assertEqual(
auth_ref.auth_token) k2k_fixtures.UNSCOPED_TOKEN_HEADER, auth_ref.auth_token
)
def test_end_to_end_workflow_303_redirect(self): def test_end_to_end_workflow_303_redirect(self):
self._mock_k2k_flow_urls(redirect_code=303) self._mock_k2k_flow_urls(redirect_code=303)
auth_ref = self.k2kplugin.get_auth_ref(self.session) auth_ref = self.k2kplugin.get_auth_ref(self.session)
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER, self.assertEqual(
auth_ref.auth_token) k2k_fixtures.UNSCOPED_TOKEN_HEADER, auth_ref.auth_token
)
def test_end_to_end_with_generic_password(self): def test_end_to_end_with_generic_password(self):
# List versions available for auth # List versions available for auth
self.requests_mock.get( self.requests_mock.get(
self.TEST_ROOT_URL, self.TEST_ROOT_URL,
json=fixture.DiscoveryList(self.TEST_ROOT_URL), json=fixture.DiscoveryList(self.TEST_ROOT_URL),
headers={'Content-Type': 'application/json'}) headers={'Content-Type': 'application/json'},
)
# The IdP should return a ECP wrapped assertion when requested # The IdP should return a ECP wrapped assertion when requested
self.requests_mock.register_uri( self.requests_mock.register_uri(
@ -263,7 +297,8 @@ class K2KAuthPluginTest(utils.TestCase):
self.REQUEST_ECP_URL, self.REQUEST_ECP_URL,
content=bytes(k2k_fixtures.ECP_ENVELOPE, 'latin-1'), content=bytes(k2k_fixtures.ECP_ENVELOPE, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'}, headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=200) status_code=200,
)
# The SP should respond with a redirect (302 or 303) # The SP should respond with a redirect (302 or 303)
self.requests_mock.register_uri( self.requests_mock.register_uri(
@ -271,24 +306,33 @@ class K2KAuthPluginTest(utils.TestCase):
self.SP_URL, self.SP_URL,
content=bytes(k2k_fixtures.TOKEN_BASED_ECP, 'latin-1'), content=bytes(k2k_fixtures.TOKEN_BASED_ECP, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'}, headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=302) status_code=302,
)
# Should not follow the redirect URL, but use the auth_url attribute # Should not follow the redirect URL, but use the auth_url attribute
self.requests_mock.register_uri( self.requests_mock.register_uri(
'GET', 'GET',
self.SP_AUTH_URL, self.SP_AUTH_URL,
json=k2k_fixtures.UNSCOPED_TOKEN, json=k2k_fixtures.UNSCOPED_TOKEN,
headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER}) headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER},
)
self.stub_url('POST', ['auth', 'tokens'], self.stub_url(
'POST',
['auth', 'tokens'],
headers={'X-Subject-Token': uuid.uuid4().hex}, headers={'X-Subject-Token': uuid.uuid4().hex},
json=self.token_v3) json=self.token_v3,
)
plugin = identity.Password(self.TEST_ROOT_URL, plugin = identity.Password(
self.TEST_ROOT_URL,
username=self.TEST_USER, username=self.TEST_USER,
password=self.TEST_PASS, password=self.TEST_PASS,
user_domain_id='default') user_domain_id='default',
)
k2kplugin = self.get_plugin(base_plugin=plugin) k2kplugin = self.get_plugin(base_plugin=plugin)
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER, self.assertEqual(
k2kplugin.get_token(self.session)) k2k_fixtures.UNSCOPED_TOKEN_HEADER,
k2kplugin.get_token(self.session),
)

Some files were not shown because too many files have changed in this diff Show More