Apply ruff, ruff-format

We disable the E203 (whitespace before ':') and E501 (line too long)
linter rules since these conflict with ruff-format. We also rework a
statement in 'keystoneauth1/tests/unit/test_session.py' since it's
triggering a bug in flake8 [1] that is currently (at time of authoring)
unresolved.

[1] https://github.com/PyCQA/flake8/issues/1948

Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
Change-Id: Ief5c1c57d1e72db9fc881063d4c7e1030e76da43
This commit is contained in:
Stephen Finucane 2024-08-02 11:33:43 +01:00
parent c05e237a8a
commit 127d7be2b7
137 changed files with 8953 additions and 6346 deletions

View File

@ -35,7 +35,7 @@ class ListAuthPluginsDirective(rst.Directive):
has_content = True
def report_load_failure(mgr, ep, err):
LOG.warning(u'Failed to load %s: %s' % (ep.module_name, err))
LOG.warning(f'Failed to load {ep.module_name}: {err}')
def display_plugin(self, ext):
overline_style = self.options.get('overline-style', '')
@ -60,7 +60,7 @@ class ListAuthPluginsDirective(rst.Directive):
yield "\n"
for opt in ext.obj.get_options():
yield ":%s: %s" % (opt.name, opt.help)
yield f":{opt.name}: {opt.help}"
yield "\n"
@ -68,7 +68,7 @@ class ListAuthPluginsDirective(rst.Directive):
mgr = extension.ExtensionManager(
'keystoneauth1.plugin',
on_load_failure_callback=self.report_load_failure,
invoke_on_load=True,
invoke_on_load=True,
)
result = ViewList()

View File

@ -81,7 +81,7 @@ latex_documents = [
'Openstack Developers',
'manual',
True,
),
)
]
# Disable usage of xindy https://bugzilla.redhat.com/show_bug.cgi?id=1643664

View File

@ -15,7 +15,7 @@ import threading
import time
class FairSemaphore(object):
class FairSemaphore:
"""Semaphore class that notifies in order of request.
We cannot use a normal Semaphore because it doesn't give any ordering,

View File

@ -78,7 +78,7 @@ def before_utcnow(**timedelta_kwargs):
# Detect if running on the Windows Subsystem for Linux
try:
with open('/proc/version', 'r') as f:
with open('/proc/version') as f:
is_windows_linux_subsystem = 'microsoft' in f.read().lower()
except IOError:
except OSError:
is_windows_linux_subsystem = False

View File

@ -13,7 +13,9 @@
from keystoneauth1.access.access import * # noqa
__all__ = ('AccessInfo', # noqa: F405
'AccessInfoV2', # noqa: F405
'AccessInfoV3', # noqa: F405
'create') # noqa: F405
__all__ = ( # noqa: F405
'AccessInfo',
'AccessInfoV2',
'AccessInfoV3',
'create',
)

View File

@ -25,10 +25,7 @@ from keystoneauth1.access import service_providers
STALE_TOKEN_DURATION = 30
__all__ = ('AccessInfo',
'AccessInfoV2',
'AccessInfoV3',
'create')
__all__ = ('AccessInfo', 'AccessInfoV2', 'AccessInfoV3', 'create')
def create(resp=None, body=None, auth_token=None):
@ -47,7 +44,6 @@ def create(resp=None, body=None, auth_token=None):
def _missingproperty(f):
@functools.wraps(f)
def inner(self):
try:
@ -58,7 +54,7 @@ def _missingproperty(f):
return property(inner)
class AccessInfo(object):
class AccessInfo:
"""Encapsulates a raw authentication token from keystone.
Provides helper methods for extracting useful values from that token.
@ -77,7 +73,8 @@ class AccessInfo(object):
def service_catalog(self):
if not self._service_catalog:
self._service_catalog = self._service_catalog_class.from_token(
self._data)
self._data
)
return self._service_catalog
@ -422,7 +419,7 @@ class AccessInfoV2(AccessInfo):
@_missingproperty
def auth_token(self):
set_token = super(AccessInfoV2, self).auth_token
set_token = super().auth_token
return set_token or self._data['access']['token']['id']
@property
@ -775,7 +772,8 @@ class AccessInfoV3(AccessInfo):
def service_providers(self):
if not self._service_providers:
self._service_providers = (
service_providers.ServiceProviders.from_token(self._data))
service_providers.ServiceProviders.from_token(self._data)
)
return self._service_providers

View File

@ -114,7 +114,8 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
service.setdefault('id', None)
service['endpoints'] = self._normalize_endpoints(
service.get('endpoints', []))
service.get('endpoints', [])
)
for endpoint in service['endpoints']:
endpoint['region_name'] = self._get_endpoint_region(endpoint)
@ -129,9 +130,15 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
interface = [interface]
return [self.normalize_interface(i) for i in interface]
def get_endpoints_data(self, service_type=None, interface=None,
region_name=None, service_name=None,
service_id=None, endpoint_id=None):
def get_endpoints_data(
self,
service_type=None,
interface=None,
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch and filter endpoint data for the specified service(s).
Returns endpoints for the specified service (or all) containing
@ -164,17 +171,19 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
matching_endpoints = {}
for service in self.normalize_catalog():
if service_type and not discover._SERVICE_TYPES.is_match(
service_type, service['type']):
service_type, service['type']
):
continue
if (service_name and service['name'] and
service_name != service['name']):
if (
service_name
and service['name']
and service_name != service['name']
):
continue
if (service_id and service['id'] and
service_id != service['id']):
if service_id and service['id'] and service_id != service['id']:
continue
matching_endpoints.setdefault(service['type'], [])
@ -198,7 +207,9 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
interface=endpoint['interface'],
region_name=endpoint['region_name'],
endpoint_id=endpoint['id'],
raw_endpoint=endpoint['raw_endpoint']))
raw_endpoint=endpoint['raw_endpoint'],
)
)
if not interfaces:
return self._endpoints_by_type(service_type, matching_endpoints)
@ -212,8 +223,9 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
for endpoint in endpoints:
matches_by_interface.setdefault(endpoint.interface, [])
matches_by_interface[endpoint.interface].append(endpoint)
best_interface = [i for i in interfaces
if i in matches_by_interface.keys()][0]
best_interface = [
i for i in interfaces if i in matches_by_interface.keys()
][0]
ret[matched_service_type] = matches_by_interface[best_interface]
return self._endpoints_by_type(service_type, ret)
@ -279,9 +291,15 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
# part if we do. Raise this so that we can panic in unit tests.
raise ValueError("Programming error choosing an endpoint.")
def get_endpoints(self, service_type=None, interface=None,
region_name=None, service_name=None,
service_id=None, endpoint_id=None):
def get_endpoints(
self,
service_type=None,
interface=None,
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch and filter endpoint data for the specified service(s).
Returns endpoints for the specified service (or all) containing
@ -294,17 +312,27 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
Returns a dict keyed by service_type with a list of endpoint dicts
"""
endpoints_data = self.get_endpoints_data(
service_type=service_type, interface=interface,
region_name=region_name, service_name=service_name,
service_id=service_id, endpoint_id=endpoint_id)
service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id,
)
endpoints = {}
for service_type, data in endpoints_data.items():
endpoints[service_type] = self._denormalize_endpoints(data)
return endpoints
def get_endpoint_data_list(self, service_type=None, interface='public',
region_name=None, service_name=None,
service_id=None, endpoint_id=None):
def get_endpoint_data_list(
self,
service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch a flat list of matching EndpointData objects.
Fetch the endpoints from the service catalog for a particular
@ -327,17 +355,25 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
:returns: a list of matching EndpointData objects
:rtype: list(`keystoneauth1.discover.EndpointData`)
"""
endpoints = self.get_endpoints_data(service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id)
endpoints = self.get_endpoints_data(
service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id,
)
return [endpoint for data in endpoints.values() for endpoint in data]
def get_urls(self, service_type=None, interface='public',
region_name=None, service_name=None,
service_id=None, endpoint_id=None):
def get_urls(
self,
service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch endpoint urls from the service catalog.
Fetch the urls of endpoints from the service catalog for a particular
@ -359,17 +395,25 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
:returns: tuple of urls
"""
endpoints = self.get_endpoint_data_list(service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id)
endpoints = self.get_endpoint_data_list(
service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id,
)
return tuple([endpoint.url for endpoint in endpoints])
def url_for(self, service_type=None, interface='public',
region_name=None, service_name=None,
service_id=None, endpoint_id=None):
def url_for(
self,
service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch an endpoint from the service catalog.
Fetch the specified endpoint from the service catalog for
@ -389,16 +433,24 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
:param string service_id: The identifier of a service.
:param string endpoint_id: The identifier of an endpoint.
"""
return self.endpoint_data_for(service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id).url
return self.endpoint_data_for(
service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id,
).url
def endpoint_data_for(self, service_type=None, interface='public',
region_name=None, service_name=None,
service_id=None, endpoint_id=None):
def endpoint_data_for(
self,
service_type=None,
interface='public',
region_name=None,
service_name=None,
service_id=None,
endpoint_id=None,
):
"""Fetch endpoint data from the service catalog.
Fetch the specified endpoint data from the service catalog for
@ -427,34 +479,30 @@ class ServiceCatalog(metaclass=abc.ABCMeta):
region_name=region_name,
service_name=service_name,
service_id=service_id,
endpoint_id=endpoint_id)
endpoint_id=endpoint_id,
)
if endpoint_data_list:
return endpoint_data_list[0]
if service_name and region_name:
msg = ('%(interface)s endpoint for %(service_type)s service '
'named %(service_name)s in %(region_name)s region not '
'found' %
{'interface': interface,
'service_type': service_type, 'service_name': service_name,
'region_name': region_name})
msg = (
f'{interface} endpoint for {service_type} service '
f'named {service_name} in {region_name} region not '
'found'
)
elif service_name:
msg = ('%(interface)s endpoint for %(service_type)s service '
'named %(service_name)s not found' %
{'interface': interface,
'service_type': service_type,
'service_name': service_name})
msg = (
f'{interface} endpoint for {service_type} service '
f'named {service_name} not found'
)
elif region_name:
msg = ('%(interface)s endpoint for %(service_type)s service '
'in %(region_name)s region not found' %
{'interface': interface,
'service_type': service_type, 'region_name': region_name})
msg = (
f'{interface} endpoint for {service_type} service '
f'in {region_name} region not found'
)
else:
msg = ('%(interface)s endpoint for %(service_type)s service '
'not found' %
{'interface': interface,
'service_type': service_type})
msg = f'{interface} endpoint for {service_type} service not found'
raise exceptions.EndpointNotFound(msg)
@ -498,8 +546,9 @@ class ServiceCatalogV2(ServiceCatalog):
for endpoint in endpoints:
raw_endpoint = endpoint.copy()
interface_urls = {}
interface_keys = [key for key in endpoint.keys()
if key.endswith('URL')]
interface_keys = [
key for key in endpoint.keys() if key.endswith('URL')
]
for key in interface_keys:
interface = self.normalize_interface(key)
interface_urls[interface] = endpoint.pop(key)
@ -522,8 +571,7 @@ class ServiceCatalogV2(ServiceCatalog):
:returns: List of endpoint description dicts in original catalog format
"""
raw_endpoints = super(ServiceCatalogV2, self)._denormalize_endpoints(
endpoints)
raw_endpoints = super()._denormalize_endpoints(endpoints)
# The same raw endpoint content will be in the list once for each
# v2 endpoint_type entry. We only need one of them in the resulting
# list. So keep a list of the string versions.

View File

@ -13,22 +13,24 @@
from keystoneauth1 import exceptions
class ServiceProviders(object):
class ServiceProviders:
"""Helper methods for dealing with Service Providers."""
@classmethod
def from_token(cls, token):
if 'token' not in token:
raise ValueError('Token format does not support service'
'providers.')
raise ValueError(
'Token format does not support service providers.'
)
return cls(token['token'].get('service_providers', []))
def __init__(self, service_providers):
def normalize(service_providers_list):
return dict((sp['id'], sp) for sp in service_providers_list
if 'id' in sp)
return {
sp['id']: sp for sp in service_providers_list if 'id' in sp
}
self._service_providers = normalize(service_providers)
def _get_service_provider(self, sp_id):

View File

@ -19,7 +19,7 @@ from keystoneauth1 import _fair_semaphore
from keystoneauth1 import session
class Adapter(object):
class Adapter:
"""An instance of a session with local variables.
A session is a global object that is shared around amongst many clients. It
@ -118,23 +118,41 @@ class Adapter(object):
client_name = None
client_version = None
def __init__(self, session, service_type=None, service_name=None,
interface=None, region_name=None, endpoint_override=None,
version=None, auth=None, user_agent=None,
connect_retries=None, logger=None, allow=None,
additional_headers=None, client_name=None,
client_version=None, allow_version_hack=None,
global_request_id=None,
min_version=None, max_version=None,
default_microversion=None, status_code_retries=None,
retriable_status_codes=None, raise_exc=None,
rate_limit=None, concurrency=None,
connect_retry_delay=None, status_code_retry_delay=None,
):
def __init__(
self,
session,
service_type=None,
service_name=None,
interface=None,
region_name=None,
endpoint_override=None,
version=None,
auth=None,
user_agent=None,
connect_retries=None,
logger=None,
allow=None,
additional_headers=None,
client_name=None,
client_version=None,
allow_version_hack=None,
global_request_id=None,
min_version=None,
max_version=None,
default_microversion=None,
status_code_retries=None,
retriable_status_codes=None,
raise_exc=None,
rate_limit=None,
concurrency=None,
connect_retry_delay=None,
status_code_retry_delay=None,
):
if version and (min_version or max_version):
raise TypeError(
"version is mutually exclusive with min_version and"
" max_version")
" max_version"
)
# NOTE(jamielennox): when adding new parameters to adapter please also
# add them to the adapter call in httpclient.HTTPClient.__init__ as
# well as to load_adapter_from_argparse below if the argument is
@ -177,7 +195,8 @@ class Adapter(object):
rate_delay = 1.0 / rate_limit
self._rate_semaphore = _fair_semaphore.FairSemaphore(
concurrency, rate_delay)
concurrency, rate_delay
)
def _set_endpoint_filter_kwargs(self, kwargs):
if self.service_type:
@ -205,7 +224,8 @@ class Adapter(object):
# case insensitive.
if kwargs.get('headers'):
kwargs['headers'] = requests.structures.CaseInsensitiveDict(
kwargs['headers'])
kwargs['headers']
)
else:
kwargs['headers'] = requests.structures.CaseInsensitiveDict()
if self.endpoint_override:
@ -215,9 +235,13 @@ class Adapter(object):
kwargs.setdefault('auth', self.auth)
if self.user_agent:
kwargs.setdefault('user_agent', self.user_agent)
for arg in ('connect_retries', 'status_code_retries',
'connect_retry_delay', 'status_code_retry_delay',
'retriable_status_codes'):
for arg in (
'connect_retries',
'status_code_retries',
'connect_retry_delay',
'status_code_retry_delay',
'retriable_status_codes',
):
if getattr(self, arg) is not None:
kwargs.setdefault(arg, getattr(self, arg))
if self.logger:
@ -239,15 +263,18 @@ class Adapter(object):
kwargs.setdefault('rate_semaphore', self._rate_semaphore)
else:
warnings.warn('Using keystoneclient sessions has been deprecated. '
'Please update your software to use keystoneauth1.')
warnings.warn(
'Using keystoneclient sessions has been deprecated. '
'Please update your software to use keystoneauth1.'
)
for k, v in self.additional_headers.items():
kwargs.setdefault('headers', {}).setdefault(k, v)
if self.global_request_id is not None:
kwargs.setdefault('headers', {}).setdefault(
"X-OpenStack-Request-ID", self.global_request_id)
"X-OpenStack-Request-ID", self.global_request_id
)
if self.raise_exc is not None:
kwargs.setdefault('raise_exc', self.raise_exc)
@ -309,10 +336,7 @@ class Adapter(object):
return self.session.get_endpoint_data(auth or self.auth, **kwargs)
def get_all_version_data(
self,
interface='public',
region_name=None):
def get_all_version_data(self, interface='public', region_name=None):
"""Get data about all versions of a service.
:param interface:
@ -330,7 +354,8 @@ class Adapter(object):
return self.session.get_all_version_data(
interface=interface,
region_name=region_name,
service_type=self.service_type)
service_type=self.service_type,
)
def get_api_major_version(self, auth=None, **kwargs):
"""Get the major API version as provided by the auth plugin.
@ -419,43 +444,50 @@ class Adapter(object):
adapter_group = parser.add_argument_group(
'Service Options',
'Options controlling the specialization of the API'
' Connection from information found in the catalog')
' Connection from information found in the catalog',
)
adapter_group.add_argument(
'--os-service-type',
metavar='<name>',
default=os.environ.get('OS_SERVICE_TYPE', service_type),
help='Service type to request from the catalog')
help='Service type to request from the catalog',
)
adapter_group.add_argument(
'--os-service-name',
metavar='<name>',
default=os.environ.get('OS_SERVICE_NAME', None),
help='Service name to request from the catalog')
help='Service name to request from the catalog',
)
adapter_group.add_argument(
'--os-interface',
metavar='<name>',
default=os.environ.get('OS_INTERFACE', 'public'),
help='API Interface to use [public, internal, admin]')
help='API Interface to use [public, internal, admin]',
)
adapter_group.add_argument(
'--os-region-name',
metavar='<name>',
default=os.environ.get('OS_REGION_NAME', None),
help='Region of the cloud to use')
help='Region of the cloud to use',
)
adapter_group.add_argument(
'--os-endpoint-override',
metavar='<name>',
default=os.environ.get('OS_ENDPOINT_OVERRIDE', None),
help='Endpoint to use instead of the endpoint in the catalog')
help='Endpoint to use instead of the endpoint in the catalog',
)
adapter_group.add_argument(
'--os-api-version',
metavar='<name>',
default=os.environ.get('OS_API_VERSION', None),
help='Which version of the service API to use')
help='Which version of the service API to use',
)
# TODO(efried): Move this to loading.adapter.Adapter
@classmethod
@ -469,66 +501,62 @@ class Adapter(object):
"""
service_env = service_type.upper().replace('-', '_')
adapter_group = parser.add_argument_group(
'{service_type} Service Options'.format(
service_type=service_type.title()),
'Options controlling the specialization of the {service_type}'
' API Connection from information found in the catalog'.format(
service_type=service_type.title()))
f'{service_type.title()} Service Options',
f'Options controlling the specialization of the {service_type.title()}'
' API Connection from information found in the catalog',
)
adapter_group.add_argument(
'--os-{service_type}-service-type'.format(
service_type=service_type),
f'--os-{service_type}-service-type',
metavar='<name>',
default=os.environ.get(
'OS_{service_type}_SERVICE_TYPE'.format(
service_type=service_env), None),
help=('Service type to request from the catalog for the'
' {service_type} service'.format(
service_type=service_type)))
default=os.environ.get(f'OS_{service_env}_SERVICE_TYPE', None),
help=(
'Service type to request from the catalog for the'
f' {service_type} service'
),
)
adapter_group.add_argument(
'--os-{service_type}-service-name'.format(
service_type=service_type),
f'--os-{service_type}-service-name',
metavar='<name>',
default=os.environ.get(
'OS_{service_type}_SERVICE_NAME'.format(
service_type=service_env), None),
help=('Service name to request from the catalog for the'
' {service_type} service'.format(
service_type=service_type)))
default=os.environ.get(f'OS_{service_env}_SERVICE_NAME', None),
help=(
'Service name to request from the catalog for the'
f' {service_type} service'
),
)
adapter_group.add_argument(
'--os-{service_type}-interface'.format(
service_type=service_type),
f'--os-{service_type}-interface',
metavar='<name>',
default=os.environ.get(
'OS_{service_type}_INTERFACE'.format(
service_type=service_env), None),
help=('API Interface to use for the {service_type} service'
' [public, internal, admin]'.format(
service_type=service_type)))
default=os.environ.get(f'OS_{service_env}_INTERFACE', None),
help=(
f'API Interface to use for the {service_type} service'
' [public, internal, admin]'
),
)
adapter_group.add_argument(
'--os-{service_type}-api-version'.format(
service_type=service_type),
f'--os-{service_type}-api-version',
metavar='<name>',
default=os.environ.get(
'OS_{service_type}_API_VERSION'.format(
service_type=service_env), None),
help=('Which version of the service API to use for'
' the {service_type} service'.format(
service_type=service_type)))
default=os.environ.get(f'OS_{service_env}_API_VERSION', None),
help=(
'Which version of the service API to use for'
f' the {service_type} service'
),
)
adapter_group.add_argument(
'--os-{service_type}-endpoint-override'.format(
service_type=service_type),
f'--os-{service_type}-endpoint-override',
metavar='<name>',
default=os.environ.get(
'OS_{service_type}_ENDPOINT_OVERRIDE'.format(
service_type=service_env), None),
help=('Endpoint to use for the {service_type} service'
' instead of the endpoint in the catalog'.format(
service_type=service_type)))
f'OS_{service_env}_ENDPOINT_OVERRIDE', None
),
help=(
f'Endpoint to use for the {service_type} service'
' instead of the endpoint in the catalog'
),
)
class LegacyJsonAdapter(Adapter):
@ -549,7 +577,7 @@ class LegacyJsonAdapter(Adapter):
except KeyError:
pass
resp = super(LegacyJsonAdapter, self).request(*args, **kwargs)
resp = super().request(*args, **kwargs)
try:
body = resp.json()

View File

@ -62,23 +62,18 @@ def get_version_data(session, url, authenticated=None, version_header=None):
The return is a list of dicts of the form::
[{
'status': 'STABLE',
'id': 'v2.3',
'links': [
{
'href': 'http://network.example.com/v2.3',
'rel': 'self',
},
{
'href': 'http://network.example.com/',
'rel': 'collection',
},
],
'min_version': '2.0',
'max_version': '2.7',
},
...,
[
{
'status': 'STABLE',
'id': 'v2.3',
'links': [
{'href': 'http://network.example.com/v2.3', 'rel': 'self'},
{'href': 'http://network.example.com/', 'rel': 'collection'},
],
'min_version': '2.0',
'max_version': '2.7',
},
...,
]
Note:
@ -117,7 +112,8 @@ def get_version_data(session, url, authenticated=None, version_header=None):
# it's the only thing returning a [] here - and that's ok.
if isinstance(body_resp, list):
raise exceptions.DiscoveryFailure(
'Invalid Response - List returned instead of dict')
'Invalid Response - List returned instead of dict'
)
# In the event of querying a root URL we will get back a list of
# available versions.
@ -163,8 +159,9 @@ def get_version_data(session, url, authenticated=None, version_header=None):
return [body_resp]
err_text = resp.text[:50] + '...' if len(resp.text) > 50 else resp.text
raise exceptions.DiscoveryFailure('Invalid Response - Bad version data '
'returned: %s' % err_text)
raise exceptions.DiscoveryFailure(
'Invalid Response - Bad version data ' f'returned: {err_text}'
)
def normalize_version_number(version):
@ -253,11 +250,12 @@ def normalize_version_number(version):
except (TypeError, ValueError):
pass
raise TypeError('Invalid version specified: %s' % version)
raise TypeError(f'Invalid version specified: {version}')
def _normalize_version_args(
version, min_version, max_version, service_type=None):
version, min_version, max_version, service_type=None
):
# The sins of our fathers become the blood on our hands.
# If a user requests an old-style service type such as volumev2, then they
# are inherently requesting the major API version 2. It's not a good
@ -270,17 +268,20 @@ def _normalize_version_args(
# as this, but in order to move forward without breaking people, we have
# to just cry in the corner while striking ourselves with thorned branches.
# That said, for sure only do this hack for officially known service_types.
if (service_type and
_SERVICE_TYPES.is_known(service_type) and
service_type[-1].isdigit() and
service_type[-2] == 'v'):
if (
service_type
and _SERVICE_TYPES.is_known(service_type)
and service_type[-1].isdigit()
and service_type[-2] == 'v'
):
implied_version = normalize_version_number(service_type[-1])
else:
implied_version = None
if version and (min_version or max_version):
raise ValueError(
"version is mutually exclusive with min_version and max_version")
"version is mutually exclusive with min_version and max_version"
)
if version:
# Explode this into min_version and max_version
@ -291,15 +292,16 @@ def _normalize_version_args(
raise exceptions.ImpliedVersionMismatch(
service_type=service_type,
implied=implied_version,
given=version_to_string(version))
given=version_to_string(version),
)
return min_version, max_version
if min_version == 'latest':
if max_version not in (None, 'latest'):
raise ValueError(
"min_version is 'latest' and max_version is {max_version}"
" but is only allowed to be 'latest' or None".format(
max_version=max_version))
f"min_version is 'latest' and max_version is {max_version}"
" but is only allowed to be 'latest' or None"
)
max_version = 'latest'
# Normalize e.g. empty string to None
@ -326,7 +328,8 @@ def _normalize_version_args(
raise exceptions.ImpliedMinVersionMismatch(
service_type=service_type,
implied=implied_version,
given=version_to_string(min_version))
given=version_to_string(min_version),
)
else:
min_version = implied_version
@ -338,7 +341,8 @@ def _normalize_version_args(
raise exceptions.ImpliedMaxVersionMismatch(
service_type=service_type,
implied=implied_version,
given=version_to_string(max_version))
given=version_to_string(max_version),
)
else:
max_version = (implied_version[0], LATEST)
return min_version, max_version
@ -477,7 +481,8 @@ def _combine_relative_url(discovery_url, version_url):
path,
parsed_version_url.params,
parsed_version_url.query,
parsed_version_url.fragment).geturl()
parsed_version_url.fragment,
).geturl()
def _version_from_url(url):
@ -499,7 +504,7 @@ def _version_from_url(url):
return None
class Status(object):
class Status:
CURRENT = 'CURRENT'
SUPPORTED = 'SUPPORTED'
DEPRECATED = 'DEPRECATED'
@ -528,19 +533,23 @@ class Status(object):
return status
class Discover(object):
class Discover:
CURRENT_STATUSES = ('stable', 'current', 'supported')
DEPRECATED_STATUSES = ('deprecated',)
EXPERIMENTAL_STATUSES = ('experimental',)
def __init__(self, session, url, authenticated=None):
self._url = url
self._data = get_version_data(session, url,
authenticated=authenticated)
self._data = get_version_data(
session, url, authenticated=authenticated
)
def raw_version_data(self, allow_experimental=False,
allow_deprecated=True, allow_unknown=False):
def raw_version_data(
self,
allow_experimental=False,
allow_deprecated=True,
allow_unknown=False,
):
"""Get raw version information from URL.
Raw data indicates that only minimal validation processing is performed
@ -560,8 +569,10 @@ class Discover(object):
try:
status = v['status']
except KeyError:
_LOGGER.warning('Skipping over invalid version data. '
'No stability status in version.')
_LOGGER.warning(
'Skipping over invalid version data. '
'No stability status in version.'
)
continue
status = status.lower()
@ -633,8 +644,10 @@ class Discover(object):
rel = link['rel']
url = _combine_relative_url(self._url, link['href'])
except (KeyError, TypeError):
_LOGGER.info('Skipping invalid version link. '
'Missing link URL or relationship.')
_LOGGER.info(
'Skipping invalid version link. '
'Missing link URL or relationship.'
)
continue
if rel.lower() == 'self':
@ -642,20 +655,25 @@ class Discover(object):
elif rel.lower() == 'collection':
collection_url = url
if not self_url:
_LOGGER.info('Skipping invalid version data. '
'Missing link to endpoint.')
_LOGGER.info(
'Skipping invalid version data. '
'Missing link to endpoint.'
)
continue
versions.append(
VersionData(version=version_number,
url=self_url,
collection=collection_url,
min_microversion=min_microversion,
max_microversion=max_microversion,
next_min_version=next_min_version,
not_before=not_before,
status=Status.normalize(v['status']),
raw_status=v['status']))
VersionData(
version=version_number,
url=self_url,
collection=collection_url,
min_microversion=min_microversion,
max_microversion=max_microversion,
next_min_version=next_min_version,
not_before=not_before,
status=Status.normalize(v['status']),
raw_status=v['status'],
)
)
versions.sort(key=lambda v: v['version'], reverse=reverse)
return versions
@ -723,9 +741,9 @@ class Discover(object):
data = self.data_for(version, **kwargs)
return data['url'] if data else None
def versioned_data_for(self, url=None,
min_version=None, max_version=None,
**kwargs):
def versioned_data_for(
self, url=None, min_version=None, max_version=None, **kwargs
):
"""Return endpoint data for the service at a url.
min_version and max_version can be given either as strings or tuples.
@ -747,15 +765,17 @@ class Discover(object):
:rtype: dict
"""
min_version, max_version = _normalize_version_args(
None, min_version, max_version)
None, min_version, max_version
)
no_version = not max_version and not min_version
version_data = self.version_data(reverse=True, **kwargs)
# If we don't have to check a min_version, we can short
# circuit anything else
if (max_version == (LATEST, LATEST) and
(not min_version or min_version == (LATEST, LATEST))):
if max_version == (LATEST, LATEST) and (
not min_version or min_version == (LATEST, LATEST)
):
# because we reverse we can just take the first entry
return version_data[0]
@ -774,8 +794,11 @@ class Discover(object):
if _latest_soft_match(min_version, data['version']):
return data
# Only validate version bounds if versions were specified
if min_version and max_version and version_between(
min_version, max_version, data['version']):
if (
min_version
and max_version
and version_between(min_version, max_version, data['version'])
):
return data
# If there is no version requested and we could not find a matching
@ -805,8 +828,9 @@ class Discover(object):
:returns: The url for the specified version or None if no match.
:rtype: str
"""
data = self.versioned_data_for(min_version=min_version,
max_version=max_version, **kwargs)
data = self.versioned_data_for(
min_version=min_version, max_version=max_version, **kwargs
)
return data['url'] if data else None
@ -814,17 +838,18 @@ class VersionData(dict):
"""Normalized Version Data about an endpoint."""
def __init__(
self,
version,
url,
collection=None,
max_microversion=None,
min_microversion=None,
next_min_version=None,
not_before=None,
status='CURRENT',
raw_status=None):
super(VersionData, self).__init__()
self,
version,
url,
collection=None,
max_microversion=None,
min_microversion=None,
next_min_version=None,
not_before=None,
status='CURRENT',
raw_status=None,
):
super().__init__()
self['version'] = version
self['url'] = url
self['collection'] = collection
@ -883,7 +908,7 @@ class VersionData(dict):
return self.get('raw_status')
class EndpointData(object):
class EndpointData:
"""Normalized information about a discovered endpoint.
Contains url, version, microversion, interface and region information.
@ -894,23 +919,25 @@ class EndpointData(object):
possibilities.
"""
def __init__(self,
catalog_url=None,
service_url=None,
service_type=None,
service_name=None,
service_id=None,
region_name=None,
interface=None,
endpoint_id=None,
raw_endpoint=None,
api_version=None,
major_version=None,
min_microversion=None,
max_microversion=None,
next_min_version=None,
not_before=None,
status=None):
def __init__(
self,
catalog_url=None,
service_url=None,
service_type=None,
service_name=None,
service_id=None,
region_name=None,
interface=None,
endpoint_id=None,
raw_endpoint=None,
api_version=None,
major_version=None,
min_microversion=None,
max_microversion=None,
next_min_version=None,
not_before=None,
status=None,
):
self.catalog_url = catalog_url
self.service_url = service_url
self.service_type = service_type
@ -962,19 +989,35 @@ class EndpointData(object):
def __str__(self):
"""Produce a string like EndpointData{key=val, ...}, for debugging."""
str_attrs = (
'api_version', 'catalog_url', 'endpoint_id', 'interface',
'major_version', 'max_microversion', 'min_microversion',
'next_min_version', 'not_before', 'raw_endpoint', 'region_name',
'service_id', 'service_name', 'service_type', 'service_url', 'url')
return "%s{%s}" % (self.__class__.__name__, ', '.join(
["%s=%s" % (attr, getattr(self, attr)) for attr in str_attrs]))
'api_version',
'catalog_url',
'endpoint_id',
'interface',
'major_version',
'max_microversion',
'min_microversion',
'next_min_version',
'not_before',
'raw_endpoint',
'region_name',
'service_id',
'service_name',
'service_type',
'service_url',
'url',
)
return "{}{{{}}}".format(
self.__class__.__name__,
', '.join([f"{attr}={getattr(self, attr)}" for attr in str_attrs]),
)
@property
def url(self):
return self.service_url or self.catalog_url
def get_current_versioned_data(self, session, allow=None, cache=None,
project_id=None):
def get_current_versioned_data(
self, session, allow=None, cache=None, project_id=None
):
"""Run version discovery on the current endpoint.
A simplified version of get_versioned_data, get_current_versioned_data
@ -1001,16 +1044,29 @@ class EndpointData(object):
could not be discovered.
"""
min_version, max_version = _normalize_version_args(
self.api_version, None, None)
self.api_version, None, None
)
return self.get_versioned_data(
session=session, allow=allow, cache=cache, allow_version_hack=True,
session=session,
allow=allow,
cache=cache,
allow_version_hack=True,
discover_versions=True,
min_version=min_version, max_version=max_version)
min_version=min_version,
max_version=max_version,
)
def get_versioned_data(self, session, allow=None, cache=None,
allow_version_hack=True, project_id=None,
discover_versions=True,
min_version=None, max_version=None):
def get_versioned_data(
self,
session,
allow=None,
cache=None,
allow_version_hack=True,
project_id=None,
discover_versions=True,
min_version=None,
max_version=None,
):
"""Run version discovery for the service described.
Performs Version Discovery and returns a new EndpointData object with
@ -1050,7 +1106,8 @@ class EndpointData(object):
could not be discovered.
"""
min_version, max_version = _normalize_version_args(
None, min_version, max_version)
None, min_version, max_version
)
if not allow:
allow = {}
@ -1059,10 +1116,15 @@ class EndpointData(object):
new_data = copy.copy(self)
new_data._set_version_info(
session=session, allow=allow, cache=cache,
allow_version_hack=allow_version_hack, project_id=project_id,
discover_versions=discover_versions, min_version=min_version,
max_version=max_version)
session=session,
allow=allow,
cache=cache,
allow_version_hack=allow_version_hack,
project_id=project_id,
discover_versions=discover_versions,
min_version=min_version,
max_version=max_version,
)
return new_data
def get_all_version_string_data(self, session, project_id=None):
@ -1082,7 +1144,8 @@ class EndpointData(object):
# Ignore errors here - we're just searching for one of the
# URLs that will give us data.
_LOGGER.debug(
"Failed attempt at discovery on %s: %s", vers_url, str(e))
"Failed attempt at discovery on %s: %s", vers_url, str(e)
)
continue
for version in d.version_string_data():
versions.append(version)
@ -1109,10 +1172,17 @@ class EndpointData(object):
return [VersionData(url=url, version=version)]
def _set_version_info(self, session, allow=None, cache=None,
allow_version_hack=True, project_id=None,
discover_versions=False,
min_version=None, max_version=None):
def _set_version_info(
self,
session,
allow=None,
cache=None,
allow_version_hack=True,
project_id=None,
discover_versions=False,
min_version=None,
max_version=None,
):
match_url = None
no_version = not max_version and not min_version
@ -1134,38 +1204,48 @@ class EndpointData(object):
# satisfy the request without further work
if self._disc:
discovered_data = self._disc.versioned_data_for(
min_version=min_version, max_version=max_version,
url=match_url, **allow)
min_version=min_version,
max_version=max_version,
url=match_url,
**allow,
)
if not discovered_data:
self._run_discovery(
session=session, cache=cache,
min_version=min_version, max_version=max_version,
project_id=project_id, allow_version_hack=allow_version_hack,
discover_versions=discover_versions)
session=session,
cache=cache,
min_version=min_version,
max_version=max_version,
project_id=project_id,
allow_version_hack=allow_version_hack,
discover_versions=discover_versions,
)
if not self._disc:
return
discovered_data = self._disc.versioned_data_for(
min_version=min_version, max_version=max_version,
url=match_url, **allow)
min_version=min_version,
max_version=max_version,
url=match_url,
**allow,
)
if not discovered_data:
if min_version and not max_version:
raise exceptions.DiscoveryFailure(
"Minimum version {min_version} was not found".format(
min_version=version_to_string(min_version)))
f"Minimum version {version_to_string(min_version)} was not found"
)
elif max_version and not min_version:
raise exceptions.DiscoveryFailure(
"Maximum version {max_version} was not found".format(
max_version=version_to_string(max_version)))
f"Maximum version {version_to_string(max_version)} was not found"
)
elif min_version and max_version:
raise exceptions.DiscoveryFailure(
"No version found between {min_version}"
" and {max_version}".format(
min_version=version_to_string(min_version),
max_version=version_to_string(max_version)))
f"No version found between {version_to_string(min_version)}"
f" and {version_to_string(max_version)}"
)
else:
raise exceptions.DiscoveryFailure(
"No version data found remotely at all")
"No version data found remotely at all"
)
self.min_microversion = discovered_data['min_microversion']
self.max_microversion = discovered_data['max_microversion']
@ -1184,25 +1264,35 @@ class EndpointData(object):
# for example a "v2" path from http://host/admin should resolve as
# http://host/admin/v2 where it would otherwise be host/v2.
# This has no effect on absolute urls returned from url_for.
url = urllib.parse.urljoin(self._disc._url.rstrip('/') + '/',
discovered_url)
url = urllib.parse.urljoin(
self._disc._url.rstrip('/') + '/', discovered_url
)
# If we had to pop a project_id from the catalog_url, put it back on
if self._saved_project_id:
url = urllib.parse.urljoin(url.rstrip('/') + '/',
self._saved_project_id)
url = urllib.parse.urljoin(
url.rstrip('/') + '/', self._saved_project_id
)
self.service_url = url
def _run_discovery(self, session, cache, min_version, max_version,
project_id, allow_version_hack, discover_versions):
def _run_discovery(
self,
session,
cache,
min_version,
max_version,
project_id,
allow_version_hack,
discover_versions,
):
tried = set()
for vers_url in self._get_discovery_url_choices(
project_id=project_id,
allow_version_hack=allow_version_hack,
min_version=min_version,
max_version=max_version):
project_id=project_id,
allow_version_hack=allow_version_hack,
min_version=min_version,
max_version=max_version,
):
if self._catalog_matches_exactly and not discover_versions:
# The version we started with is correct, and we don't want
# new data
@ -1214,13 +1304,14 @@ class EndpointData(object):
try:
self._disc = get_discovery(
session, vers_url,
cache=cache,
authenticated=False)
session, vers_url, cache=cache, authenticated=False
)
break
except (exceptions.DiscoveryFailure,
exceptions.HttpError,
exceptions.ConnectionError) as exc:
except (
exceptions.DiscoveryFailure,
exceptions.HttpError,
exceptions.ConnectionError,
) as exc:
_LOGGER.debug('No version document at %s: %s', vers_url, exc)
continue
if not self._disc:
@ -1242,7 +1333,9 @@ class EndpointData(object):
_LOGGER.warning(
'Failed to contact the endpoint at %s for '
'discovery. Fallback to using that endpoint as '
'the base url.', self.url)
'the base url.',
self.url,
)
return
else:
@ -1252,14 +1345,20 @@ class EndpointData(object):
# date enough to properly specify a version and keystoneauth
# can't deliver.
raise exceptions.DiscoveryFailure(
"Unable to find a version discovery document at %s, "
"Unable to find a version discovery document at {}, "
"the service is unavailable or misconfigured. "
"Required version range (%s - %s), version hack disabled."
% (self.url, min_version or "any", max_version or "any"))
"Required version range ({} - {}), version hack disabled.".format(
self.url, min_version or "any", max_version or "any"
)
)
def _get_discovery_url_choices(
self, project_id=None, allow_version_hack=True,
min_version=None, max_version=None):
self,
project_id=None,
allow_version_hack=True,
min_version=None,
max_version=None,
):
"""Find potential locations for version discovery URLs.
min_version and max_version are already normalized, so will either be
@ -1295,19 +1394,27 @@ class EndpointData(object):
'/'.join(url_parts),
url.params,
url.query,
url.fragment).geturl()
url.fragment,
).geturl()
except TypeError:
pass
else:
# `is_between` means version bounds were specified *and* the URL
# version is between them.
is_between = min_version and max_version and version_between(
min_version, max_version, url_version)
exact_match = (is_between and max_version and
max_version[0] == url_version[0])
high_match = (is_between and max_version and
max_version[1] != LATEST and
version_match(max_version, url_version))
is_between = (
min_version
and max_version
and version_between(min_version, max_version, url_version)
)
exact_match = (
is_between and max_version and max_version[0] == url_version[0]
)
high_match = (
is_between
and max_version
and max_version[1] != LATEST
and version_match(max_version, url_version)
)
if exact_match or is_between:
self._catalog_matches_version = True
self._catalog_matches_exactly = exact_match
@ -1316,13 +1423,19 @@ class EndpointData(object):
# return it just yet. It's a good option, but unless we
# have an exact match or match the max requested, we want
# to try for an unversioned endpoint first.
catalog_discovery = urllib.parse.ParseResult(
url.scheme,
url.netloc,
'/'.join(url_parts),
url.params,
url.query,
url.fragment).geturl().rstrip('/') + '/'
catalog_discovery = (
urllib.parse.ParseResult(
url.scheme,
url.netloc,
'/'.join(url_parts),
url.params,
url.query,
url.fragment,
)
.geturl()
.rstrip('/')
+ '/'
)
# If we found a viable catalog endpoint and it's
# an exact match or matches the max, go ahead and give
@ -1342,7 +1455,8 @@ class EndpointData(object):
'/'.join(url_parts),
url.params,
url.query,
url.fragment).geturl()
url.fragment,
).geturl()
# Since this is potentially us constructing a base URL from the
# versioned URL - we need to make sure it has a trailing /. But
# we only want to do that if we have built a new URL - not if
@ -1448,7 +1562,8 @@ def get_discovery(session, url, cache=None, authenticated=False):
'',
parsed_url.params,
parsed_url.query,
parsed_url.fragment).geturl()
parsed_url.fragment,
).geturl()
for cache in caches:
disc = cache.get(url)
@ -1468,7 +1583,7 @@ def get_discovery(session, url, cache=None, authenticated=False):
return disc
class _VersionHacks(object):
class _VersionHacks:
"""A container to abstract the list of version hacks.
This could be done as simply a dictionary but is abstracted like this to

View File

@ -28,5 +28,5 @@ class MissingAuthMethods(base.ClientException):
self.methods = body['receipt']['methods']
self.required_auth_methods = body['required_auth_methods']
self.expires_at = utils.parse_isotime(body['receipt']['expires_at'])
message = "%s: %s" % (self.message, self.required_auth_methods)
super(MissingAuthMethods, self).__init__(message)
message = f"{self.message}: {self.required_auth_methods}"
super().__init__(message)

View File

@ -13,12 +13,14 @@
from keystoneauth1.exceptions import base
__all__ = ('AuthPluginException',
'MissingAuthPlugin',
'NoMatchingPlugin',
'UnsupportedParameters',
'OptionError',
'MissingRequiredOptions')
__all__ = (
'AuthPluginException',
'MissingAuthPlugin',
'NoMatchingPlugin',
'UnsupportedParameters',
'OptionError',
'MissingRequiredOptions',
)
class AuthPluginException(base.ClientException):
@ -41,8 +43,8 @@ class NoMatchingPlugin(AuthPluginException):
def __init__(self, name):
self.name = name
msg = 'The plugin %s could not be found' % name
super(NoMatchingPlugin, self).__init__(msg)
msg = f'The plugin {name} could not be found'
super().__init__(msg)
class UnsupportedParameters(AuthPluginException):
@ -59,7 +61,7 @@ class UnsupportedParameters(AuthPluginException):
self.names = names
m = 'The following parameters were given that are unsupported: %s'
super(UnsupportedParameters, self).__init__(m % ', '.join(self.names))
super().__init__(m % ', '.join(self.names))
class OptionError(AuthPluginException):
@ -90,4 +92,4 @@ class MissingRequiredOptions(OptionError):
names = ", ".join(o.dest for o in options)
m = 'Auth plugin requires parameters which were not given: %s'
super(MissingRequiredOptions, self).__init__(m % names)
super().__init__(m % names)

View File

@ -21,4 +21,4 @@ class ClientException(Exception):
def __init__(self, message=None):
self.message = message or self.message
super(ClientException, self).__init__(self.message)
super().__init__(self.message)

View File

@ -13,9 +13,7 @@
from keystoneauth1.exceptions import base
__all__ = ('CatalogException',
'EmptyCatalog',
'EndpointNotFound')
__all__ = ('CatalogException', 'EmptyCatalog', 'EndpointNotFound')
class CatalogException(base.ClientException):

View File

@ -13,12 +13,14 @@
from keystoneauth1.exceptions import base
__all__ = ('ConnectionError',
'ConnectTimeout',
'ConnectFailure',
'SSLError',
'RetriableConnectionFailure',
'UnknownConnectionError')
__all__ = (
'ConnectionError',
'ConnectTimeout',
'ConnectFailure',
'SSLError',
'RetriableConnectionFailure',
'UnknownConnectionError',
)
class RetriableConnectionFailure(Exception):
@ -47,5 +49,5 @@ class UnknownConnectionError(ConnectionError):
"""An error was encountered but we don't know what it is."""
def __init__(self, msg, original):
super(UnknownConnectionError, self).__init__(msg)
super().__init__(msg)
self.original = original

View File

@ -17,11 +17,13 @@ from keystoneauth1.exceptions import base
_SERVICE_TYPES = os_service_types.ServiceTypes()
__all__ = ('DiscoveryFailure',
'ImpliedVersionMismatch',
'ImpliedMinVersionMismatch',
'ImpliedMaxVersionMismatch',
'VersionNotAvailable')
__all__ = (
'DiscoveryFailure',
'ImpliedVersionMismatch',
'ImpliedMinVersionMismatch',
'ImpliedMaxVersionMismatch',
'VersionNotAvailable',
)
class DiscoveryFailure(base.ClientException):
@ -36,17 +38,12 @@ class ImpliedVersionMismatch(ValueError):
label = 'version'
def __init__(self, service_type, implied, given):
super(ImpliedVersionMismatch, self).__init__(
"service_type {service_type} was given which implies"
" major API version {implied} but {label} of"
" {given} was also given. Please update your code"
" to use the official service_type {official_type}.".format(
service_type=service_type,
implied=str(implied[0]),
given=given,
label=self.label,
official_type=_SERVICE_TYPES.get_service_type(service_type),
))
super().__init__(
f"service_type {service_type} was given which implies"
f" major API version {str(implied[0])} but {self.label} of"
f" {given} was also given. Please update your code"
f" to use the official service_type {_SERVICE_TYPES.get_service_type(service_type)}."
)
class ImpliedMinVersionMismatch(ImpliedVersionMismatch):

View File

@ -25,38 +25,37 @@ from keystoneauth1.exceptions import auth
from keystoneauth1.exceptions import base
__all__ = ('HttpError',
'HTTPClientError',
'BadRequest',
'Unauthorized',
'PaymentRequired',
'Forbidden',
'NotFound',
'MethodNotAllowed',
'NotAcceptable',
'ProxyAuthenticationRequired',
'RequestTimeout',
'Conflict',
'Gone',
'LengthRequired',
'PreconditionFailed',
'RequestEntityTooLarge',
'RequestUriTooLong',
'UnsupportedMediaType',
'RequestedRangeNotSatisfiable',
'ExpectationFailed',
'UnprocessableEntity',
'HttpServerError',
'InternalServerError',
'HttpNotImplemented',
'BadGateway',
'ServiceUnavailable',
'GatewayTimeout',
'HttpVersionNotSupported',
'from_response')
__all__ = (
'HttpError',
'HTTPClientError',
'BadRequest',
'Unauthorized',
'PaymentRequired',
'Forbidden',
'NotFound',
'MethodNotAllowed',
'NotAcceptable',
'ProxyAuthenticationRequired',
'RequestTimeout',
'Conflict',
'Gone',
'LengthRequired',
'PreconditionFailed',
'RequestEntityTooLarge',
'RequestUriTooLong',
'UnsupportedMediaType',
'RequestedRangeNotSatisfiable',
'ExpectationFailed',
'UnprocessableEntity',
'HttpServerError',
'InternalServerError',
'HttpNotImplemented',
'BadGateway',
'ServiceUnavailable',
'GatewayTimeout',
'HttpVersionNotSupported',
'from_response',
)
class HttpError(base.ClientException):
@ -65,10 +64,17 @@ class HttpError(base.ClientException):
http_status = 0
message = "HTTP Error"
def __init__(self, message=None, details=None,
response=None, request_id=None,
url=None, method=None, http_status=None,
retry_after=0):
def __init__(
self,
message=None,
details=None,
response=None,
request_id=None,
url=None,
method=None,
http_status=None,
retry_after=0,
):
self.http_status = http_status or self.http_status
self.message = message or self.message
self.details = details
@ -76,11 +82,11 @@ class HttpError(base.ClientException):
self.response = response
self.url = url
self.method = method
formatted_string = "%s (HTTP %s)" % (self.message, self.http_status)
formatted_string = f"{self.message} (HTTP {self.http_status})"
self.retry_after = retry_after
if request_id:
formatted_string += " (Request-ID: %s)" % request_id
super(HttpError, self).__init__(formatted_string)
formatted_string += f" (Request-ID: {request_id})"
super().__init__(formatted_string)
class HTTPClientError(HttpError):
@ -256,7 +262,7 @@ class RequestEntityTooLarge(HTTPClientError):
except (KeyError, ValueError):
self.retry_after = 0
super(RequestEntityTooLarge, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
class RequestUriTooLong(HTTPClientError):
@ -377,11 +383,11 @@ class HttpVersionNotSupported(HttpServerError):
# _code_map contains all the classes that have http_status attribute.
_code_map = dict(
(getattr(obj, 'http_status', None), obj)
_code_map = {
getattr(obj, 'http_status', None): obj
for name, obj in vars(sys.modules[__name__]).items()
if inspect.isclass(obj) and getattr(obj, 'http_status', False)
)
}
def from_response(response, method, url):
@ -414,8 +420,9 @@ def from_response(response, method, url):
error = body["error"]
kwargs["message"] = error.get("message")
kwargs["details"] = error.get("details")
elif (isinstance(body, dict) and
isinstance(body.get("errors"), list)):
elif isinstance(body, dict) and isinstance(
body.get("errors"), list
):
# if the error response follows the API SIG guidelines, it
# will return a list of errors. in this case, only the first
# error is shown, but if there are multiple the user will be
@ -429,13 +436,15 @@ def from_response(response, method, url):
if len(errors) > 1:
# if there is more than one error, let the user know
# that multiple were seen.
msg_hdr = ("Multiple error responses, "
"showing first only: ")
msg_hdr = (
"Multiple error responses, showing first only: "
)
else:
msg_hdr = ""
kwargs["message"] = "{}{}".format(msg_hdr,
errors[0].get("title"))
kwargs["message"] = "{}{}".format(
msg_hdr, errors[0].get("title")
)
kwargs["details"] = errors[0].get("detail")
else:
kwargs["message"] = "Unrecognized schema in response body."
@ -444,8 +453,10 @@ def from_response(response, method, url):
kwargs["details"] = response.text
# we check explicity for 401 in case of auth receipts
if (response.status_code == 401
and "Openstack-Auth-Receipt" in response.headers):
if (
response.status_code == 401
and "Openstack-Auth-Receipt" in response.headers
):
return auth.MissingAuthMethods(response)
try:

View File

@ -14,18 +14,21 @@
from keystoneauth1.exceptions import auth_plugins
__all__ = (
'InvalidDiscoveryEndpoint', 'InvalidOidcDiscoveryDocument',
'OidcAccessTokenEndpointNotFound', 'OidcAuthorizationEndpointNotFound',
'OidcGrantTypeMissmatch', 'OidcPluginNotSupported',
'InvalidDiscoveryEndpoint',
'InvalidOidcDiscoveryDocument',
'OidcAccessTokenEndpointNotFound',
'OidcAuthorizationEndpointNotFound',
'OidcGrantTypeMissmatch',
'OidcPluginNotSupported',
)
class InvalidDiscoveryEndpoint(auth_plugins.AuthPluginException):
message = "OpenID Connect Discovery Document endpoint not set."""
message = "OpenID Connect Discovery Document endpoint not set."
class InvalidOidcDiscoveryDocument(auth_plugins.AuthPluginException):
message = "OpenID Connect Discovery Document is not valid JSON."""
message = "OpenID Connect Discovery Document is not valid JSON."
class OidcAccessTokenEndpointNotFound(auth_plugins.AuthPluginException):
@ -37,7 +40,8 @@ class OidcAuthorizationEndpointNotFound(auth_plugins.AuthPluginException):
class OidcDeviceAuthorizationEndpointNotFound(
auth_plugins.AuthPluginException):
auth_plugins.AuthPluginException
):
message = "OpenID Connect device authorization endpoint not provided."

View File

@ -21,5 +21,5 @@ class InvalidResponse(base.ClientException):
message = "Invalid response from server."
def __init__(self, response):
super(InvalidResponse, self).__init__()
super().__init__()
self.response = response

View File

@ -20,5 +20,5 @@ class ServiceProviderNotFound(base.ClientException):
def __init__(self, sp_id):
self.sp_id = sp_id
msg = 'The Service Provider %(sp)s could not be found' % {'sp': sp_id}
super(ServiceProviderNotFound, self).__init__(msg)
msg = f'The Service Provider {sp_id} could not be found'
super().__init__(msg)

View File

@ -15,7 +15,6 @@ from keystoneauth1 import loading
class Saml2Password(loading.BaseFederationLoader):
@property
def plugin_class(self):
return _saml2.V3Saml2Password
@ -25,25 +24,29 @@ class Saml2Password(loading.BaseFederationLoader):
return _saml2._V3_SAML2_AVAILABLE
def get_options(self):
options = super(Saml2Password, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('identity-provider-url',
required=True,
help=('An Identity Provider URL, where the SAML2 '
'authentication request will be sent.')),
loading.Opt('username', help='Username', required=True),
loading.Opt('password',
secret=True,
help='Password',
required=True)
])
options.extend(
[
loading.Opt(
'identity-provider-url',
required=True,
help=(
'An Identity Provider URL, where the SAML2 '
'authentication request will be sent.'
),
),
loading.Opt('username', help='Username', required=True),
loading.Opt(
'password', secret=True, help='Password', required=True
),
]
)
return options
class ADFSPassword(loading.BaseFederationLoader):
@property
def plugin_class(self):
return _saml2.V3ADFSPassword
@ -53,24 +56,33 @@ class ADFSPassword(loading.BaseFederationLoader):
return _saml2._V3_ADFS_AVAILABLE
def get_options(self):
options = super(ADFSPassword, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('identity-provider-url',
required=True,
help=('An Identity Provider URL, where the SAML '
'authentication request will be sent.')),
loading.Opt('service-provider-endpoint',
required=True,
help="Service Provider's Endpoint"),
loading.Opt('service-provider-entity-id',
required=True,
help="Service Provider's SAML Entity ID"),
loading.Opt('username', help='Username', required=True),
loading.Opt('password',
secret=True,
required=True,
help='Password')
])
options.extend(
[
loading.Opt(
'identity-provider-url',
required=True,
help=(
'An Identity Provider URL, where the SAML '
'authentication request will be sent.'
),
),
loading.Opt(
'service-provider-endpoint',
required=True,
help="Service Provider's Endpoint",
),
loading.Opt(
'service-provider-entity-id',
required=True,
help="Service Provider's SAML Entity ID",
),
loading.Opt('username', help='Username', required=True),
loading.Opt(
'password', secret=True, required=True, help='Password'
),
]
)
return options

View File

@ -35,21 +35,34 @@ class Password(base.BaseSAMLPlugin):
NAMESPACES = {
's': 'http://www.w3.org/2003/05/soap-envelope',
'a': 'http://www.w3.org/2005/08/addressing',
'u': ('http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd')
'u': (
'http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd'
),
}
ADFS_TOKEN_NAMESPACES = {
's': 'http://www.w3.org/2003/05/soap-envelope',
't': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512'
't': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512',
}
ADFS_ASSERTION_XPATH = ('/s:Envelope/s:Body'
'/t:RequestSecurityTokenResponseCollection'
'/t:RequestSecurityTokenResponse')
ADFS_ASSERTION_XPATH = (
'/s:Envelope/s:Body'
'/t:RequestSecurityTokenResponseCollection'
'/t:RequestSecurityTokenResponse'
)
def __init__(self, auth_url, identity_provider, identity_provider_url,
service_provider_endpoint, username, password,
protocol, service_provider_entity_id=None, **kwargs):
def __init__(
self,
auth_url,
identity_provider,
identity_provider_url,
service_provider_endpoint,
username,
password,
protocol,
service_provider_entity_id=None,
**kwargs,
):
"""Constructor for ``ADFSPassword``.
:param auth_url: URL of the Identity Service
@ -78,10 +91,15 @@ class Password(base.BaseSAMLPlugin):
:type password: string
"""
super(Password, self).__init__(
auth_url=auth_url, identity_provider=identity_provider,
super().__init__(
auth_url=auth_url,
identity_provider=identity_provider,
identity_provider_url=identity_provider_url,
username=username, password=password, protocol=protocol, **kwargs)
username=username,
password=password,
protocol=protocol,
**kwargs,
)
self.service_provider_endpoint = service_provider_endpoint
self.service_provider_entity_id = service_provider_entity_id
@ -123,9 +141,11 @@ class Password(base.BaseSAMLPlugin):
"""
date_created = datetime.datetime.now(datetime.timezone.utc).replace(
tzinfo=None)
tzinfo=None
)
date_expires = date_created + datetime.timedelta(
seconds=self.DEFAULT_ADFS_TOKEN_EXPIRATION)
seconds=self.DEFAULT_ADFS_TOKEN_EXPIRATION
)
return [_time.strftime(fmt) for _time in (date_created, date_expires)]
def _prepare_adfs_request(self):
@ -135,132 +155,188 @@ class Password(base.BaseSAMLPlugin):
"""
WSS_SECURITY_NAMESPACE = {
'o': ('http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd')
'o': (
'http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd'
)
}
TRUST_NAMESPACE = {
'trust': 'http://docs.oasis-open.org/ws-sx/ws-trust/200512'
}
WSP_NAMESPACE = {
'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy'
}
WSP_NAMESPACE = {'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy'}
WSA_NAMESPACE = {
'wsa': 'http://www.w3.org/2005/08/addressing'
}
WSA_NAMESPACE = {'wsa': 'http://www.w3.org/2005/08/addressing'}
root = etree.Element(
'{http://www.w3.org/2003/05/soap-envelope}Envelope',
nsmap=self.NAMESPACES)
nsmap=self.NAMESPACES,
)
header = etree.SubElement(
root, '{http://www.w3.org/2003/05/soap-envelope}Header')
root, '{http://www.w3.org/2003/05/soap-envelope}Header'
)
action = etree.SubElement(
header, "{http://www.w3.org/2005/08/addressing}Action")
header, "{http://www.w3.org/2005/08/addressing}Action"
)
action.set(
"{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1")
action.text = ('http://docs.oasis-open.org/ws-sx/ws-trust/200512'
'/RST/Issue')
"{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1"
)
action.text = (
'http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue'
)
messageID = etree.SubElement(
header, '{http://www.w3.org/2005/08/addressing}MessageID')
header, '{http://www.w3.org/2005/08/addressing}MessageID'
)
messageID.text = 'urn:uuid:' + uuid.uuid4().hex
replyID = etree.SubElement(
header, '{http://www.w3.org/2005/08/addressing}ReplyTo')
header, '{http://www.w3.org/2005/08/addressing}ReplyTo'
)
address = etree.SubElement(
replyID, '{http://www.w3.org/2005/08/addressing}Address')
replyID, '{http://www.w3.org/2005/08/addressing}Address'
)
address.text = 'http://www.w3.org/2005/08/addressing/anonymous'
to = etree.SubElement(
header, '{http://www.w3.org/2005/08/addressing}To')
header, '{http://www.w3.org/2005/08/addressing}To'
)
to.set("{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1")
security = etree.SubElement(
header, '{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
header,
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd}Security',
nsmap=WSS_SECURITY_NAMESPACE)
nsmap=WSS_SECURITY_NAMESPACE,
)
security.set(
"{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1")
"{http://www.w3.org/2003/05/soap-envelope}mustUnderstand", "1"
)
timestamp = etree.SubElement(
security, ('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Timestamp'))
security,
(
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Timestamp'
),
)
timestamp.set(
('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Id'), '_0')
(
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Id'
),
'_0',
)
created = etree.SubElement(
timestamp, ('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Created'))
timestamp,
(
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Created'
),
)
expires = etree.SubElement(
timestamp, ('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Expires'))
timestamp,
(
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-utility-1.0.xsd}Expires'
),
)
created.text, expires.text = self._token_dates()
usernametoken = etree.SubElement(
security, '{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd}UsernameToken')
security,
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-'
'wss-wssecurity-secext-1.0.xsd}UsernameToken',
)
usernametoken.set(
('{http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-'
'wssecurity-utility-1.0.xsd}u'), "uuid-%s-1" % uuid.uuid4().hex)
(
'{http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-'
'wssecurity-utility-1.0.xsd}u'
),
f"uuid-{uuid.uuid4().hex}-1",
)
username = etree.SubElement(
usernametoken, ('{http://docs.oasis-open.org/wss/2004/01/oasis-'
'200401-wss-wssecurity-secext-1.0.xsd}Username'))
usernametoken,
(
'{http://docs.oasis-open.org/wss/2004/01/oasis-'
'200401-wss-wssecurity-secext-1.0.xsd}Username'
),
)
password = etree.SubElement(
usernametoken, ('{http://docs.oasis-open.org/wss/2004/01/oasis-'
'200401-wss-wssecurity-secext-1.0.xsd}Password'),
Type=('http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-'
'username-token-profile-1.0#PasswordText'))
usernametoken,
(
'{http://docs.oasis-open.org/wss/2004/01/oasis-'
'200401-wss-wssecurity-secext-1.0.xsd}Password'
),
Type=(
'http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-'
'username-token-profile-1.0#PasswordText'
),
)
body = etree.SubElement(
root, "{http://www.w3.org/2003/05/soap-envelope}Body")
root, "{http://www.w3.org/2003/05/soap-envelope}Body"
)
request_security_token = etree.SubElement(
body, ('{http://docs.oasis-open.org/ws-sx/ws-trust/200512}'
'RequestSecurityToken'), nsmap=TRUST_NAMESPACE)
body,
(
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}'
'RequestSecurityToken'
),
nsmap=TRUST_NAMESPACE,
)
applies_to = etree.SubElement(
request_security_token,
'{http://schemas.xmlsoap.org/ws/2004/09/policy}AppliesTo',
nsmap=WSP_NAMESPACE)
nsmap=WSP_NAMESPACE,
)
endpoint_reference = etree.SubElement(
applies_to,
'{http://www.w3.org/2005/08/addressing}EndpointReference',
nsmap=WSA_NAMESPACE)
nsmap=WSA_NAMESPACE,
)
wsa_address = etree.SubElement(
endpoint_reference,
'{http://www.w3.org/2005/08/addressing}Address')
endpoint_reference, '{http://www.w3.org/2005/08/addressing}Address'
)
keytype = etree.SubElement(
request_security_token,
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}KeyType')
keytype.text = ('http://docs.oasis-open.org/ws-sx/'
'ws-trust/200512/Bearer')
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}KeyType',
)
keytype.text = (
'http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer'
)
request_type = etree.SubElement(
request_security_token,
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}RequestType')
request_type.text = ('http://docs.oasis-open.org/ws-sx/'
'ws-trust/200512/Issue')
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}RequestType',
)
request_type.text = (
'http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue'
)
token_type = etree.SubElement(
request_security_token,
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}TokenType')
'{http://docs.oasis-open.org/ws-sx/ws-trust/200512}TokenType',
)
token_type.text = 'urn:oasis:names:tc:SAML:1.0:assertion'
# After constructing the request, let's plug in some values
username.text = self.username
password.text = self.password
to.text = self.identity_provider_url
wsa_address.text = (self.service_provider_entity_id or
self.service_provider_endpoint)
wsa_address.text = (
self.service_provider_entity_id or self.service_provider_endpoint
)
self.prepared_request = root
@ -289,12 +365,14 @@ class Password(base.BaseSAMLPlugin):
recognized.
"""
def _get_failure(e):
xpath = '/s:Envelope/s:Body/s:Fault/s:Code/s:Subcode/s:Value'
content = e.response.content
try:
obj = self.str_to_xml(content).xpath(
xpath, namespaces=self.NAMESPACES)
xpath, namespaces=self.NAMESPACES
)
obj = self._first(obj)
return obj.text
# NOTE(marek-denis): etree.Element.xpath() doesn't raise an
@ -309,13 +387,18 @@ class Password(base.BaseSAMLPlugin):
request_security_token = self.xml_to_str(self.prepared_request)
try:
response = session.post(
url=self.identity_provider_url, headers=self.HEADER_SOAP,
data=request_security_token, authenticated=False)
url=self.identity_provider_url,
headers=self.HEADER_SOAP,
data=request_security_token,
authenticated=False,
)
except exceptions.InternalServerError as e:
reason = _get_failure(e)
raise exceptions.AuthorizationFailure(reason)
msg = ('Error parsing XML returned from '
'the ADFS Identity Provider, reason: %s')
msg = (
'Error parsing XML returned from '
'the ADFS Identity Provider, reason: %s'
)
self.adfs_token = self.str_to_xml(response.content, msg)
def _prepare_sp_request(self):
@ -329,7 +412,8 @@ class Password(base.BaseSAMLPlugin):
"""
assertion = self.adfs_token.xpath(
self.ADFS_ASSERTION_XPATH, namespaces=self.ADFS_TOKEN_NAMESPACES)
self.ADFS_ASSERTION_XPATH, namespaces=self.ADFS_TOKEN_NAMESPACES
)
assertion = self._first(assertion)
assertion = self.xml_to_str(assertion)
# TODO(marek-denis): Ideally no string replacement should occur.
@ -338,7 +422,8 @@ class Password(base.BaseSAMLPlugin):
# from scratch and reuse values from the adfs security token.
assertion = assertion.replace(
b'http://docs.oasis-open.org/ws-sx/ws-trust/200512',
b'http://schemas.xmlsoap.org/ws/2005/02/trust')
b'http://schemas.xmlsoap.org/ws/2005/02/trust',
)
encoded_assertion = urllib.parse.quote(assertion)
self.encoded_assertion = 'wa=wsignin1.0&wresult=' + encoded_assertion
@ -358,8 +443,12 @@ class Password(base.BaseSAMLPlugin):
"""
session.post(
url=self.service_provider_endpoint, data=self.encoded_assertion,
headers=self.HEADER_X_FORM, redirect=False, authenticated=False)
url=self.service_provider_endpoint,
data=self.encoded_assertion,
headers=self.HEADER_X_FORM,
redirect=False,
authenticated=False,
)
def _access_service_provider(self, session):
"""Access protected endpoint and fetch unscoped token.
@ -382,9 +471,11 @@ class Password(base.BaseSAMLPlugin):
if self._cookies(session) is False:
raise exceptions.AuthorizationFailure(
"Session object doesn't contain a cookie, therefore you are "
"not allowed to enter the Identity Provider's protected area.")
self.authenticated_response = session.get(self.federated_token_url,
authenticated=False)
"not allowed to enter the Identity Provider's protected area."
)
self.authenticated_response = session.get(
self.federated_token_url, authenticated=False
)
def get_unscoped_auth_ref(self, session, *kwargs):
"""Retrieve unscoped token after authentcation with ADFS server.

View File

@ -23,21 +23,27 @@ class _Saml2TokenAuthMethod(v3.AuthMethod):
_method_parameters = []
def get_auth_data(self, session, auth, headers, **kwargs):
raise exceptions.MethodNotImplemented('This method should never '
'be called')
raise exceptions.MethodNotImplemented(
'This method should never be called'
)
class BaseSAMLPlugin(v3.FederationBaseAuth):
HTTP_MOVED_TEMPORARILY = 302
HTTP_SEE_OTHER = 303
_auth_method_class = _Saml2TokenAuthMethod
def __init__(self, auth_url,
identity_provider, identity_provider_url,
username, password, protocol,
**kwargs):
def __init__(
self,
auth_url,
identity_provider,
identity_provider_url,
username,
password,
protocol,
**kwargs,
):
"""Class constructor accepting following parameters.
:param auth_url: URL of the Identity Service
@ -68,10 +74,12 @@ class BaseSAMLPlugin(v3.FederationBaseAuth):
:type protocol: string
"""
super(BaseSAMLPlugin, self).__init__(
auth_url=auth_url, identity_provider=identity_provider,
super().__init__(
auth_url=auth_url,
identity_provider=identity_provider,
protocol=protocol,
**kwargs)
**kwargs,
)
self.identity_provider_url = identity_provider_url
self.username = username
self.password = password

View File

@ -28,7 +28,7 @@ _PAOS_NAMESPACE = 'urn:liberty:paos:2003-08'
_ECP_NAMESPACE = 'urn:oasis:names:tc:SAML:2.0:profiles:SSO:ecp'
_PAOS_HEADER = 'application/vnd.paos+xml'
_PAOS_VER = 'ver="%s";"%s"' % (_PAOS_NAMESPACE, _ECP_NAMESPACE)
_PAOS_VER = f'ver="{_PAOS_NAMESPACE}";"{_ECP_NAMESPACE}"'
_XML_NAMESPACES = {
'ecp': _ECP_NAMESPACE,
@ -72,14 +72,14 @@ def _response_xml(response, name):
try:
return etree.XML(response.content)
except etree.XMLSyntaxError as e:
msg = 'SAML2: Error parsing XML returned from %s: %s' % (name, e)
msg = f'SAML2: Error parsing XML returned from {name}: {e}'
raise InvalidResponse(msg)
def _str_from_xml(xml, path):
li = xml.xpath(path, namespaces=_XML_NAMESPACES)
if len(li) != 1:
raise IndexError('%s should provide a single element list' % path)
raise IndexError(f'{path} should provide a single element list')
return li[0]
@ -115,7 +115,7 @@ class _SamlAuth(requests.auth.AuthBase):
"""
def __init__(self, identity_provider_url, requests_auth):
super(_SamlAuth, self).__init__()
super().__init__()
self.identity_provider_url = identity_provider_url
self.requests_auth = requests_auth
@ -132,8 +132,10 @@ class _SamlAuth(requests.auth.AuthBase):
return request
def _handle_response(self, response, **kwargs):
if (response.status_code == 200 and
response.headers.get('Content-Type') == _PAOS_HEADER):
if (
response.status_code == 200
and response.headers.get('Content-Type') == _PAOS_HEADER
):
response = self._ecp_retry(response, **kwargs)
return response
@ -151,33 +153,40 @@ class _SamlAuth(requests.auth.AuthBase):
authn_request.remove(authn_request[0])
idp_response = send('POST',
self.identity_provider_url,
headers={'Content-type': 'text/xml'},
data=etree.tostring(authn_request),
auth=self.requests_auth)
idp_response = send(
'POST',
self.identity_provider_url,
headers={'Content-type': 'text/xml'},
data=etree.tostring(authn_request),
auth=self.requests_auth,
)
history.append(idp_response)
authn_response = _response_xml(idp_response, 'Identity Provider')
idp_consumer_url = _str_from_xml(authn_response,
_XPATH_IDP_CONSUMER_URL)
idp_consumer_url = _str_from_xml(
authn_response, _XPATH_IDP_CONSUMER_URL
)
if sp_consumer_url != idp_consumer_url:
# send fault message to the SP, discard the response
send('POST',
sp_consumer_url,
data=_SOAP_FAULT,
headers={'Content-Type': _PAOS_HEADER})
send(
'POST',
sp_consumer_url,
data=_SOAP_FAULT,
headers={'Content-Type': _PAOS_HEADER},
)
# prepare error message and raise an exception.
msg = ('Consumer URLs from Service Provider %(service_provider)s '
'%(sp_consumer_url)s and Identity Provider '
'%(identity_provider)s %(idp_consumer_url)s are not equal')
msg = (
'Consumer URLs from Service Provider %(service_provider)s '
'%(sp_consumer_url)s and Identity Provider '
'%(identity_provider)s %(idp_consumer_url)s are not equal'
)
msg = msg % {
'service_provider': sp_response.request.url,
'sp_consumer_url': sp_consumer_url,
'identity_provider': self.identity_provider_url,
'idp_consumer_url': idp_consumer_url
'idp_consumer_url': idp_consumer_url,
}
raise ConsumerMismatch(msg)
@ -186,19 +195,22 @@ class _SamlAuth(requests.auth.AuthBase):
# idp_consumer_url is the URL on the SP that handles the ECP body
# returned and creates an authenticated session.
final_resp = send('POST',
idp_consumer_url,
headers={'Content-Type': _PAOS_HEADER},
cookies=idp_response.cookies,
data=etree.tostring(authn_response))
final_resp = send(
'POST',
idp_consumer_url,
headers={'Content-Type': _PAOS_HEADER},
cookies=idp_response.cookies,
data=etree.tostring(authn_response),
)
history.append(final_resp)
# the SP should then redirect us back to the original URL to retry the
# original request.
if final_resp.status_code in (requests.codes.found,
requests.codes.other):
if final_resp.status_code in (
requests.codes.found,
requests.codes.other,
):
# Consume content and release the original connection
# to allow our new request to reuse the same one.
sp_response.content
@ -216,13 +228,15 @@ class _SamlAuth(requests.auth.AuthBase):
class _FederatedSaml(v3.FederationBaseAuth):
def __init__(self, auth_url, identity_provider, protocol,
identity_provider_url, **kwargs):
super(_FederatedSaml, self).__init__(auth_url,
identity_provider,
protocol,
**kwargs)
def __init__(
self,
auth_url,
identity_provider,
protocol,
identity_provider_url,
**kwargs,
):
super().__init__(auth_url, identity_provider, protocol, **kwargs)
self.identity_provider_url = identity_provider_url
@abc.abstractmethod
@ -234,9 +248,11 @@ class _FederatedSaml(v3.FederationBaseAuth):
auth = _SamlAuth(self.identity_provider_url, method)
try:
resp = session.get(self.federated_token_url,
requests_auth=auth,
authenticated=False)
resp = session.get(
self.federated_token_url,
requests_auth=auth,
authenticated=False,
)
except SamlException as e:
raise exceptions.AuthorizationFailure(str(e))
@ -287,13 +303,23 @@ class Password(_FederatedSaml):
"""
def __init__(self, auth_url, identity_provider, protocol,
identity_provider_url, username, password, **kwargs):
super(Password, self).__init__(auth_url,
identity_provider,
protocol,
identity_provider_url,
**kwargs)
def __init__(
self,
auth_url,
identity_provider,
protocol,
identity_provider_url,
username,
password,
**kwargs,
):
super().__init__(
auth_url,
identity_provider,
protocol,
identity_provider_url,
**kwargs,
)
self.username = username
self.password = password

View File

@ -43,7 +43,8 @@ def _mutual_auth(value):
def _requests_auth(mutual_authentication):
return requests_kerberos.HTTPKerberosAuth(
mutual_authentication=_mutual_auth(mutual_authentication))
mutual_authentication=_mutual_auth(mutual_authentication)
)
def _dependency_check():
@ -57,12 +58,11 @@ packages. These can be installed with::
class KerberosMethod(v3.AuthMethod):
_method_parameters = ['mutual_auth']
def __init__(self, *args, **kwargs):
_dependency_check()
super(KerberosMethod, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def get_auth_data(self, session, auth, headers, request_kwargs, **kwargs):
# NOTE(jamielennox): request_kwargs is passed as a kwarg however it is
@ -82,16 +82,18 @@ class MappedKerberos(federation.FederationBaseAuth):
use the standard keystone auth process to scope that to any given project.
"""
def __init__(self, auth_url, identity_provider, protocol,
mutual_auth=None, **kwargs):
def __init__(
self, auth_url, identity_provider, protocol, mutual_auth=None, **kwargs
):
_dependency_check()
self.mutual_auth = mutual_auth
super(MappedKerberos, self).__init__(auth_url, identity_provider,
protocol, **kwargs)
super().__init__(auth_url, identity_provider, protocol, **kwargs)
def get_unscoped_auth_ref(self, session, **kwargs):
resp = session.get(self.federated_token_url,
requests_auth=_requests_auth(self.mutual_auth),
authenticated=False)
resp = session.get(
self.federated_token_url,
requests_auth=_requests_auth(self.mutual_auth),
authenticated=False,
)
return access.create(body=resp.json(), resp=resp)

View File

@ -16,7 +16,6 @@ from keystoneauth1 import loading
class Kerberos(loading.BaseV3Loader):
@property
def plugin_class(self):
return kerberos.Kerberos
@ -26,14 +25,18 @@ class Kerberos(loading.BaseV3Loader):
return kerberos.requests_kerberos is not None
def get_options(self):
options = super(Kerberos, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('mutual-auth',
required=False,
default='optional',
help='Configures Kerberos Mutual Authentication'),
])
options.extend(
[
loading.Opt(
'mutual-auth',
required=False,
default='optional',
help='Configures Kerberos Mutual Authentication',
)
]
)
return options
@ -41,16 +44,17 @@ class Kerberos(loading.BaseV3Loader):
if kwargs.get('mutual_auth'):
value = kwargs.get('mutual_auth')
if not (value.lower() in ['required', 'optional', 'disabled']):
m = ('You need to provide a valid value for kerberos mutual '
'authentication. It can be one of the following: '
'(required, optional, disabled)')
m = (
'You need to provide a valid value for kerberos mutual '
'authentication. It can be one of the following: '
'(required, optional, disabled)'
)
raise exceptions.OptionError(m)
return super(Kerberos, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class MappedKerberos(loading.BaseFederationLoader):
@property
def plugin_class(self):
return kerberos.MappedKerberos
@ -60,14 +64,18 @@ class MappedKerberos(loading.BaseFederationLoader):
return kerberos.requests_kerberos is not None
def get_options(self):
options = super(MappedKerberos, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('mutual-auth',
required=False,
default='optional',
help='Configures Kerberos Mutual Authentication'),
])
options.extend(
[
loading.Opt(
'mutual-auth',
required=False,
default='optional',
help='Configures Kerberos Mutual Authentication',
)
]
)
return options
@ -75,9 +83,11 @@ class MappedKerberos(loading.BaseFederationLoader):
if kwargs.get('mutual_auth'):
value = kwargs.get('mutual_auth')
if not (value.lower() in ['required', 'optional', 'disabled']):
m = ('You need to provide a valid value for kerberos mutual '
'authentication. It can be one of the following: '
'(required, optional, disabled)')
m = (
'You need to provide a valid value for kerberos mutual '
'authentication. It can be one of the following: '
'(required, optional, disabled)'
)
raise exceptions.OptionError(m)
return super(MappedKerberos, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)

View File

@ -17,7 +17,6 @@ from keystoneauth1 import loading
# NOTE(jamielennox): This is not a BaseV3Loader because we don't want to
# include the scoping options like project-id in the option list
class V3OAuth1(loading.BaseIdentityLoader):
@property
def plugin_class(self):
return v3.OAuth1
@ -27,21 +26,25 @@ class V3OAuth1(loading.BaseIdentityLoader):
return v3.oauth1 is not None
def get_options(self):
options = super(V3OAuth1, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('consumer-key',
required=True,
help='OAuth Consumer ID/Key'),
loading.Opt('consumer-secret',
required=True,
help='OAuth Consumer Secret'),
loading.Opt('access-key',
required=True,
help='OAuth Access Key'),
loading.Opt('access-secret',
required=True,
help='OAuth Access Secret'),
])
options.extend(
[
loading.Opt(
'consumer-key', required=True, help='OAuth Consumer ID/Key'
),
loading.Opt(
'consumer-secret',
required=True,
help='OAuth Consumer Secret',
),
loading.Opt(
'access-key', required=True, help='OAuth Access Key'
),
loading.Opt(
'access-secret', required=True, help='OAuth Access Secret'
),
]
)
return options

View File

@ -44,37 +44,46 @@ class OAuth1Method(v3.AuthMethod):
:param string access_secret: Access token secret.
"""
_method_parameters = ['consumer_key', 'consumer_secret',
'access_key', 'access_secret']
_method_parameters = [
'consumer_key',
'consumer_secret',
'access_key',
'access_secret',
]
def get_auth_data(self, session, auth, headers, **kwargs):
# Add the oauth specific content into the headers
oauth_client = oauth1.Client(self.consumer_key,
client_secret=self.consumer_secret,
resource_owner_key=self.access_key,
resource_owner_secret=self.access_secret,
signature_method=oauth1.SIGNATURE_HMAC)
oauth_client = oauth1.Client(
self.consumer_key,
client_secret=self.consumer_secret,
resource_owner_key=self.access_key,
resource_owner_secret=self.access_secret,
signature_method=oauth1.SIGNATURE_HMAC,
)
o_url, o_headers, o_body = oauth_client.sign(auth.token_url,
http_method='POST')
o_url, o_headers, o_body = oauth_client.sign(
auth.token_url, http_method='POST'
)
headers.update(o_headers)
return 'oauth1', {}
def get_cache_id_elements(self):
return dict(('oauth1_%s' % p, getattr(self, p))
for p in self._method_parameters)
return {
f'oauth1_{p}': getattr(self, p) for p in self._method_parameters
}
class OAuth1(v3.AuthConstructor):
_auth_method_class = OAuth1Method
def __init__(self, *args, **kwargs):
super(OAuth1, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
if self.has_scope_parameters:
LOG.warning('Scoping parameters such as a project were provided '
'to the OAuth1 plugin. Because OAuth1 access is '
'always scoped to a project these will be ignored by '
'the identity server')
LOG.warning(
'Scoping parameters such as a project were provided '
'to the OAuth1 plugin. Because OAuth1 access is '
'always scoped to a project these will be ignored by '
'the identity server'
)

View File

@ -33,14 +33,15 @@ V2Token = v2.Token
V3Token = v3.Token
V3FederationToken = v3.V3FederationToken
__all__ = ('DiscoveryList',
'FixtureValidationError',
'LoadingFixture',
'TestPlugin',
'V2Discovery',
'V3Discovery',
'V2Token',
'V3Token',
'V3FederationToken',
'VersionDiscovery',
)
__all__ = (
'DiscoveryList',
'FixtureValidationError',
'LoadingFixture',
'TestPlugin',
'V2Discovery',
'V3Discovery',
'V2Token',
'V3Token',
'V3FederationToken',
'VersionDiscovery',
)

View File

@ -12,11 +12,7 @@
from keystoneauth1 import _utils as utils
__all__ = ('DiscoveryList',
'V2Discovery',
'V3Discovery',
'VersionDiscovery',
)
__all__ = ('DiscoveryList', 'V2Discovery', 'V3Discovery', 'VersionDiscovery')
_DEFAULT_DAYS_AGO = 30
@ -32,7 +28,7 @@ class DiscoveryBase(dict):
"""
def __init__(self, id, status=None, updated=None):
super(DiscoveryBase, self).__init__()
super().__init__()
self.id = id
self.status = status or 'stable'
@ -103,7 +99,7 @@ class VersionDiscovery(DiscoveryBase):
"""
def __init__(self, href, id, **kwargs):
super(VersionDiscovery, self).__init__(id, **kwargs)
super().__init__(id, **kwargs)
self.add_link(href)
@ -122,7 +118,7 @@ class MicroversionDiscovery(DiscoveryBase):
"""
def __init__(self, href, id, min_version='', max_version='', **kwargs):
super(MicroversionDiscovery, self).__init__(id, **kwargs)
super().__init__(id, **kwargs)
self.add_link(href)
@ -160,7 +156,7 @@ class NovaMicroversionDiscovery(DiscoveryBase):
"""
def __init__(self, href, id, min_version=None, version=None, **kwargs):
super(NovaMicroversionDiscovery, self).__init__(id, **kwargs)
super().__init__(id, **kwargs)
self.add_link(href)
@ -204,7 +200,7 @@ class V2Discovery(DiscoveryBase):
_DESC_URL = 'https://developer.openstack.org/api-ref/identity/v2/'
def __init__(self, href, id=None, html=True, pdf=True, **kwargs):
super(V2Discovery, self).__init__(id or 'v2.0', **kwargs)
super().__init__(id or 'v2.0', **kwargs)
self.add_link(href)
@ -219,9 +215,11 @@ class V2Discovery(DiscoveryBase):
The standard structure includes a link to a HTML document with the
API specification. Add it to this entry.
"""
self.add_link(href=self._DESC_URL + 'content',
rel='describedby',
type='text/html')
self.add_link(
href=self._DESC_URL + 'content',
rel='describedby',
type='text/html',
)
def add_pdf_description(self):
"""Add the PDF described by links.
@ -229,9 +227,11 @@ class V2Discovery(DiscoveryBase):
The standard structure includes a link to a PDF document with the
API specification. Add it to this entry.
"""
self.add_link(href=self._DESC_URL + 'identity-dev-guide-2.0.pdf',
rel='describedby',
type='application/pdf')
self.add_link(
href=self._DESC_URL + 'identity-dev-guide-2.0.pdf',
rel='describedby',
type='application/pdf',
)
class V3Discovery(DiscoveryBase):
@ -249,7 +249,7 @@ class V3Discovery(DiscoveryBase):
"""
def __init__(self, href, id=None, json=True, xml=True, **kwargs):
super(V3Discovery, self).__init__(id or 'v3.0', **kwargs)
super().__init__(id or 'v3.0', **kwargs)
self.add_link(href)
@ -264,8 +264,10 @@ class V3Discovery(DiscoveryBase):
The standard structure includes a list of media-types that the endpoint
supports. Add JSON to the list.
"""
self.add_media_type(base='application/json',
type='application/vnd.openstack.identity-v3+json')
self.add_media_type(
base='application/json',
type='application/vnd.openstack.identity-v3+json',
)
def add_xml_media_type(self):
"""Add the XML media-type links.
@ -273,8 +275,10 @@ class V3Discovery(DiscoveryBase):
The standard structure includes a list of media-types that the endpoint
supports. Add XML to the list.
"""
self.add_media_type(base='application/xml',
type='application/vnd.openstack.identity-v3+xml')
self.add_media_type(
base='application/xml',
type='application/vnd.openstack.identity-v3+xml',
)
class DiscoveryList(dict):
@ -298,22 +302,47 @@ class DiscoveryList(dict):
TEST_URL = 'http://keystone.host:5000/'
def __init__(self, href=None, v2=True, v3=True, v2_id=None, v3_id=None,
v2_status=None, v2_updated=None, v2_html=True, v2_pdf=True,
v3_status=None, v3_updated=None, v3_json=True, v3_xml=True):
super(DiscoveryList, self).__init__(versions={'values': []})
def __init__(
self,
href=None,
v2=True,
v3=True,
v2_id=None,
v3_id=None,
v2_status=None,
v2_updated=None,
v2_html=True,
v2_pdf=True,
v3_status=None,
v3_updated=None,
v3_json=True,
v3_xml=True,
):
super().__init__(versions={'values': []})
href = href or self.TEST_URL
if v2:
v2_href = href.rstrip('/') + '/v2.0'
self.add_v2(v2_href, id=v2_id, status=v2_status,
updated=v2_updated, html=v2_html, pdf=v2_pdf)
self.add_v2(
v2_href,
id=v2_id,
status=v2_status,
updated=v2_updated,
html=v2_html,
pdf=v2_pdf,
)
if v3:
v3_href = href.rstrip('/') + '/v3'
self.add_v3(v3_href, id=v3_id, status=v3_status,
updated=v3_updated, json=v3_json, xml=v3_xml)
self.add_v3(
v3_href,
id=v3_id,
status=v3_status,
updated=v3_updated,
json=v3_json,
xml=v3_xml,
)
@property
def versions(self):

View File

@ -25,11 +25,16 @@ from keystoneauth1 import session
class BetamaxFixture(fixtures.Fixture):
def __init__(self, cassette_name, cassette_library_dir=None,
serializer=None, record=False,
pre_record_hook=hooks.pre_record_hook,
serializer_name=None, request_matchers=None):
def __init__(
self,
cassette_name,
cassette_library_dir=None,
serializer=None,
record=False,
pre_record_hook=hooks.pre_record_hook,
serializer_name=None,
request_matchers=None,
):
"""Configure Betamax for the test suite.
:param str cassette_name:
@ -93,10 +98,12 @@ class BetamaxFixture(fixtures.Fixture):
return self._serializer_name
def setUp(self):
super(BetamaxFixture, self).setUp()
super().setUp()
self.mockpatch = mock.patch.object(
session, '_construct_session',
partial(_construct_session_with_betamax, self))
session,
'_construct_session',
partial(_construct_session_with_betamax, self),
)
self.mockpatch.start()
# Unpatch during cleanup
self.addCleanup(self.mockpatch.stop)
@ -116,7 +123,8 @@ def _construct_session_with_betamax(fixture, session_obj=None):
with betamax.Betamax.configure() as config:
config.before_record(callback=fixture.pre_record_hook)
fixture.recorder = betamax.Betamax(
session_obj, cassette_library_dir=fixture.cassette_library_dir)
session_obj, cassette_library_dir=fixture.cassette_library_dir
)
record = 'none'
serializer = None
@ -126,10 +134,12 @@ def _construct_session_with_betamax(fixture, session_obj=None):
serializer = fixture.serializer_name
fixture.recorder.use_cassette(fixture.cassette_name,
serialize_with=serializer,
record=record,
**fixture.use_cassette_kwargs)
fixture.recorder.use_cassette(
fixture.cassette_name,
serialize_with=serializer,
record=record,
**fixture.use_cassette_kwargs,
)
fixture.recorder.start()
fixture.addCleanup(fixture.recorder.stop)

View File

@ -18,10 +18,7 @@ from keystoneauth1 import discover
from keystoneauth1 import loading
from keystoneauth1 import plugin
__all__ = (
'LoadingFixture',
'TestPlugin',
)
__all__ = ('LoadingFixture', 'TestPlugin')
DEFAULT_TEST_ENDPOINT = 'https://openstack.example.com/%(service_type)s'
@ -62,12 +59,10 @@ class TestPlugin(plugin.BaseAuthPlugin):
auth_type = 'test_plugin'
def __init__(self,
token=None,
endpoint=None,
user_id=None,
project_id=None):
super(TestPlugin, self).__init__()
def __init__(
self, token=None, endpoint=None, user_id=None, project_id=None
):
super().__init__()
self.token = token or uuid.uuid4().hex
self.endpoint = endpoint or DEFAULT_TEST_ENDPOINT
@ -98,9 +93,8 @@ class TestPlugin(plugin.BaseAuthPlugin):
class _TestPluginLoader(loading.BaseLoader):
def __init__(self, plugin):
super(_TestPluginLoader, self).__init__()
super().__init__()
self._plugin = plugin
def create_plugin(self, **kwargs):
@ -129,12 +123,10 @@ class LoadingFixture(fixtures.Fixture):
MOCK_POINT = 'keystoneauth1.loading.base.get_plugin_loader'
def __init__(self,
token=None,
endpoint=None,
user_id=None,
project_id=None):
super(LoadingFixture, self).__init__()
def __init__(
self, token=None, endpoint=None, user_id=None, project_id=None
):
super().__init__()
# these are created and saved here so that a test could use them
self.token = token or uuid.uuid4().hex
@ -143,16 +135,19 @@ class LoadingFixture(fixtures.Fixture):
self.project_id = project_id or uuid.uuid4().hex
def setUp(self):
super(LoadingFixture, self).setUp()
super().setUp()
self.useFixture(fixtures.MonkeyPatch(self.MOCK_POINT,
self.get_plugin_loader))
self.useFixture(
fixtures.MonkeyPatch(self.MOCK_POINT, self.get_plugin_loader)
)
def create_plugin(self):
return TestPlugin(token=self.token,
endpoint=self.endpoint,
user_id=self.user_id,
project_id=self.project_id)
return TestPlugin(
token=self.token,
endpoint=self.endpoint,
user_id=self.user_id,
project_id=self.project_id,
)
def get_plugin_loader(self, auth_type):
plugin = self.create_plugin()
@ -171,6 +166,6 @@ class LoadingFixture(fixtures.Fixture):
endpoint = _format_endpoint(self.endpoint, **kwargs)
if path:
endpoint = "%s/%s" % (endpoint.rstrip('/'), path.lstrip('/'))
endpoint = "{}/{}".format(endpoint.rstrip('/'), path.lstrip('/'))
return endpoint

View File

@ -20,7 +20,7 @@ import yaml
def _should_use_block(value):
for c in u"\u000a\u000d\u001c\u001d\u001e\u0085\u2028\u2029":
for c in "\u000a\u000d\u001c\u001d\u001e\u0085\u2028\u2029":
if c in value:
return True
return False
@ -40,7 +40,7 @@ def _represent_scalar(self, tag, value, style=None):
def _unicode_representer(dumper, uni):
node = yaml.ScalarNode(tag=u'tag:yaml.org,2002:str', value=uni)
node = yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=uni)
return node
@ -49,9 +49,12 @@ def _indent_json(val):
return ''
return json.dumps(
json.loads(val), indent=2,
separators=(',', ': '), sort_keys=False,
default=str)
json.loads(val),
indent=2,
separators=(',', ': '),
sort_keys=False,
default=str,
)
def _is_json_body(interaction):
@ -60,13 +63,11 @@ def _is_json_body(interaction):
class YamlJsonSerializer(betamax.serializers.base.BaseSerializer):
name = "yamljson"
@staticmethod
def generate_cassette_name(cassette_library_dir, cassette_name):
return os.path.join(
cassette_library_dir, "{name}.yaml".format(name=cassette_name))
return os.path.join(cassette_library_dir, f"{cassette_name}.yaml")
def serialize(self, cassette_data):
# Reserialize internal json with indentation
@ -74,7 +75,8 @@ class YamlJsonSerializer(betamax.serializers.base.BaseSerializer):
for key in ('request', 'response'):
if _is_json_body(interaction[key]):
interaction[key]['body']['string'] = _indent_json(
interaction[key]['body']['string'])
interaction[key]['body']['string']
)
class MyDumper(yaml.Dumper):
"""Specialized Dumper which does nice blocks and unicode."""
@ -84,7 +86,8 @@ class YamlJsonSerializer(betamax.serializers.base.BaseSerializer):
MyDumper.add_representer(str, _unicode_representer)
return yaml.dump(
cassette_data, Dumper=MyDumper, default_flow_style=False)
cassette_data, Dumper=MyDumper, default_flow_style=False
)
def deserialize(self, cassette_data):
try:

View File

@ -18,15 +18,23 @@ from keystoneauth1.fixture import exception
class _Service(dict):
def add_endpoint(self, public, admin=None, internal=None,
tenant_id=None, region=None, id=None):
data = {'tenantId': tenant_id or uuid.uuid4().hex,
'publicURL': public,
'adminURL': admin or public,
'internalURL': internal or public,
'region': region,
'id': id or uuid.uuid4().hex}
def add_endpoint(
self,
public,
admin=None,
internal=None,
tenant_id=None,
region=None,
id=None,
):
data = {
'tenantId': tenant_id or uuid.uuid4().hex,
'publicURL': public,
'adminURL': admin or public,
'internalURL': internal or public,
'region': region,
'id': id or uuid.uuid4().hex,
}
self.setdefault('endpoints', []).append(data)
return data
@ -41,11 +49,21 @@ class Token(dict):
that matter to them and not copy and paste sample.
"""
def __init__(self, token_id=None, expires=None, issued=None,
tenant_id=None, tenant_name=None, user_id=None,
user_name=None, trust_id=None, trustee_user_id=None,
audit_id=None, audit_chain_id=None):
super(Token, self).__init__()
def __init__(
self,
token_id=None,
expires=None,
issued=None,
tenant_id=None,
tenant_name=None,
user_id=None,
user_name=None,
trust_id=None,
trustee_user_id=None,
audit_id=None,
audit_chain_id=None,
):
super().__init__()
self.token_id = token_id or uuid.uuid4().hex
self.user_id = user_id or uuid.uuid4().hex
@ -75,8 +93,9 @@ class Token(dict):
if trust_id or trustee_user_id:
# the trustee_user_id will generally be the same as the user_id as
# the token is being issued to the trustee
self.set_trust(id=trust_id,
trustee_user_id=trustee_user_id or user_id)
self.set_trust(
id=trust_id, trustee_user_id=trustee_user_id or user_id
)
if audit_chain_id:
self.audit_chain_id = audit_chain_id
@ -237,8 +256,10 @@ class Token(dict):
def remove_service(self, type):
self.root['serviceCatalog'] = [
f for f in self.root.setdefault('serviceCatalog', [])
if f['type'] != type]
f
for f in self.root.setdefault('serviceCatalog', [])
if f['type'] != type
]
def set_scope(self, id=None, name=None):
self.tenant_id = id or uuid.uuid4().hex

View File

@ -25,16 +25,19 @@ class _Service(dict):
"""
def add_endpoint(self, interface, url, region=None, id=None):
data = {'id': id or uuid.uuid4().hex,
'interface': interface,
'url': url,
'region': region,
'region_id': region}
data = {
'id': id or uuid.uuid4().hex,
'interface': interface,
'url': url,
'region': region,
'region_id': region,
}
self.setdefault('endpoints', []).append(data)
return data
def add_standard_endpoints(self, public=None, admin=None, internal=None,
region=None):
def add_standard_endpoints(
self, public=None, admin=None, internal=None, region=None
):
ret = []
if public:
@ -56,18 +59,36 @@ class Token(dict):
that matter to them and not copy and paste sample.
"""
def __init__(self, expires=None, issued=None, user_id=None, user_name=None,
user_domain_id=None, user_domain_name=None, methods=None,
project_id=None, project_name=None, project_domain_id=None,
project_domain_name=None, domain_id=None, domain_name=None,
trust_id=None, trust_impersonation=None, trustee_user_id=None,
trustor_user_id=None, application_credential_id=None,
application_credential_access_rules=None,
oauth_access_token_id=None, oauth_consumer_id=None,
audit_id=None, audit_chain_id=None,
is_admin_project=None, project_is_domain=None,
oauth2_thumbprint=None):
super(Token, self).__init__()
def __init__(
self,
expires=None,
issued=None,
user_id=None,
user_name=None,
user_domain_id=None,
user_domain_name=None,
methods=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None,
domain_id=None,
domain_name=None,
trust_id=None,
trust_impersonation=None,
trustee_user_id=None,
trustor_user_id=None,
application_credential_id=None,
application_credential_access_rules=None,
oauth_access_token_id=None,
oauth_consumer_id=None,
audit_id=None,
audit_chain_id=None,
is_admin_project=None,
project_is_domain=None,
oauth2_thumbprint=None,
):
super().__init__()
self.user_id = user_id or uuid.uuid4().hex
self.user_name = user_name or uuid.uuid4().hex
@ -97,32 +118,47 @@ class Token(dict):
# expires should be able to be passed as a string so ignore
self.expires_str = expires
if (project_id or project_name or
project_domain_id or project_domain_name):
self.set_project_scope(id=project_id,
name=project_name,
domain_id=project_domain_id,
domain_name=project_domain_name,
is_domain=project_is_domain)
if (
project_id
or project_name
or project_domain_id
or project_domain_name
):
self.set_project_scope(
id=project_id,
name=project_name,
domain_id=project_domain_id,
domain_name=project_domain_name,
is_domain=project_is_domain,
)
if domain_id or domain_name:
self.set_domain_scope(id=domain_id, name=domain_name)
if (trust_id or (trust_impersonation is not None) or
trustee_user_id or trustor_user_id):
self.set_trust_scope(id=trust_id,
impersonation=trust_impersonation,
trustee_user_id=trustee_user_id,
trustor_user_id=trustor_user_id)
if (
trust_id
or (trust_impersonation is not None)
or trustee_user_id
or trustor_user_id
):
self.set_trust_scope(
id=trust_id,
impersonation=trust_impersonation,
trustee_user_id=trustee_user_id,
trustor_user_id=trustor_user_id,
)
if application_credential_id:
self.set_application_credential(
application_credential_id,
access_rules=application_credential_access_rules)
access_rules=application_credential_access_rules,
)
if oauth_access_token_id or oauth_consumer_id:
self.set_oauth(access_token_id=oauth_access_token_id,
consumer_id=oauth_consumer_id)
self.set_oauth(
access_token_id=oauth_access_token_id,
consumer_id=oauth_consumer_id,
)
if audit_chain_id:
self.audit_chain_id = audit_chain_id
@ -326,7 +362,8 @@ class Token(dict):
@application_credential_id.setter
def application_credential_id(self, value):
application_credential = self.root.setdefault(
'application_credential', {})
'application_credential', {}
)
application_credential.setdefault('id', value)
@property
@ -336,7 +373,8 @@ class Token(dict):
@application_credential_access_rules.setter
def application_credential_access_rules(self, value):
application_credential = self.root.setdefault(
'application_credential', {})
'application_credential', {}
)
application_credential.setdefault('access_rules', value)
@property
@ -438,8 +476,7 @@ class Token(dict):
def add_role(self, name=None, id=None):
roles = self.root.setdefault('roles', [])
data = {'id': id or uuid.uuid4().hex,
'name': name or uuid.uuid4().hex}
data = {'id': id or uuid.uuid4().hex, 'name': name or uuid.uuid4().hex}
roles.append(data)
return data
@ -453,11 +490,17 @@ class Token(dict):
def remove_service(self, type):
self.root.setdefault('catalog', [])
self.root['catalog'] = [
f for f in self.root.setdefault('catalog', [])
if f['type'] != type]
f for f in self.root.setdefault('catalog', []) if f['type'] != type
]
def set_project_scope(self, id=None, name=None, domain_id=None,
domain_name=None, is_domain=None):
def set_project_scope(
self,
id=None,
name=None,
domain_id=None,
domain_name=None,
is_domain=None,
):
self.project_id = id or uuid.uuid4().hex
self.project_name = name or uuid.uuid4().hex
self.project_domain_id = domain_id or uuid.uuid4().hex
@ -477,8 +520,13 @@ class Token(dict):
# entire system.
self.system = {'all': True}
def set_trust_scope(self, id=None, impersonation=False,
trustee_user_id=None, trustor_user_id=None):
def set_trust_scope(
self,
id=None,
impersonation=False,
trustee_user_id=None,
trustor_user_id=None,
):
self.trust_id = id or uuid.uuid4().hex
self.trust_impersonation = impersonation
self.trustee_user_id = trustee_user_id or uuid.uuid4().hex
@ -488,8 +536,9 @@ class Token(dict):
self.oauth_access_token_id = access_token_id or uuid.uuid4().hex
self.oauth_consumer_id = consumer_id or uuid.uuid4().hex
def set_application_credential(self, application_credential_id,
access_rules=None):
def set_application_credential(
self, application_credential_id, access_rules=None
):
self.application_credential_id = application_credential_id
if access_rules is not None:
self.application_credential_access_rules = access_rules
@ -517,20 +566,22 @@ class V3FederationToken(Token):
FEDERATED_DOMAIN_ID = 'Federated'
def __init__(self, methods=None, identity_provider=None, protocol=None,
groups=None):
def __init__(
self, methods=None, identity_provider=None, protocol=None, groups=None
):
methods = methods or ['saml2']
super(V3FederationToken, self).__init__(methods=methods)
super().__init__(methods=methods)
self._user_domain = {'id': V3FederationToken.FEDERATED_DOMAIN_ID}
self.add_federation_info_to_user(identity_provider, protocol, groups)
def add_federation_info_to_user(self, identity_provider=None,
protocol=None, groups=None):
def add_federation_info_to_user(
self, identity_provider=None, protocol=None, groups=None
):
data = {
"OS-FEDERATION": {
"identity_provider": identity_provider or uuid.uuid4().hex,
"protocol": protocol or uuid.uuid4().hex,
"groups": groups or [{"id": uuid.uuid4().hex}]
"groups": groups or [{"id": uuid.uuid4().hex}],
}
}
self._user.update(data)

View File

@ -18,7 +18,6 @@ errors so that core devs don't have to.
"""
import re
from hacking import core
@ -27,10 +26,11 @@ from hacking import core
@core.flake8ext
def check_oslo_namespace_imports(logical_line, blank_before, filename):
oslo_namespace_imports = re.compile(
r"(((from)|(import))\s+oslo\.)|(from\s+oslo\s+import\s+)")
r"(((from)|(import))\s+oslo\.)|(from\s+oslo\s+import\s+)"
)
if re.match(oslo_namespace_imports, logical_line):
msg = ("K333: '%s' must be used instead of '%s'.") % (
logical_line.replace('oslo.', 'oslo_'),
logical_line)
msg = ("K333: '{}' must be used instead of '{}'.").format(
logical_line.replace('oslo.', 'oslo_'), logical_line
)
yield (0, msg)

View File

@ -25,15 +25,14 @@ class HTTPBasicAuth(plugin.FixedEndpointPlugin):
"""
def __init__(self, endpoint=None, username=None, password=None):
super(HTTPBasicAuth, self).__init__(endpoint)
super().__init__(endpoint)
self.username = username
self.password = password
def get_token(self, session, **kwargs):
if self.username is None or self.password is None:
return None
token = bytes('%s:%s' % (self.username, self.password),
encoding='utf-8')
token = bytes(f'{self.username}:{self.password}', encoding='utf-8')
encoded = base64.b64encode(token)
return str(encoded, encoding='utf-8')
@ -41,5 +40,5 @@ class HTTPBasicAuth(plugin.FixedEndpointPlugin):
token = self.get_token(session)
if not token:
return None
auth = 'Basic %s' % token
auth = f'Basic {token}'
return {AUTH_HEADER_NAME: auth}

View File

@ -70,20 +70,22 @@ V3OAuth2ClientCredential = v3.OAuth2ClientCredential
V3OAuth2mTlsClientCredential = v3.OAuth2mTlsClientCredential
"""See :class:`keystoneauth1.identity.v3.OAuth2mTlsClientCredential`"""
__all__ = ('BaseIdentityPlugin',
'Password',
'Token',
'V2Password',
'V2Token',
'V3Password',
'V3Token',
'V3OidcPassword',
'V3OidcAuthorizationCode',
'V3OidcAccessToken',
'V3OidcDeviceAuthorization',
'V3TOTP',
'V3TokenlessAuth',
'V3ApplicationCredential',
'V3MultiFactor',
'V3OAuth2ClientCredential',
'V3OAuth2mTlsClientCredential')
__all__ = (
'BaseIdentityPlugin',
'Password',
'Token',
'V2Password',
'V2Token',
'V3Password',
'V3Token',
'V3OidcPassword',
'V3OidcAuthorizationCode',
'V3OidcAccessToken',
'V3OidcDeviceAuthorization',
'V3TOTP',
'V3TokenlessAuth',
'V3ApplicationCredential',
'V3MultiFactor',
'V3OAuth2ClientCredential',
'V3OAuth2mTlsClientCredential',
)

View File

@ -31,8 +31,7 @@ class AccessInfoPlugin(base.BaseIdentityPlugin):
"""
def __init__(self, auth_ref, auth_url=None):
super(AccessInfoPlugin, self).__init__(auth_url=auth_url,
reauthenticate=False)
super().__init__(auth_url=auth_url, reauthenticate=False)
self.auth_ref = auth_ref
def get_auth_ref(self, session, **kwargs):

View File

@ -27,14 +27,12 @@ LOG = utils.get_logger(__name__)
class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# we count a token as valid (not needing refreshing) if it is valid for at
# least this many seconds before the token expiry time
MIN_TOKEN_LIFE_SECONDS = 120
def __init__(self, auth_url=None, reauthenticate=True):
super(BaseIdentityPlugin, self).__init__()
super().__init__()
self.auth_url = auth_url
self.auth_ref = None
@ -152,11 +150,22 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
return False
def get_endpoint_data(self, session, service_type=None, interface=None,
region_name=None, service_name=None, allow=None,
allow_version_hack=True, discover_versions=True,
skip_discovery=False, min_version=None,
max_version=None, endpoint_override=None, **kwargs):
def get_endpoint_data(
self,
session,
service_type=None,
interface=None,
region_name=None,
service_name=None,
allow=None,
allow_version_hack=True,
discover_versions=True,
skip_discovery=False,
min_version=None,
max_version=None,
endpoint_override=None,
**kwargs,
):
"""Return a valid endpoint data for a service.
If a valid token is not present then a new one will be fetched using
@ -223,7 +232,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
allow = allow or {}
min_version, max_version = discover._normalize_version_args(
None, min_version, max_version, service_type=service_type)
None, min_version, max_version, service_type=service_type
)
# NOTE(jamielennox): if you specifically ask for requests to be sent to
# the auth url then we can ignore many of the checks. Typically if you
@ -233,7 +243,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
if interface is plugin.AUTH_INTERFACE:
endpoint_data = discover.EndpointData(
service_url=self.auth_url,
service_type=service_type or 'identity')
service_type=service_type or 'identity',
)
project_id = None
elif endpoint_override:
# TODO(mordred) Make a code path that will look for a
@ -246,7 +257,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
catalog_url=endpoint_override,
interface=interface,
region_name=region_name,
service_name=service_name)
service_name=service_name,
)
# Setting an endpoint_override then calling get_endpoint_data means
# you absolutely want the discovery info for the URL in question.
# There are no code flows where this will happen for any other
@ -255,9 +267,11 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
project_id = self.get_project_id(session)
else:
if not service_type:
LOG.warning('Plugin cannot return an endpoint without '
'knowing the service type that is required. Add '
'service_type to endpoint filtering data.')
LOG.warning(
'Plugin cannot return an endpoint without '
'knowing the service type that is required. Add '
'service_type to endpoint filtering data.'
)
return None
# It's possible for things higher in the stack, because of
@ -273,7 +287,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name)
service_name=service_name,
)
if not endpoint_data:
return None
@ -288,10 +303,14 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
max_version=max_version,
cache=self._discovery_cache,
discover_versions=discover_versions,
allow_version_hack=allow_version_hack, allow=allow)
except (exceptions.DiscoveryFailure,
exceptions.HttpError,
exceptions.ConnectionError):
allow_version_hack=allow_version_hack,
allow=allow,
)
except (
exceptions.DiscoveryFailure,
exceptions.HttpError,
exceptions.ConnectionError,
):
# If a version was requested, we didn't find it, return
# None.
if max_version or min_version:
@ -300,12 +319,21 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# should be fine
return endpoint_data
def get_endpoint(self, session, service_type=None, interface=None,
region_name=None, service_name=None, version=None,
allow=None, allow_version_hack=True,
skip_discovery=False,
min_version=None, max_version=None,
**kwargs):
def get_endpoint(
self,
session,
service_type=None,
interface=None,
region_name=None,
service_name=None,
version=None,
allow=None,
allow_version_hack=True,
skip_discovery=False,
min_version=None,
max_version=None,
**kwargs,
):
"""Return a valid endpoint for a service.
If a valid token is not present then a new one will be fetched using
@ -363,26 +391,45 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# Explode `version` into min_version and max_version - everything below
# here uses the latter rather than the former.
min_version, max_version = discover._normalize_version_args(
version, min_version, max_version, service_type=service_type)
version, min_version, max_version, service_type=service_type
)
# Set discover_versions to False since we're only going to return
# a URL. Fetching the microversion data would be needlessly
# expensive in the common case. However, discover_versions=False
# will still run discovery if the version requested is not the
# version in the catalog.
endpoint_data = self.get_endpoint_data(
session, service_type=service_type, interface=interface,
region_name=region_name, service_name=service_name,
allow=allow, min_version=min_version, max_version=max_version,
discover_versions=False, skip_discovery=skip_discovery,
allow_version_hack=allow_version_hack, **kwargs)
session,
service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
allow=allow,
min_version=min_version,
max_version=max_version,
discover_versions=False,
skip_discovery=skip_discovery,
allow_version_hack=allow_version_hack,
**kwargs,
)
return endpoint_data.url if endpoint_data else None
def get_api_major_version(self, session, service_type=None, interface=None,
region_name=None, service_name=None,
version=None, allow=None,
allow_version_hack=True, skip_discovery=False,
discover_versions=False, min_version=None,
max_version=None, **kwargs):
def get_api_major_version(
self,
session,
service_type=None,
interface=None,
region_name=None,
service_name=None,
version=None,
allow=None,
allow_version_hack=True,
skip_discovery=False,
discover_versions=False,
min_version=None,
max_version=None,
**kwargs,
):
"""Return the major API version for a service.
If a valid token is not present then a new one will be fetched using
@ -456,8 +503,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
``block-storage`` service and they do::
client = adapter.Adapter(
session, service_type='block-storage', min_version=2,
max_version=3)
session, service_type='block-storage', min_version=2, max_version=3
)
volume_version = client.get_api_major_version()
The version actually be returned with no api calls other than getting
@ -485,15 +532,23 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
# Explode `version` into min_version and max_version - everything below
# here uses the latter rather than the former.
min_version, max_version = discover._normalize_version_args(
version, min_version, max_version, service_type=service_type)
version, min_version, max_version, service_type=service_type
)
# Using functools.partial here just to reduce copy-pasta of params
get_endpoint_data = functools.partial(
self.get_endpoint_data,
session, service_type=service_type, interface=interface,
region_name=region_name, service_name=service_name,
allow=allow, min_version=min_version, max_version=max_version,
session,
service_type=service_type,
interface=interface,
region_name=region_name,
service_name=service_name,
allow=allow,
min_version=min_version,
max_version=max_version,
skip_discovery=skip_discovery,
allow_version_hack=allow_version_hack, **kwargs)
allow_version_hack=allow_version_hack,
**kwargs,
)
data = get_endpoint_data(discover_versions=discover_versions)
if (not data or not data.api_version) and not discover_versions:
# It's possible that no version was requested and the endpoint
@ -505,9 +560,14 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
return None
return data.api_version
def get_all_version_data(self, session, interface='public',
region_name=None, service_type=None,
**kwargs):
def get_all_version_data(
self,
session,
interface='public',
region_name=None,
service_type=None,
**kwargs,
):
"""Get version data for all services in the catalog.
:param session: A session object that can be used for communication.
@ -539,12 +599,12 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
for endpoint_service_type, services in endpoints_data.items():
if service_types.is_known(endpoint_service_type):
endpoint_service_type = service_types.get_service_type(
endpoint_service_type)
endpoint_service_type
)
for service in services:
versions = service.get_all_version_string_data(
session=session,
project_id=self.get_project_id(session),
session=session, project_id=self.get_project_id(session)
)
if service.region_name not in version_data:
@ -566,15 +626,15 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
def get_sp_auth_url(self, session, sp_id, **kwargs):
try:
return self.get_access(
session).service_providers.get_auth_url(sp_id)
return self.get_access(session).service_providers.get_auth_url(
sp_id
)
except exceptions.ServiceProviderNotFound:
return None
def get_sp_url(self, session, sp_id, **kwargs):
try:
return self.get_access(
session).service_providers.get_sp_url(sp_id)
return self.get_access(session).service_providers.get_sp_url(sp_id)
except exceptions.ServiceProviderNotFound:
return None
@ -602,9 +662,12 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
:returns: A discovery object with the results of looking up that URL.
"""
return discover.get_discovery(session=session, url=url,
cache=self._discovery_cache,
authenticated=authenticated)
return discover.get_discovery(
session=session,
url=url,
cache=self._discovery_cache,
authenticated=authenticated,
)
def get_cache_id_elements(self):
"""Get the elements for this auth plugin that make it unique.
@ -667,8 +730,10 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
:rtype: str or None if no auth present.
"""
if self.auth_ref:
data = {'auth_token': self.auth_ref.auth_token,
'body': self.auth_ref._data}
data = {
'auth_token': self.auth_ref.auth_token,
'body': self.auth_ref._data,
}
return json.dumps(data)
@ -680,7 +745,8 @@ class BaseIdentityPlugin(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
"""
if data:
auth_data = json.loads(data)
self.auth_ref = access.create(body=auth_data['body'],
auth_token=auth_data['auth_token'])
self.auth_ref = access.create(
body=auth_data['body'], auth_token=auth_data['auth_token']
)
else:
self.auth_ref = None

View File

@ -15,7 +15,4 @@ from keystoneauth1.identity.generic.password import Password # noqa
from keystoneauth1.identity.generic.token import Token # noqa
__all__ = ('BaseGenericPlugin',
'Password',
'Token',
)
__all__ = ('BaseGenericPlugin', 'Password', 'Token')

View File

@ -29,22 +29,24 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
URL and then proxy all calls from the base plugin to the versioned one.
"""
def __init__(self, auth_url,
tenant_id=None,
tenant_name=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None,
domain_id=None,
domain_name=None,
system_scope=None,
trust_id=None,
default_domain_id=None,
default_domain_name=None,
reauthenticate=True):
super(BaseGenericPlugin, self).__init__(auth_url=auth_url,
reauthenticate=reauthenticate)
def __init__(
self,
auth_url,
tenant_id=None,
tenant_name=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None,
domain_id=None,
domain_name=None,
system_scope=None,
trust_id=None,
default_domain_id=None,
default_domain_name=None,
reauthenticate=True,
):
super().__init__(auth_url=auth_url, reauthenticate=reauthenticate)
self._project_id = project_id or tenant_id
self._project_name = project_name or tenant_name
@ -86,29 +88,39 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
:returns: True if a domain parameter is set, false otherwise.
"""
return any([self._domain_id, self._domain_name,
self._project_domain_id, self._project_domain_name])
return any(
[
self._domain_id,
self._domain_name,
self._project_domain_id,
self._project_domain_name,
]
)
@property
def _v2_params(self):
"""Return the parameters that are common to v2 plugins."""
return {'trust_id': self._trust_id,
'tenant_id': self._project_id,
'tenant_name': self._project_name,
'reauthenticate': self.reauthenticate}
return {
'trust_id': self._trust_id,
'tenant_id': self._project_id,
'tenant_name': self._project_name,
'reauthenticate': self.reauthenticate,
}
@property
def _v3_params(self):
"""Return the parameters that are common to v3 plugins."""
return {'trust_id': self._trust_id,
'system_scope': self._system_scope,
'project_id': self._project_id,
'project_name': self._project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name,
'domain_id': self._domain_id,
'domain_name': self._domain_name,
'reauthenticate': self.reauthenticate}
return {
'trust_id': self._trust_id,
'system_scope': self._system_scope,
'project_id': self._project_id,
'project_name': self._project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name,
'domain_id': self._domain_id,
'domain_name': self._domain_name,
'reauthenticate': self.reauthenticate,
}
@property
def project_domain_id(self):
@ -130,16 +142,20 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
plugin = None
try:
disc = self.get_discovery(session,
self.auth_url,
authenticated=False)
except (exceptions.DiscoveryFailure,
exceptions.HttpError,
exceptions.SSLError,
exceptions.ConnectionError) as e:
LOG.warning('Failed to discover available identity versions when '
'contacting %s. Attempting to parse version from URL.',
self.auth_url)
disc = self.get_discovery(
session, self.auth_url, authenticated=False
)
except (
exceptions.DiscoveryFailure,
exceptions.HttpError,
exceptions.SSLError,
exceptions.ConnectionError,
) as e:
LOG.warning(
'Failed to discover available identity versions when '
'contacting %s. Attempting to parse version from URL.',
self.auth_url,
)
url_parts = urllib.parse.urlparse(self.auth_url)
path = url_parts.path.lower()
@ -147,7 +163,8 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
if path.startswith('/v2.0'):
if self._has_domain_scope:
raise exceptions.DiscoveryFailure(
'Cannot use v2 authentication with domain scope')
'Cannot use v2 authentication with domain scope'
)
plugin = self.create_plugin(session, (2, 0), self.auth_url)
elif path.startswith('/v3'):
plugin = self.create_plugin(session, (3, 0), self.auth_url)
@ -155,7 +172,8 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
raise exceptions.DiscoveryFailure(
'Could not find versioned identity endpoints when '
'attempting to authenticate. Please check that your '
'auth_url is correct. %s' % e)
f'auth_url is correct. {e}'
)
else:
# NOTE(jamielennox): version_data is always in oldest to newest
@ -172,23 +190,28 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
for data in disc_data:
version = data['version']
if (discover.version_match((2,), version) and
self._has_domain_scope):
if (
discover.version_match((2,), version)
and self._has_domain_scope
):
# NOTE(jamielennox): if there are domain parameters there
# is no point even trying against v2 APIs.
v2_with_domain_scope = True
continue
plugin = self.create_plugin(session,
version,
data['url'],
raw_status=data['raw_status'])
plugin = self.create_plugin(
session,
version,
data['url'],
raw_status=data['raw_status'],
)
if plugin:
break
if not plugin and v2_with_domain_scope:
raise exceptions.DiscoveryFailure(
'Cannot use v2 authentication with domain scope')
'Cannot use v2 authentication with domain scope'
)
if plugin:
return plugin
@ -196,7 +219,8 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
# so there were no URLs that i could use for auth of any version.
raise exceptions.DiscoveryFailure(
'Could not find versioned identity endpoints when attempting '
'to authenticate. Please check that your auth_url is correct.')
'to authenticate. Please check that your auth_url is correct.'
)
def get_auth_ref(self, session, **kwargs):
if not self._plugin:
@ -212,11 +236,13 @@ class BaseGenericPlugin(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
if not _implemented:
raise NotImplementedError()
return {'auth_url': self.auth_url,
'project_id': self._project_id,
'project_name': self._project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name,
'domain_id': self._domain_id,
'domain_name': self._domain_name,
'trust_id': self._trust_id}
return {
'auth_url': self.auth_url,
'project_id': self._project_id,
'project_name': self._project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name,
'domain_id': self._domain_id,
'domain_name': self._domain_name,
'trust_id': self._trust_id,
}

View File

@ -27,9 +27,17 @@ class Password(base.BaseGenericPlugin):
"""
def __init__(self, auth_url, username=None, user_id=None, password=None,
user_domain_id=None, user_domain_name=None, **kwargs):
super(Password, self).__init__(auth_url=auth_url, **kwargs)
def __init__(
self,
auth_url,
username=None,
user_id=None,
password=None,
user_domain_id=None,
user_domain_name=None,
**kwargs,
):
super().__init__(auth_url=auth_url, **kwargs)
self._username = username
self._user_id = user_id
@ -42,23 +50,27 @@ class Password(base.BaseGenericPlugin):
if self._user_domain_id or self._user_domain_name:
return None
return v2.Password(auth_url=url,
user_id=self._user_id,
username=self._username,
password=self._password,
**self._v2_params)
return v2.Password(
auth_url=url,
user_id=self._user_id,
username=self._username,
password=self._password,
**self._v2_params,
)
elif discover.version_match((3,), version):
u_domain_id = self._user_domain_id or self._default_domain_id
u_domain_name = self._user_domain_name or self._default_domain_name
return v3.Password(auth_url=url,
user_id=self._user_id,
username=self._username,
user_domain_id=u_domain_id,
user_domain_name=u_domain_name,
password=self._password,
**self._v3_params)
return v3.Password(
auth_url=url,
user_id=self._user_id,
username=self._username,
user_domain_id=u_domain_id,
user_domain_name=u_domain_name,
password=self._password,
**self._v3_params,
)
@property
def user_domain_id(self):
@ -77,8 +89,7 @@ class Password(base.BaseGenericPlugin):
self._user_domain_name = value
def get_cache_id_elements(self):
elements = super(Password, self).get_cache_id_elements(
_implemented=True)
elements = super().get_cache_id_elements(_implemented=True)
elements['username'] = self._username
elements['user_id'] = self._user_id
elements['password'] = self._password

View File

@ -23,7 +23,7 @@ class Token(base.BaseGenericPlugin):
"""
def __init__(self, auth_url, token=None, **kwargs):
super(Token, self).__init__(auth_url, **kwargs)
super().__init__(auth_url, **kwargs)
self._token = token
def create_plugin(self, session, version, url, raw_status=None):
@ -34,6 +34,6 @@ class Token(base.BaseGenericPlugin):
return v3.Token(url, self._token, **self._v3_params)
def get_cache_id_elements(self):
elements = super(Token, self).get_cache_id_elements(_implemented=True)
elements = super().get_cache_id_elements(_implemented=True)
elements['token'] = self._token
return elements

View File

@ -31,13 +31,15 @@ class Auth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
is going to expire. (optional) default True
"""
def __init__(self, auth_url,
trust_id=None,
tenant_id=None,
tenant_name=None,
reauthenticate=True):
super(Auth, self).__init__(auth_url=auth_url,
reauthenticate=reauthenticate)
def __init__(
self,
auth_url,
trust_id=None,
tenant_id=None,
tenant_name=None,
reauthenticate=True,
):
super().__init__(auth_url=auth_url, reauthenticate=reauthenticate)
self.trust_id = trust_id
self.tenant_id = tenant_id
@ -56,8 +58,9 @@ class Auth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
params['auth']['trust_id'] = self.trust_id
_logger.debug('Making authentication request to %s', url)
resp = session.post(url, json=params, headers=headers,
authenticated=False, log=False)
resp = session.post(
url, json=params, headers=headers, authenticated=False, log=False
)
try:
resp_data = resp.json()
@ -106,9 +109,15 @@ class Password(Auth):
:raises TypeError: if a user_id or username is not provided.
"""
def __init__(self, auth_url, username=_NOT_PASSED, password=None,
user_id=_NOT_PASSED, **kwargs):
super(Password, self).__init__(auth_url, **kwargs)
def __init__(
self,
auth_url,
username=_NOT_PASSED,
password=None,
user_id=_NOT_PASSED,
**kwargs,
):
super().__init__(auth_url, **kwargs)
if username is _NOT_PASSED and user_id is _NOT_PASSED:
msg = 'You need to specify either a username or user_id'
@ -134,13 +143,15 @@ class Password(Auth):
return {'passwordCredentials': auth}
def get_cache_id_elements(self):
return {'username': self.username,
'user_id': self.user_id,
'password': self.password,
'auth_url': self.auth_url,
'tenant_id': self.tenant_id,
'tenant_name': self.tenant_name,
'trust_id': self.trust_id}
return {
'username': self.username,
'user_id': self.user_id,
'password': self.password,
'auth_url': self.auth_url,
'tenant_id': self.tenant_id,
'tenant_name': self.tenant_name,
'trust_id': self.trust_id,
}
class Token(Auth):
@ -156,7 +167,7 @@ class Token(Auth):
"""
def __init__(self, auth_url, token, **kwargs):
super(Token, self).__init__(auth_url, **kwargs)
super().__init__(auth_url, **kwargs)
self.token = token
def get_auth_data(self, headers=None):
@ -165,8 +176,10 @@ class Token(Auth):
return {'token': {'id': self.token}}
def get_cache_id_elements(self):
return {'token': self.token,
'auth_url': self.auth_url,
'tenant_id': self.tenant_id,
'tenant_name': self.tenant_name,
'trust_id': self.trust_id}
return {
'token': self.token,
'auth_url': self.auth_url,
'tenant_id': self.tenant_id,
'tenant_name': self.tenant_name,
'trust_id': self.trust_id,
}

View File

@ -27,40 +27,29 @@ from keystoneauth1.identity.v3.oauth2_client_credential import * # noqa
from keystoneauth1.identity.v3.oauth2_mtls_client_credential import * # noqa
__all__ = ('ApplicationCredential',
'ApplicationCredentialMethod',
'Auth',
'AuthConstructor',
'AuthMethod',
'BaseAuth',
'FederationBaseAuth',
'Keystone2Keystone',
'Password',
'PasswordMethod',
'Token',
'TokenMethod',
'OidcAccessToken',
'OidcAuthorizationCode',
'OidcClientCredentials',
'OidcPassword',
'TOTPMethod',
'TOTP',
'TokenlessAuth',
'ReceiptMethod',
'MultiFactor',
'OAuth2ClientCredential',
'OAuth2ClientCredentialMethod',
'OAuth2mTlsClientCredential',
)
__all__ = (
'ApplicationCredential',
'ApplicationCredentialMethod',
'Auth',
'AuthConstructor',
'AuthMethod',
'BaseAuth',
'FederationBaseAuth',
'Keystone2Keystone',
'Password',
'PasswordMethod',
'Token',
'TokenMethod',
'OidcAccessToken',
'OidcAuthorizationCode',
'OidcClientCredentials',
'OidcPassword',
'TOTPMethod',
'TOTP',
'TokenlessAuth',
'ReceiptMethod',
'MultiFactor',
'OAuth2ClientCredential',
'OAuth2ClientCredentialMethod',
'OAuth2mTlsClientCredential',
)

View File

@ -37,13 +37,15 @@ class ApplicationCredentialMethod(base.AuthMethod):
provided.
"""
_method_parameters = ['application_credential_secret',
'application_credential_id',
'application_credential_name',
'user_id',
'username',
'user_domain_id',
'user_domain_name']
_method_parameters = [
'application_credential_secret',
'application_credential_id',
'application_credential_name',
'user_id',
'username',
'user_domain_id',
'user_domain_name',
]
def get_auth_data(self, session, auth, headers, **kwargs):
auth_data = {'secret': self.application_credential_secret}
@ -62,13 +64,16 @@ class ApplicationCredentialMethod(base.AuthMethod):
auth_data['user']['domain'] = {'id': self.user_domain_id}
elif self.user_domain_name:
auth_data['user']['domain'] = {
'name': self.user_domain_name}
'name': self.user_domain_name
}
return 'application_credential', auth_data
def get_cache_id_elements(self):
return dict(('application_credential_%s' % p, getattr(self, p))
for p in self._method_parameters)
return {
f'application_credential_{p}': getattr(self, p)
for p in self._method_parameters
}
class ApplicationCredential(base.AuthConstructor):

View File

@ -41,19 +41,21 @@ class BaseAuth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
token. (optional) default True.
"""
def __init__(self, auth_url,
trust_id=None,
system_scope=None,
domain_id=None,
domain_name=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None,
reauthenticate=True,
include_catalog=True):
super(BaseAuth, self).__init__(auth_url=auth_url,
reauthenticate=reauthenticate)
def __init__(
self,
auth_url,
trust_id=None,
system_scope=None,
domain_id=None,
domain_name=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None,
reauthenticate=True,
include_catalog=True,
):
super().__init__(auth_url=auth_url, reauthenticate=reauthenticate)
self.trust_id = trust_id
self.system_scope = system_scope
self.domain_id = domain_id
@ -67,7 +69,7 @@ class BaseAuth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
@property
def token_url(self):
"""The full URL where we will send authentication data."""
return '%s/auth/tokens' % self.auth_url.rstrip('/')
return '{}/auth/tokens'.format(self.auth_url.rstrip('/'))
@abc.abstractmethod
def get_auth_ref(self, session, **kwargs):
@ -76,9 +78,14 @@ class BaseAuth(base.BaseIdentityPlugin, metaclass=abc.ABCMeta):
@property
def has_scope_parameters(self):
"""Return true if parameters can be used to create a scoped token."""
return (self.domain_id or self.domain_name or
self.project_id or self.project_name or
self.trust_id or self.system_scope)
return (
self.domain_id
or self.domain_name
or self.project_id
or self.project_name
or self.trust_id
or self.system_scope
)
class Auth(BaseAuth):
@ -104,7 +111,7 @@ class Auth(BaseAuth):
def __init__(self, auth_url, auth_methods, **kwargs):
self.unscoped = kwargs.pop('unscoped', False)
super(Auth, self).__init__(auth_url=auth_url, **kwargs)
super().__init__(auth_url=auth_url, **kwargs)
self.auth_methods = auth_methods
def add_method(self, method):
@ -119,7 +126,8 @@ class Auth(BaseAuth):
for method in self.auth_methods:
name, auth_data = method.get_auth_data(
session, self, headers, request_kwargs=rkwargs)
session, self, headers, request_kwargs=rkwargs
)
# NOTE(adriant): Methods like ReceiptMethod don't
# want anything added to the request data, so they
# explicitly return None, which we check for.
@ -129,19 +137,23 @@ class Auth(BaseAuth):
if not ident:
raise exceptions.AuthorizationFailure(
'Authentication method required (e.g. password)')
'Authentication method required (e.g. password)'
)
mutual_exclusion = [bool(self.domain_id or self.domain_name),
bool(self.project_id or self.project_name),
bool(self.trust_id),
bool(self.system_scope),
bool(self.unscoped)]
mutual_exclusion = [
bool(self.domain_id or self.domain_name),
bool(self.project_id or self.project_name),
bool(self.trust_id),
bool(self.system_scope),
bool(self.unscoped),
]
if sum(mutual_exclusion) > 1:
raise exceptions.AuthorizationFailure(
message='Authentication cannot be scoped to multiple'
' targets. Pick one of: project, domain, '
'trust, system or unscoped')
' targets. Pick one of: project, domain, '
'trust, system or unscoped'
)
if self.domain_id:
body['auth']['scope'] = {'domain': {'id': self.domain_id}}
@ -174,7 +186,7 @@ class Auth(BaseAuth):
token_url = self.token_url
if not self.auth_url.rstrip('/').endswith('v3'):
token_url = '%s/v3/auth/tokens' % self.auth_url.rstrip('/')
token_url = '{}/v3/auth/tokens'.format(self.auth_url.rstrip('/'))
# NOTE(jamielennox): we add nocatalog here rather than in token_url
# directly as some federation plugins require the base token_url
@ -182,8 +194,14 @@ class Auth(BaseAuth):
token_url += '?nocatalog'
_logger.debug('Making authentication request to %s', token_url)
resp = session.post(token_url, json=body, headers=headers,
authenticated=False, log=False, **rkwargs)
resp = session.post(
token_url,
json=body,
headers=headers,
authenticated=False,
log=False,
**rkwargs,
)
try:
_logger.debug(json.dumps(resp.json()))
@ -194,21 +212,24 @@ class Auth(BaseAuth):
if 'token' not in resp_data:
raise exceptions.InvalidResponse(response=resp)
return access.AccessInfoV3(auth_token=resp.headers['X-Subject-Token'],
body=resp_data)
return access.AccessInfoV3(
auth_token=resp.headers['X-Subject-Token'], body=resp_data
)
def get_cache_id_elements(self):
if not self.auth_methods:
return None
params = {'auth_url': self.auth_url,
'domain_id': self.domain_id,
'domain_name': self.domain_name,
'project_id': self.project_id,
'project_name': self.project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name,
'trust_id': self.trust_id}
params = {
'auth_url': self.auth_url,
'domain_id': self.domain_id,
'domain_name': self.domain_name,
'project_id': self.project_id,
'project_name': self.project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name,
'trust_id': self.trust_id,
}
for method in self.auth_methods:
try:
@ -240,14 +261,13 @@ class AuthMethod(metaclass=abc.ABCMeta):
setattr(self, param, kwargs.pop(param, None))
if kwargs:
msg = "Unexpected Attributes: %s" % ", ".join(kwargs.keys())
msg = "Unexpected Attributes: {}".format(", ".join(kwargs.keys()))
raise AttributeError(msg)
@classmethod
def _extract_kwargs(cls, kwargs):
"""Remove parameters related to this method from other kwargs."""
return dict([(p, kwargs.pop(p, None))
for p in cls._method_parameters])
return {p: kwargs.pop(p, None) for p in cls._method_parameters}
@abc.abstractmethod
def get_auth_data(self, session, auth, headers, **kwargs):
@ -296,4 +316,4 @@ class AuthConstructor(Auth, metaclass=abc.ABCMeta):
def __init__(self, auth_url, *args, **kwargs):
method_kwargs = self._auth_method_class._extract_kwargs(kwargs)
method = self._auth_method_class(*args, **method_kwargs)
super(AuthConstructor, self).__init__(auth_url, [method], **kwargs)
super().__init__(auth_url, [method], **kwargs)

View File

@ -35,13 +35,15 @@ class _Rescoped(base.BaseAuth, metaclass=abc.ABCMeta):
rescoping_plugin = token.Token
def _get_scoping_data(self):
return {'trust_id': self.trust_id,
'domain_id': self.domain_id,
'domain_name': self.domain_name,
'project_id': self.project_id,
'project_name': self.project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name}
return {
'trust_id': self.trust_id,
'domain_id': self.domain_id,
'domain_name': self.domain_name,
'project_id': self.project_id,
'project_name': self.project_name,
'project_domain_id': self.project_domain_id,
'project_domain_name': self.project_domain_name,
}
def get_auth_ref(self, session, **kwargs):
"""Authenticate retrieve token information.
@ -63,9 +65,9 @@ class _Rescoped(base.BaseAuth, metaclass=abc.ABCMeta):
scoping = self._get_scoping_data()
if any(scoping.values()):
token_plugin = self.rescoping_plugin(self.auth_url,
token=auth_ref.auth_token,
**scoping)
token_plugin = self.rescoping_plugin(
self.auth_url, token=auth_ref.auth_token, **scoping
)
auth_ref = token_plugin.get_auth_ref(session)
@ -93,7 +95,7 @@ class FederationBaseAuth(_Rescoped):
"""
def __init__(self, auth_url, identity_provider, protocol, **kwargs):
super(FederationBaseAuth, self).__init__(auth_url=auth_url, **kwargs)
super().__init__(auth_url=auth_url, **kwargs)
self.identity_provider = identity_provider
self.protocol = protocol
@ -106,10 +108,12 @@ class FederationBaseAuth(_Rescoped):
values = {
'host': host,
'identity_provider': self.identity_provider,
'protocol': self.protocol
'protocol': self.protocol,
}
url = ("%(host)s/OS-FEDERATION/identity_providers/"
"%(identity_provider)s/protocols/%(protocol)s/auth")
url = (
"%(host)s/OS-FEDERATION/identity_providers/"
"%(identity_provider)s/protocols/%(protocol)s/auth"
)
url = url % values
return url

View File

@ -43,7 +43,7 @@ class Keystone2Keystone(federation._Rescoped):
HTTP_SEE_OTHER = 303
def __init__(self, base_plugin, service_provider, **kwargs):
super(Keystone2Keystone, self).__init__(auth_url=None, **kwargs)
super().__init__(auth_url=None, **kwargs)
self._local_cloud_plugin = base_plugin
self._sp_id = service_provider
@ -81,36 +81,38 @@ class Keystone2Keystone(federation._Rescoped):
'methods': ['token'],
'token': {
'id': self._local_cloud_plugin.get_token(session)
}
},
},
'scope': {
'service_provider': {
'id': self._sp_id
}
}
'scope': {'service_provider': {'id': self._sp_id}},
}
}
endpoint_filter = {'version': (3, 0),
'interface': plugin.AUTH_INTERFACE}
endpoint_filter = {
'version': (3, 0),
'interface': plugin.AUTH_INTERFACE,
}
headers = {'Accept': 'application/json'}
resp = session.post(self.REQUEST_ECP_URL,
json=body,
auth=self._local_cloud_plugin,
endpoint_filter=endpoint_filter,
headers=headers,
authenticated=False,
raise_exc=False)
resp = session.post(
self.REQUEST_ECP_URL,
json=body,
auth=self._local_cloud_plugin,
endpoint_filter=endpoint_filter,
headers=headers,
authenticated=False,
raise_exc=False,
)
# NOTE(marek-denis): I am not sure whether disabling exceptions in the
# Session object and testing if resp.ok is sufficient. An alternative
# would be catching locally all exceptions and reraising with custom
# warning.
if not resp.ok:
msg = ("Error while requesting ECP wrapped assertion: response "
"exit code: %(status_code)d, reason: %(err)s")
msg = (
"Error while requesting ECP wrapped assertion: response "
"exit code: %(status_code)d, reason: %(err)s"
)
msg = msg % {'status_code': resp.status_code, 'err': resp.reason}
raise exceptions.AuthorizationFailure(msg)
@ -119,8 +121,9 @@ class Keystone2Keystone(federation._Rescoped):
return str(resp.text)
def _send_service_provider_ecp_authn_response(self, session, sp_url,
sp_auth_url):
def _send_service_provider_ecp_authn_response(
self, session, sp_url, sp_auth_url
):
"""Present ECP wrapped SAML assertion to the keystone SP.
The assertion is issued by the keystone IdP and it is targeted to the
@ -145,27 +148,33 @@ class Keystone2Keystone(federation._Rescoped):
headers={'Content-Type': 'application/vnd.paos+xml'},
data=self._get_ecp_assertion(session),
authenticated=False,
redirect=False)
redirect=False,
)
# Don't follow HTTP specs - after the HTTP 302/303 response don't
# repeat the call directed to the Location URL. In this case, this is
# an indication that SAML2 session is now active and protected resource
# can be accessed.
if response.status_code in (self.HTTP_MOVED_TEMPORARILY,
self.HTTP_SEE_OTHER):
if response.status_code in (
self.HTTP_MOVED_TEMPORARILY,
self.HTTP_SEE_OTHER,
):
response = session.get(
sp_auth_url,
headers={'Content-Type': 'application/vnd.paos+xml'},
authenticated=False)
authenticated=False,
)
return response
def get_unscoped_auth_ref(self, session, **kwargs):
sp_auth_url = self._local_cloud_plugin.get_sp_auth_url(
session, self._sp_id)
session, self._sp_id
)
sp_url = self._local_cloud_plugin.get_sp_url(session, self._sp_id)
self.auth_url = self._remote_auth_url(sp_auth_url)
response = self._send_service_provider_ecp_authn_response(
session, sp_url, sp_auth_url)
session, sp_url, sp_auth_url
)
return access.create(resp=response)

View File

@ -14,7 +14,7 @@ from keystoneauth1.identity.v3 import base
from keystoneauth1 import loading
__all__ = ('MultiFactor', )
__all__ = ('MultiFactor',)
class MultiFactor(base.Auth):
@ -42,7 +42,8 @@ class MultiFactor(base.Auth):
for method in auth_methods:
# Using the loaders we pull the related auth method class
method_class = loading.get_plugin_loader(
method).plugin_class._auth_method_class
method
).plugin_class._auth_method_class
# We build some new kwargs for the method from required parameters
method_kwargs = {}
for key in method_class._method_parameters:
@ -56,4 +57,4 @@ class MultiFactor(base.Auth):
# to the super class and throw errors
for key in method_keys:
kwargs.pop(key, None)
super(MultiFactor, self).__init__(auth_url, method_instances, **kwargs)
super().__init__(auth_url, method_instances, **kwargs)

View File

@ -31,7 +31,7 @@ class OAuth2ClientCredentialMethod(base.AuthMethod):
_method_parameters = [
'oauth2_endpoint',
'oauth2_client_id',
'oauth2_client_secret'
'oauth2_client_secret',
]
def get_auth_data(self, session, auth, headers, **kwargs):
@ -48,7 +48,7 @@ class OAuth2ClientCredentialMethod(base.AuthMethod):
"""
auth_data = {
'id': self.oauth2_client_id,
'secret': self.oauth2_client_secret
'secret': self.oauth2_client_secret,
}
return 'application_credential', auth_data
@ -66,8 +66,10 @@ class OAuth2ClientCredentialMethod(base.AuthMethod):
should be prefixed with the plugin identifier. For example the password
plugin returns its username value as 'password_username'.
"""
return dict(('oauth2_client_credential_%s' % p, getattr(self, p))
for p in self._method_parameters)
return {
f'oauth2_client_credential_{p}': getattr(self, p)
for p in self._method_parameters
}
class OAuth2ClientCredential(base.AuthConstructor):
@ -82,7 +84,7 @@ class OAuth2ClientCredential(base.AuthConstructor):
_auth_method_class = OAuth2ClientCredentialMethod
def __init__(self, auth_url, *args, **kwargs):
super(OAuth2ClientCredential, self).__init__(auth_url, *args, **kwargs)
super().__init__(auth_url, *args, **kwargs)
self._oauth2_endpoint = kwargs['oauth2_endpoint']
self._oauth2_client_id = kwargs['oauth2_client_id']
self._oauth2_client_secret = kwargs['oauth2_client_secret']
@ -99,19 +101,21 @@ class OAuth2ClientCredential(base.AuthConstructor):
:rtype: dict
"""
# get headers for X-Auth-Token
headers = super(OAuth2ClientCredential, self).get_headers(
session, **kwargs)
headers = super().get_headers(session, **kwargs)
# Get OAuth2.0 access token and add the field 'Authorization'
data = {"grant_type": "client_credentials"}
auth = requests.auth.HTTPBasicAuth(self._oauth2_client_id,
self._oauth2_client_secret)
resp = session.request(self._oauth2_endpoint,
"POST",
authenticated=False,
raise_exc=False,
data=data,
requests_auth=auth)
auth = requests.auth.HTTPBasicAuth(
self._oauth2_client_id, self._oauth2_client_secret
)
resp = session.request(
self._oauth2_endpoint,
"POST",
authenticated=False,
raise_exc=False,
data=data,
requests_auth=auth,
)
if resp.status_code == 200:
oauth2 = resp.json()
oauth2_token = oauth2["access_token"]

View File

@ -27,10 +27,10 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
:param string oauth2_client_id: OAuth2.0 client credential id.
"""
def __init__(self, auth_url, oauth2_endpoint, oauth2_client_id,
*args, **kwargs):
super(OAuth2mTlsClientCredential, self).__init__(
auth_url, *args, **kwargs)
def __init__(
self, auth_url, oauth2_endpoint, oauth2_client_id, *args, **kwargs
):
super().__init__(auth_url, *args, **kwargs)
self.auth_url = auth_url
self.oauth2_endpoint = oauth2_endpoint
self.oauth2_client_id = oauth2_client_id
@ -64,12 +64,16 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
"""
# Get OAuth2.0 access token and add the field 'Authorization' when
# using the HTTPS protocol.
data = {'grant_type': 'client_credentials',
'client_id': self.oauth2_client_id}
resp = session.post(url=self.oauth2_endpoint,
authenticated=False,
raise_exc=False,
data=data)
data = {
'grant_type': 'client_credentials',
'client_id': self.oauth2_client_id,
}
resp = session.post(
url=self.oauth2_endpoint,
authenticated=False,
raise_exc=False,
data=data,
)
if resp.status_code == 200:
oauth2 = resp.json()
self.oauth2_access_token = oauth2.get('access_token')
@ -78,17 +82,18 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
msg = error.get('error_description')
raise exceptions.ClientException(msg)
headers = {'Accept': 'application/json',
'X-Auth-Token': self.oauth2_access_token,
'X-Subject-Token': self.oauth2_access_token}
headers = {
'Accept': 'application/json',
'X-Auth-Token': self.oauth2_access_token,
'X-Subject-Token': self.oauth2_access_token,
}
token_url = '%s/auth/tokens' % self.auth_url.rstrip('/')
token_url = '{}/auth/tokens'.format(self.auth_url.rstrip('/'))
if not self.auth_url.rstrip('/').endswith('v3'):
token_url = '%s/v3/auth/tokens' % self.auth_url.rstrip('/')
resp = session.get(url=token_url,
authenticated=False,
headers=headers,
log=False)
token_url = '{}/v3/auth/tokens'.format(self.auth_url.rstrip('/'))
resp = session.get(
url=token_url, authenticated=False, headers=headers, log=False
)
try:
resp_data = resp.json()
except ValueError:
@ -96,8 +101,9 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
if 'token' not in resp_data:
raise exceptions.InvalidResponse(response=resp)
return access.AccessInfoV3(auth_token=self.oauth2_access_token,
body=resp_data)
return access.AccessInfoV3(
auth_token=self.oauth2_access_token, body=resp_data
)
def get_headers(self, session, **kwargs):
"""Fetch authentication headers for message.
@ -111,8 +117,7 @@ class OAuth2mTlsClientCredential(base.BaseAuth, metaclass=abc.ABCMeta):
:rtype: dict
"""
# get headers for X-Auth-Token
headers = super(OAuth2mTlsClientCredential, self).get_headers(
session, **kwargs)
headers = super().get_headers(session, **kwargs)
# add OAuth2.0 access token to the headers
if headers:

View File

@ -27,10 +27,12 @@ from keystoneauth1.identity.v3 import federation
_logger = utils.get_logger(__name__)
__all__ = ('OidcAuthorizationCode',
'OidcClientCredentials',
'OidcPassword',
'OidcAccessToken')
__all__ = (
'OidcAuthorizationCode',
'OidcClientCredentials',
'OidcPassword',
'OidcAccessToken',
)
SENSITIVE_KEYS = ("password", "code", "token", "secret")
@ -44,14 +46,20 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
grant_type = None
def __init__(self, auth_url, identity_provider, protocol,
client_id, client_secret,
access_token_type,
scope="openid profile",
access_token_endpoint=None,
discovery_endpoint=None,
grant_type=None,
**kwargs):
def __init__(
self,
auth_url,
identity_provider,
protocol,
client_id,
client_secret,
access_token_type,
scope="openid profile",
access_token_endpoint=None,
discovery_endpoint=None,
grant_type=None,
**kwargs,
):
"""The OpenID Connect plugin expects the following.
:param auth_url: URL of the Identity Service
@ -96,8 +104,7 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
states that "openid" must be always specified.
:type scope: string
"""
super(_OidcBase, self).__init__(auth_url, identity_provider, protocol,
**kwargs)
super().__init__(auth_url, identity_provider, protocol, **kwargs)
self.client_id = client_id
self.client_secret = client_secret
@ -111,15 +118,17 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
if grant_type is not None:
if grant_type != self.grant_type:
raise exceptions.OidcGrantTypeMissmatch()
warnings.warn("Passing grant_type as an argument has been "
"deprecated as it is now defined in the plugin "
"itself. You should stop passing this argument "
"to the plugin, as it will be ignored, since you "
"cannot pass a free text string as a grant_type. "
"This argument will be dropped from the plugin in "
"July 2017 or with the next major release of "
"keystoneauth (3.0.0)",
DeprecationWarning)
warnings.warn(
"Passing grant_type as an argument has been "
"deprecated as it is now defined in the plugin "
"itself. You should stop passing this argument "
"to the plugin, as it will be ignored, since you "
"cannot pass a free text string as a grant_type. "
"This argument will be dropped from the plugin in "
"July 2017 or with the next major release of "
"keystoneauth (3.0.0)",
DeprecationWarning,
)
def _get_discovery_document(self, session):
"""Get the contents of the OpenID Connect Discovery Document.
@ -137,14 +146,18 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
otherwise it will return an empty dict.
:rtype: dict
"""
if (self.discovery_endpoint is not None
and not self._discovery_document):
if (
self.discovery_endpoint is not None
and not self._discovery_document
):
try:
resp = session.get(self.discovery_endpoint,
authenticated=False)
resp = session.get(
self.discovery_endpoint, authenticated=False
)
except exceptions.HttpError:
_logger.error("Cannot fetch discovery document %(discovery)s" %
{"discovery": self.discovery_endpoint})
_logger.error(
f"Cannot fetch discovery document {self.discovery_endpoint}"
)
raise
try:
@ -211,20 +224,25 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
sanitized_payload = self._sanitize(payload)
_logger.debug(
"Making OpenID-Connect authentication request to %s with "
"data %s", access_token_endpoint, sanitized_payload
"data %s",
access_token_endpoint,
sanitized_payload,
)
op_response = session.post(access_token_endpoint,
requests_auth=client_auth,
data=payload,
log=False,
authenticated=False)
op_response = session.post(
access_token_endpoint,
requests_auth=client_auth,
data=payload,
log=False,
authenticated=False,
)
response = op_response.json()
if _logger.isEnabledFor(logging.DEBUG):
sanitized_response = self._sanitize(response)
_logger.debug(
"OpenID-Connect authentication response from %s is %s",
access_token_endpoint, sanitized_response
access_token_endpoint,
sanitized_response,
)
return response[self.access_token_type]
@ -247,9 +265,9 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
"""
# use access token against protected URL
headers = {'Authorization': 'Bearer ' + access_token}
auth_response = session.post(self.federated_token_url,
headers=headers,
authenticated=False)
auth_response = session.post(
self.federated_token_url, headers=headers, authenticated=False
)
return auth_response
def get_unscoped_auth_ref(self, session):
@ -278,8 +296,11 @@ class _OidcBase(federation.FederationBaseAuth, metaclass=abc.ABCMeta):
# First of all, check if the grant type is supported
discovery = self._get_discovery_document(session)
grant_types = discovery.get("grant_types_supported")
if (grant_types and self.grant_type is not None
and self.grant_type not in grant_types):
if (
grant_types
and self.grant_type is not None
and self.grant_type not in grant_types
):
raise exceptions.OidcPluginNotSupported()
# Get the payload
@ -317,13 +338,21 @@ class OidcPassword(_OidcBase):
grant_type = "password"
def __init__(self, auth_url, identity_provider, protocol, # nosec
client_id, client_secret,
access_token_endpoint=None,
discovery_endpoint=None,
access_token_type='access_token',
username=None, password=None, idp_otp_key=None,
**kwargs):
def __init__(
self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret,
access_token_endpoint=None,
discovery_endpoint=None,
access_token_type='access_token',
username=None,
password=None,
idp_otp_key=None,
**kwargs,
):
"""The OpenID Password plugin expects the following.
:param username: Username used to authenticate
@ -332,7 +361,7 @@ class OidcPassword(_OidcBase):
:param password: Password used to authenticate
:type password: string
"""
super(OidcPassword, self).__init__(
super().__init__(
auth_url=auth_url,
identity_provider=identity_provider,
protocol=protocol,
@ -341,7 +370,8 @@ class OidcPassword(_OidcBase):
access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint,
access_token_type=access_token_type,
**kwargs)
**kwargs,
)
self.username = username
self.password = password
self.idp_otp_key = idp_otp_key
@ -355,10 +385,12 @@ class OidcPassword(_OidcBase):
:returns: a python dictionary containing the payload to be exchanged
:rtype: dict
"""
payload = {'username': self.username,
'password': self.password,
'scope': self.scope,
'client_id': self.client_id}
payload = {
'username': self.username,
'password': self.password,
'scope': self.scope,
'client_id': self.client_id,
}
self.manage_otp_from_session_or_request_to_the_user(payload, session)
@ -391,7 +423,8 @@ class OidcPassword(_OidcBase):
payload[self.idp_otp_key] = otp_from_session
else:
payload[self.idp_otp_key] = input(
"Please, enter the generated OTP code: ")
"Please, enter the generated OTP code: "
)
setattr(session, 'otp', payload[self.idp_otp_key])
@ -400,12 +433,18 @@ class OidcClientCredentials(_OidcBase):
grant_type = 'client_credentials'
def __init__(self, auth_url, identity_provider, protocol, # nosec
client_id, client_secret,
access_token_endpoint=None,
discovery_endpoint=None,
access_token_type='access_token',
**kwargs):
def __init__(
self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret,
access_token_endpoint=None,
discovery_endpoint=None,
access_token_type='access_token',
**kwargs,
):
"""The OpenID Client Credentials expects the following.
:param client_id: Client ID used to authenticate
@ -414,7 +453,7 @@ class OidcClientCredentials(_OidcBase):
:param client_secret: Client Secret used to authenticate
:type password: string
"""
super(OidcClientCredentials, self).__init__(
super().__init__(
auth_url=auth_url,
identity_provider=identity_provider,
protocol=protocol,
@ -423,7 +462,8 @@ class OidcClientCredentials(_OidcBase):
access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint,
access_token_type=access_token_type,
**kwargs)
**kwargs,
)
def get_payload(self, session):
"""Get an authorization grant for the client credentials grant type.
@ -443,12 +483,20 @@ class OidcAuthorizationCode(_OidcBase):
grant_type = 'authorization_code'
def __init__(self, auth_url, identity_provider, protocol, # nosec
client_id, client_secret,
access_token_endpoint=None,
discovery_endpoint=None,
access_token_type='access_token',
redirect_uri=None, code=None, **kwargs):
def __init__(
self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret,
access_token_endpoint=None,
discovery_endpoint=None,
access_token_type='access_token',
redirect_uri=None,
code=None,
**kwargs,
):
"""The OpenID Authorization Code plugin expects the following.
:param redirect_uri: OpenID Connect Client Redirect URL
@ -458,7 +506,7 @@ class OidcAuthorizationCode(_OidcBase):
:type code: string
"""
super(OidcAuthorizationCode, self).__init__(
super().__init__(
auth_url=auth_url,
identity_provider=identity_provider,
protocol=protocol,
@ -467,7 +515,8 @@ class OidcAuthorizationCode(_OidcBase):
access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint,
access_token_type=access_token_type,
**kwargs)
**kwargs,
)
self.redirect_uri = redirect_uri
self.code = code
@ -488,8 +537,9 @@ class OidcAuthorizationCode(_OidcBase):
class OidcAccessToken(_OidcBase):
"""Implementation for OpenID Connect access token reuse."""
def __init__(self, auth_url, identity_provider, protocol,
access_token, **kwargs):
def __init__(
self, auth_url, identity_provider, protocol, access_token, **kwargs
):
"""The OpenID Connect plugin based on the Access Token.
It expects the following:
@ -507,13 +557,16 @@ class OidcAccessToken(_OidcBase):
:param access_token: OpenID Connect Access token
:type access_token: string
"""
super(OidcAccessToken, self).__init__(auth_url, identity_provider,
protocol,
client_id=None,
client_secret=None,
access_token_endpoint=None,
access_token_type=None,
**kwargs)
super().__init__(
auth_url,
identity_provider,
protocol,
client_id=None,
client_secret=None,
access_token_endpoint=None,
access_token_type=None,
**kwargs,
)
self.access_token = access_token
def get_payload(self, session):
@ -546,13 +599,20 @@ class OidcDeviceAuthorization(_OidcBase):
grant_type = "urn:ietf:params:oauth:grant-type:device_code"
HEADER_X_FORM = {"Content-Type": "application/x-www-form-urlencoded"}
def __init__(self, auth_url, identity_provider, protocol, # nosec
client_id, client_secret=None,
access_token_endpoint=None,
device_authorization_endpoint=None,
discovery_endpoint=None,
code_challenge=None, code_challenge_method=None,
**kwargs):
def __init__(
self,
auth_url,
identity_provider,
protocol, # nosec
client_id,
client_secret=None,
access_token_endpoint=None,
device_authorization_endpoint=None,
discovery_endpoint=None,
code_challenge=None,
code_challenge_method=None,
**kwargs,
):
"""The OAuth 2.0 Device Authorization plugin expects the following.
:param device_authorization_endpoint: OAuth 2.0 Device Authorization
@ -571,7 +631,7 @@ class OidcDeviceAuthorization(_OidcBase):
self.device_authorization_endpoint = device_authorization_endpoint
self.code_challenge_method = code_challenge_method
super(OidcDeviceAuthorization, self).__init__(
super().__init__(
auth_url=auth_url,
identity_provider=identity_provider,
protocol=protocol,
@ -580,7 +640,8 @@ class OidcDeviceAuthorization(_OidcBase):
access_token_endpoint=access_token_endpoint,
discovery_endpoint=discovery_endpoint,
access_token_type=self.access_token_type,
**kwargs)
**kwargs,
)
def _get_device_authorization_endpoint(self, session):
"""Get the endpoint for the OAuth 2.0 Device Authorization flow.
@ -639,8 +700,9 @@ class OidcDeviceAuthorization(_OidcBase):
:returns: a python dictionary containing the payload to be exchanged
:rtype: dict
"""
device_authz_endpoint = \
self._get_device_authorization_endpoint(session)
device_authz_endpoint = self._get_device_authorization_endpoint(
session
)
if self.client_secret:
client_auth = (self.client_id, self.client_secret)
@ -651,8 +713,9 @@ class OidcDeviceAuthorization(_OidcBase):
if self.code_challenge_method:
self.code_challenge = self._generate_pkce_challenge()
payload.setdefault('code_challenge_method',
self.code_challenge_method)
payload.setdefault(
'code_challenge_method', self.code_challenge_method
)
payload.setdefault('code_challenge', self.code_challenge)
encoded_payload = urlparse.urlencode(payload)
@ -660,19 +723,24 @@ class OidcDeviceAuthorization(_OidcBase):
sanitized_payload = self._sanitize(payload)
_logger.debug(
"Making OpenID-Connect authentication request to %s with "
"data %s", device_authz_endpoint, sanitized_payload
"data %s",
device_authz_endpoint,
sanitized_payload,
)
op_response = session.post(device_authz_endpoint,
requests_auth=client_auth,
headers=self.HEADER_X_FORM,
data=encoded_payload,
log=False,
authenticated=False)
op_response = session.post(
device_authz_endpoint,
requests_auth=client_auth,
headers=self.HEADER_X_FORM,
data=encoded_payload,
log=False,
authenticated=False,
)
if _logger.isEnabledFor(logging.DEBUG):
sanitized_response = self._sanitize(op_response.json())
_logger.debug(
"OpenID-Connect authentication response from %s is %s",
device_authz_endpoint, sanitized_response
device_authz_endpoint,
sanitized_response,
)
self.expires_in = int(op_response.json()["expires_in"])
@ -681,8 +749,9 @@ class OidcDeviceAuthorization(_OidcBase):
self.interval = int(op_response.json()["interval"])
self.user_code = op_response.json()["user_code"]
self.verification_uri = op_response.json()["verification_uri"]
self.verification_uri_complete = \
op_response.json()["verification_uri_complete"]
self.verification_uri_complete = op_response.json()[
"verification_uri_complete"
]
payload = {'device_code': self.device_code}
if self.code_challenge_method:
@ -701,8 +770,10 @@ class OidcDeviceAuthorization(_OidcBase):
'device_code': self.device_code}
:type payload: dict
"""
_logger.warning(f"To authenticate please go to: "
f"{self.verification_uri_complete}")
_logger.warning(
f"To authenticate please go to: "
f"{self.verification_uri_complete}"
)
if self.client_secret:
client_auth = (self.client_id, self.client_secret)
@ -720,19 +791,23 @@ class OidcDeviceAuthorization(_OidcBase):
_logger.debug(
"Making OpenID-Connect authentication request to %s "
"with data %s",
access_token_endpoint, sanitized_payload
access_token_endpoint,
sanitized_payload,
)
op_response = session.post(access_token_endpoint,
requests_auth=client_auth,
data=encoded_payload,
headers=self.HEADER_X_FORM,
log=False,
authenticated=False)
op_response = session.post(
access_token_endpoint,
requests_auth=client_auth,
data=encoded_payload,
headers=self.HEADER_X_FORM,
log=False,
authenticated=False,
)
if _logger.isEnabledFor(logging.DEBUG):
sanitized_response = self._sanitize(op_response.json())
_logger.debug(
"OpenID-Connect authentication response from %s is %s",
access_token_endpoint, sanitized_response
access_token_endpoint,
sanitized_response,
)
except exceptions.http.BadRequest as exc:
error = exc.response.json().get("error")

View File

@ -26,11 +26,13 @@ class PasswordMethod(base.AuthMethod):
:param string user_domain_name: User's domain name for authentication.
"""
_method_parameters = ['user_id',
'username',
'user_domain_id',
'user_domain_name',
'password']
_method_parameters = [
'user_id',
'username',
'user_domain_id',
'user_domain_name',
'password',
]
def get_auth_data(self, session, auth, headers, **kwargs):
user = {'password': self.password}
@ -48,8 +50,9 @@ class PasswordMethod(base.AuthMethod):
return 'password', {'user': user}
def get_cache_id_elements(self):
return dict(('password_%s' % p, getattr(self, p))
for p in self._method_parameters)
return {
f'password_{p}': getattr(self, p) for p in self._method_parameters
}
class Password(base.AuthConstructor):

View File

@ -13,7 +13,7 @@
from keystoneauth1.identity.v3 import base
__all__ = ('ReceiptMethod', )
__all__ = ('ReceiptMethod',)
class ReceiptMethod(base.AuthMethod):

View File

@ -51,4 +51,4 @@ class Token(base.AuthConstructor):
_auth_method_class = TokenMethod
def __init__(self, auth_url, token, **kwargs):
super(Token, self).__init__(auth_url, token=token, **kwargs)
super().__init__(auth_url, token=token, **kwargs)

View File

@ -27,13 +27,16 @@ class TokenlessAuth(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
the provided HTTPS certificate along with the scope information.
"""
def __init__(self, auth_url,
domain_id=None,
domain_name=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None):
def __init__(
self,
auth_url,
domain_id=None,
domain_name=None,
project_id=None,
project_name=None,
project_domain_id=None,
project_domain_name=None,
):
"""A init method for TokenlessAuth.
:param string auth_url: Identity service endpoint for authentication.
@ -75,23 +78,23 @@ class TokenlessAuth(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
elif self.project_name:
scope_headers['X-Project-Name'] = self.project_name
if self.project_domain_id:
scope_headers['X-Project-Domain-Id'] = (
self.project_domain_id)
scope_headers['X-Project-Domain-Id'] = self.project_domain_id
elif self.project_domain_name:
scope_headers['X-Project-Domain-Name'] = (
self.project_domain_name)
self.project_domain_name
)
else:
LOG.warning(
'Neither Project Domain ID nor Project Domain Name was '
'provided.')
'provided.'
)
return None
elif self.domain_id:
scope_headers['X-Domain-Id'] = self.domain_id
elif self.domain_name:
scope_headers['X-Domain-Name'] = self.domain_name
else:
LOG.warning(
'Neither Project nor Domain scope was provided.')
LOG.warning('Neither Project nor Domain scope was provided.')
return None
return scope_headers
@ -106,8 +109,10 @@ class TokenlessAuth(plugin.BaseAuthPlugin, metaclass=abc.ABCMeta):
:return: A valid endpoint URL or None if not available.
:rtype: string or None
"""
if (service_type is plugin.AUTH_INTERFACE or
service_type.lower() == 'identity'):
if (
service_type is plugin.AUTH_INTERFACE
or service_type.lower() == 'identity'
):
return self.auth_url
return None

View File

@ -28,11 +28,13 @@ class TOTPMethod(base.AuthMethod):
:param string user_domain_name: User's domain name for authentication.
"""
_method_parameters = ['user_id',
'username',
'user_domain_id',
'user_domain_name',
'passcode']
_method_parameters = [
'user_id',
'username',
'user_domain_id',
'user_domain_name',
'passcode',
]
def get_auth_data(self, session, auth, headers, **kwargs):
user = {'passcode': self.passcode}
@ -54,8 +56,7 @@ class TOTPMethod(base.AuthMethod):
# the key in caching.
params = copy.copy(self._method_parameters)
params.remove('passcode')
return dict(('totp_%s' % p, getattr(self, p))
for p in self._method_parameters)
return {f'totp_{p}': getattr(self, p) for p in self._method_parameters}
class TOTP(base.AuthConstructor):

View File

@ -37,7 +37,8 @@ get_session_conf_options = session.get_conf_options
register_adapter_argparse_arguments = adapter.register_argparse_arguments
register_service_adapter_argparse_arguments = (
adapter.register_service_argparse_arguments)
adapter.register_service_argparse_arguments
)
register_adapter_conf_options = adapter.register_conf_options
load_adapter_from_conf_options = adapter.load_from_conf_options
get_adapter_conf_options = adapter.get_conf_options
@ -50,38 +51,32 @@ __all__ = (
'get_available_plugin_loaders',
'get_plugin_loader',
'PLUGIN_NAMESPACE',
# loading.identity
'BaseIdentityLoader',
'BaseV2Loader',
'BaseV3Loader',
'BaseFederationLoader',
'BaseGenericLoader',
# auth cli
'register_auth_argparse_arguments',
'load_auth_from_argparse_arguments',
# auth conf
'get_auth_common_conf_options',
'get_auth_plugin_conf_options',
'register_auth_conf_options',
'load_auth_from_conf_options',
# session
'register_session_argparse_arguments',
'load_session_from_argparse_arguments',
'register_session_conf_options',
'load_session_from_conf_options',
'get_session_conf_options',
# adapter
'register_adapter_argparse_arguments',
'register_service_adapter_argparse_arguments',
'register_adapter_conf_options',
'load_adapter_from_conf_options',
'get_adapter_conf_options',
# loading.opts
'Opt',
)

View File

@ -32,15 +32,21 @@ class AdminToken(loading.BaseLoader):
return token_endpoint.Token
def get_options(self):
options = super(AdminToken, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('endpoint',
deprecated=[loading.Opt('url')],
help='The endpoint that will always be used'),
loading.Opt('token',
secret=True,
help='The token that will always be used'),
])
options.extend(
[
loading.Opt(
'endpoint',
deprecated=[loading.Opt('url')],
help='The endpoint that will always be used',
),
loading.Opt(
'token',
secret=True,
help='The token that will always be used',
),
]
)
return options

View File

@ -31,18 +31,25 @@ class HTTPBasicAuth(loading.BaseLoader):
return http_basic.HTTPBasicAuth
def get_options(self):
options = super(HTTPBasicAuth, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('username',
help='Username',
deprecated=[loading.Opt('user-name')]),
loading.Opt('password',
secret=True,
prompt='Password: ',
help="User's password"),
loading.Opt('endpoint',
help='The endpoint that will always be used'),
])
options.extend(
[
loading.Opt(
'username',
help='Username',
deprecated=[loading.Opt('user-name')],
),
loading.Opt(
'password',
secret=True,
prompt='Password: ',
help="User's password",
),
loading.Opt(
'endpoint', help='The endpoint that will always be used'
),
]
)
return options

View File

@ -33,12 +33,15 @@ class Token(loading.BaseGenericLoader):
return identity.Token
def get_options(self):
options = super(Token, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('token', secret=True,
help='Token to authenticate with'),
])
options.extend(
[
loading.Opt(
'token', secret=True, help='Token to authenticate with'
)
]
)
return options
@ -59,17 +62,23 @@ class Password(loading.BaseGenericLoader):
return identity.Password
def get_options(cls):
options = super(Password, cls).get_options()
options.extend([
loading.Opt('user-id', help='User id'),
loading.Opt('username',
help='Username',
deprecated=[loading.Opt('user-name')]),
loading.Opt('user-domain-id', help="User's domain id"),
loading.Opt('user-domain-name', help="User's domain name"),
loading.Opt('password',
secret=True,
prompt='Password: ',
help="User's password"),
])
options = super().get_options()
options.extend(
[
loading.Opt('user-id', help='User id'),
loading.Opt(
'username',
help='Username',
deprecated=[loading.Opt('user-name')],
),
loading.Opt('user-domain-id', help="User's domain id"),
loading.Opt('user-domain-name', help="User's domain name"),
loading.Opt(
'password',
secret=True,
prompt='Password: ',
help="User's password",
),
]
)
return options

View File

@ -15,39 +15,41 @@ from keystoneauth1 import loading
class Token(loading.BaseV2Loader):
@property
def plugin_class(self):
return identity.V2Token
def get_options(self):
options = super(Token, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('token', secret=True, help='Token'),
])
options.extend([loading.Opt('token', secret=True, help='Token')])
return options
class Password(loading.BaseV2Loader):
@property
def plugin_class(self):
return identity.V2Password
def get_options(self):
options = super(Password, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('username',
deprecated=[loading.Opt('user-name')],
help='Username to login with'),
loading.Opt('user-id', help='User ID to login with'),
loading.Opt('password',
secret=True,
prompt='Password: ',
help='Password to use'),
])
options.extend(
[
loading.Opt(
'username',
deprecated=[loading.Opt('user-name')],
help='Username to login with',
),
loading.Opt('user-id', help='User ID to login with'),
loading.Opt(
'password',
secret=True,
prompt='Password: ',
help='Password to use',
),
]
)
return options

View File

@ -16,334 +16,409 @@ from keystoneauth1 import loading
def _add_common_identity_options(options):
options.extend([
loading.Opt('user-id', help='User ID'),
loading.Opt('username',
help='Username',
deprecated=[loading.Opt('user-name')]),
loading.Opt('user-domain-id', help="User's domain id"),
loading.Opt('user-domain-name', help="User's domain name"),
])
options.extend(
[
loading.Opt('user-id', help='User ID'),
loading.Opt(
'username',
help='Username',
deprecated=[loading.Opt('user-name')],
),
loading.Opt('user-domain-id', help="User's domain id"),
loading.Opt('user-domain-name', help="User's domain name"),
]
)
def _assert_identity_options(options):
if (options.get('username') and
not (options.get('user_domain_name') or
options.get('user_domain_id'))):
m = "You have provided a username. In the V3 identity API a " \
"username is only unique within a domain so you must " \
if options.get('username') and not (
options.get('user_domain_name') or options.get('user_domain_id')
):
m = (
"You have provided a username. In the V3 identity API a "
"username is only unique within a domain so you must "
"also provide either a user_domain_id or user_domain_name."
)
raise exceptions.OptionError(m)
class Password(loading.BaseV3Loader):
@property
def plugin_class(self):
return identity.V3Password
def get_options(self):
options = super(Password, self).get_options()
options = super().get_options()
_add_common_identity_options(options)
options.extend([
loading.Opt('password',
secret=True,
prompt='Password: ',
help="User's password"),
])
options.extend(
[
loading.Opt(
'password',
secret=True,
prompt='Password: ',
help="User's password",
)
]
)
return options
def load_from_options(self, **kwargs):
_assert_identity_options(kwargs)
return super(Password, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class Token(loading.BaseV3Loader):
@property
def plugin_class(self):
return identity.V3Token
def get_options(self):
options = super(Token, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('token',
secret=True,
help='Token to authenticate with'),
])
options.extend(
[
loading.Opt(
'token', secret=True, help='Token to authenticate with'
)
]
)
return options
class _OpenIDConnectBase(loading.BaseFederationLoader):
def load_from_options(self, **kwargs):
if not (kwargs.get('access_token_endpoint') or
kwargs.get('discovery_endpoint')):
m = ("You have to specify either an 'access-token-endpoint' or "
"a 'discovery-endpoint'.")
if not (
kwargs.get('access_token_endpoint')
or kwargs.get('discovery_endpoint')
):
m = (
"You have to specify either an 'access-token-endpoint' or "
"a 'discovery-endpoint'."
)
raise exceptions.OptionError(m)
return super(_OpenIDConnectBase, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
def get_options(self):
options = super(_OpenIDConnectBase, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('client-id', help='OAuth 2.0 Client ID'),
loading.Opt('client-secret', secret=True,
help='OAuth 2.0 Client Secret'),
loading.Opt('openid-scope', default="openid profile",
dest="scope",
help='OpenID Connect scope that is requested from '
'authorization server. Note that the OpenID '
'Connect specification states that "openid" '
'must be always specified.'),
loading.Opt('access-token-endpoint',
help='OpenID Connect Provider Token Endpoint. Note '
'that if a discovery document is being passed this '
'option will override the endpoint provided by the '
'server in the discovery document.'),
loading.Opt('discovery-endpoint',
help='OpenID Connect Discovery Document URL. '
'The discovery document will be used to obtain the '
'values of the access token endpoint and the '
'authentication endpoint. This URL should look like '
'https://idp.example.org/.well-known/'
'openid-configuration'),
loading.Opt('access-token-type',
help='OAuth 2.0 Authorization Server Introspection '
'token type, it is used to decide which type '
'of token will be used when processing token '
'introspection. Valid values are: '
'"access_token" or "id_token"'),
])
options.extend(
[
loading.Opt('client-id', help='OAuth 2.0 Client ID'),
loading.Opt(
'client-secret',
secret=True,
help='OAuth 2.0 Client Secret',
),
loading.Opt(
'openid-scope',
default="openid profile",
dest="scope",
help='OpenID Connect scope that is requested from '
'authorization server. Note that the OpenID '
'Connect specification states that "openid" '
'must be always specified.',
),
loading.Opt(
'access-token-endpoint',
help='OpenID Connect Provider Token Endpoint. Note '
'that if a discovery document is being passed this '
'option will override the endpoint provided by the '
'server in the discovery document.',
),
loading.Opt(
'discovery-endpoint',
help='OpenID Connect Discovery Document URL. '
'The discovery document will be used to obtain the '
'values of the access token endpoint and the '
'authentication endpoint. This URL should look like '
'https://idp.example.org/.well-known/'
'openid-configuration',
),
loading.Opt(
'access-token-type',
help='OAuth 2.0 Authorization Server Introspection '
'token type, it is used to decide which type '
'of token will be used when processing token '
'introspection. Valid values are: '
'"access_token" or "id_token"',
),
]
)
return options
class OpenIDConnectClientCredentials(_OpenIDConnectBase):
@property
def plugin_class(self):
return identity.V3OidcClientCredentials
def get_options(self):
options = super(OpenIDConnectClientCredentials, self).get_options()
options = super().get_options()
return options
class OpenIDConnectPassword(_OpenIDConnectBase):
@property
def plugin_class(self):
return identity.V3OidcPassword
def get_options(self):
options = super(OpenIDConnectPassword, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('username', help='Username', required=True),
loading.Opt('password', secret=True,
help='Password', required=True),
loading.Opt('idp_otp_key',
help='A key to be used in the Identity Provider access'
' token endpoint to pass the OTP value. '
'E.g. totp'),
])
options.extend(
[
loading.Opt('username', help='Username', required=True),
loading.Opt(
'password', secret=True, help='Password', required=True
),
loading.Opt(
'idp_otp_key',
help='A key to be used in the Identity Provider access'
' token endpoint to pass the OTP value. '
'E.g. totp',
),
]
)
return options
class OpenIDConnectAuthorizationCode(_OpenIDConnectBase):
@property
def plugin_class(self):
return identity.V3OidcAuthorizationCode
def get_options(self):
options = super(OpenIDConnectAuthorizationCode, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('redirect-uri', help='OpenID Connect Redirect URL'),
loading.Opt('code', secret=True, required=True,
deprecated=[loading.Opt('authorization-code')],
help='OAuth 2.0 Authorization Code'),
])
options.extend(
[
loading.Opt(
'redirect-uri', help='OpenID Connect Redirect URL'
),
loading.Opt(
'code',
secret=True,
required=True,
deprecated=[loading.Opt('authorization-code')],
help='OAuth 2.0 Authorization Code',
),
]
)
return options
class OpenIDConnectAccessToken(loading.BaseFederationLoader):
@property
def plugin_class(self):
return identity.V3OidcAccessToken
def get_options(self):
options = super(OpenIDConnectAccessToken, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('access-token', secret=True, required=True,
help='OAuth 2.0 Access Token'),
])
options.extend(
[
loading.Opt(
'access-token',
secret=True,
required=True,
help='OAuth 2.0 Access Token',
)
]
)
return options
class OpenIDConnectDeviceAuthorization(_OpenIDConnectBase):
@property
def plugin_class(self):
return identity.V3OidcDeviceAuthorization
def get_options(self):
options = super(OpenIDConnectDeviceAuthorization, self).get_options()
options = super().get_options()
# RFC 8628 doesn't support id_token
options = [opt for opt in options if opt.name != 'access-token-type']
options.extend([
loading.Opt('device-authorization-endpoint',
help='OAuth 2.0 Device Authorization Endpoint. Note '
'that if a discovery document is being passed this '
'option will override the endpoint provided by the '
'server in the discovery document.'),
loading.Opt('code-challenge-method',
help='PKCE Challenge Method (RFC 7636)'),
])
options.extend(
[
loading.Opt(
'device-authorization-endpoint',
help='OAuth 2.0 Device Authorization Endpoint. Note '
'that if a discovery document is being passed this '
'option will override the endpoint provided by the '
'server in the discovery document.',
),
loading.Opt(
'code-challenge-method',
help='PKCE Challenge Method (RFC 7636)',
),
]
)
return options
class TOTP(loading.BaseV3Loader):
@property
def plugin_class(self):
return identity.V3TOTP
def get_options(self):
options = super(TOTP, self).get_options()
options = super().get_options()
_add_common_identity_options(options)
options.extend([
loading.Opt(
'passcode',
secret=True,
prompt='TOTP passcode: ',
help="User's TOTP passcode"),
])
options.extend(
[
loading.Opt(
'passcode',
secret=True,
prompt='TOTP passcode: ',
help="User's TOTP passcode",
)
]
)
return options
def load_from_options(self, **kwargs):
_assert_identity_options(kwargs)
return super(TOTP, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class TokenlessAuth(loading.BaseLoader):
@property
def plugin_class(self):
return identity.V3TokenlessAuth
def get_options(self):
options = super(TokenlessAuth, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('auth-url', required=True,
help='Authentication URL'),
loading.Opt('domain-id', help='Domain ID to scope to'),
loading.Opt('domain-name', help='Domain name to scope to'),
loading.Opt('project-id', help='Project ID to scope to'),
loading.Opt('project-name', help='Project name to scope to'),
loading.Opt('project-domain-id',
help='Domain ID containing project'),
loading.Opt('project-domain-name',
help='Domain name containing project'),
])
options.extend(
[
loading.Opt(
'auth-url', required=True, help='Authentication URL'
),
loading.Opt('domain-id', help='Domain ID to scope to'),
loading.Opt('domain-name', help='Domain name to scope to'),
loading.Opt('project-id', help='Project ID to scope to'),
loading.Opt('project-name', help='Project name to scope to'),
loading.Opt(
'project-domain-id', help='Domain ID containing project'
),
loading.Opt(
'project-domain-name',
help='Domain name containing project',
),
]
)
return options
def load_from_options(self, **kwargs):
if (not kwargs.get('domain_id') and
not kwargs.get('domain_name') and
not kwargs.get('project_id') and
not kwargs.get('project_name') or
(kwargs.get('project_name') and
not (kwargs.get('project_domain_name') or
kwargs.get('project_domain_id')))):
m = ('You need to provide either a domain_name, domain_id, '
'project_id or project_name. '
'If you have provided a project_name, in the V3 identity '
'API a project_name is only unique within a domain so '
'you must also provide either a project_domain_id or '
'project_domain_name.')
if (
not kwargs.get('domain_id')
and not kwargs.get('domain_name')
and not kwargs.get('project_id')
and not kwargs.get('project_name')
or (
kwargs.get('project_name')
and not (
kwargs.get('project_domain_name')
or kwargs.get('project_domain_id')
)
)
):
m = (
'You need to provide either a domain_name, domain_id, '
'project_id or project_name. '
'If you have provided a project_name, in the V3 identity '
'API a project_name is only unique within a domain so '
'you must also provide either a project_domain_id or '
'project_domain_name.'
)
raise exceptions.OptionError(m)
return super(TokenlessAuth, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class ApplicationCredential(loading.BaseV3Loader):
@property
def plugin_class(self):
return identity.V3ApplicationCredential
def get_options(self):
options = super(ApplicationCredential, self).get_options()
options = super().get_options()
_add_common_identity_options(options)
options.extend([
loading.Opt('application_credential_secret', secret=True,
required=True,
help="Application credential auth secret"),
loading.Opt('application_credential_id',
help='Application credential ID'),
loading.Opt('application_credential_name',
help='Application credential name'),
])
options.extend(
[
loading.Opt(
'application_credential_secret',
secret=True,
required=True,
help="Application credential auth secret",
),
loading.Opt(
'application_credential_id',
help='Application credential ID',
),
loading.Opt(
'application_credential_name',
help='Application credential name',
),
]
)
return options
def load_from_options(self, **kwargs):
_assert_identity_options(kwargs)
if (not kwargs.get('application_credential_id') and
not kwargs.get('application_credential_name')):
m = ('You must provide either an application credential ID or an '
'application credential name and user.')
if not kwargs.get('application_credential_id') and not kwargs.get(
'application_credential_name'
):
m = (
'You must provide either an application credential ID or an '
'application credential name and user.'
)
raise exceptions.OptionError(m)
if not kwargs.get('application_credential_secret'):
m = ('You must provide an auth secret.')
m = 'You must provide an auth secret.'
raise exceptions.OptionError(m)
return super(ApplicationCredential, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class MultiFactor(loading.BaseV3Loader):
def __init__(self, *args, **kwargs):
self._methods = None
return super(MultiFactor, self).__init__(*args, **kwargs)
return super().__init__(*args, **kwargs)
@property
def plugin_class(self):
return identity.V3MultiFactor
def get_options(self):
options = super(MultiFactor, self).get_options()
options = super().get_options()
options.extend([
loading.Opt(
'auth_methods',
required=True,
help="Methods to authenticate with."),
])
options.extend(
[
loading.Opt(
'auth_methods',
required=True,
help="Methods to authenticate with.",
)
]
)
if self._methods:
options_dict = {o.name: o for o in options}
@ -362,29 +437,36 @@ class MultiFactor(loading.BaseV3Loader):
self._methods = kwargs['auth_methods']
return super(MultiFactor, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class OAuth2ClientCredential(loading.BaseV3Loader):
@property
def plugin_class(self):
return identity.V3OAuth2ClientCredential
def get_options(self):
options = super(OAuth2ClientCredential, self).get_options()
options.extend([
loading.Opt('oauth2_endpoint',
required=True,
help='Endpoint for OAuth2.0'),
loading.Opt('oauth2_client_id',
required=True,
help='Client id for OAuth2.0'),
loading.Opt('oauth2_client_secret',
secret=True,
required=True,
help='Client secret for OAuth2.0'),
])
options = super().get_options()
options.extend(
[
loading.Opt(
'oauth2_endpoint',
required=True,
help='Endpoint for OAuth2.0',
),
loading.Opt(
'oauth2_client_id',
required=True,
help='Client id for OAuth2.0',
),
loading.Opt(
'oauth2_client_secret',
secret=True,
required=True,
help='Client secret for OAuth2.0',
),
]
)
return options
@ -399,26 +481,31 @@ class OAuth2ClientCredential(loading.BaseV3Loader):
m = 'You must provide an OAuth2.0 client credential auth secret.'
raise exceptions.OptionError(m)
return super(OAuth2ClientCredential, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class OAuth2mTlsClientCredential(loading.BaseV3Loader):
@property
def plugin_class(self):
return identity.V3OAuth2mTlsClientCredential
def get_options(self):
options = super(OAuth2mTlsClientCredential, self).get_options()
options.extend([
loading.Opt('oauth2-endpoint',
required=True,
help='Endpoint for OAuth2.0 Mutual-TLS Authorization'),
loading.Opt('oauth2-client-id',
required=True,
help='Client credential ID for OAuth2.0 Mutual-TLS '
'Authorization')
])
options = super().get_options()
options.extend(
[
loading.Opt(
'oauth2-endpoint',
required=True,
help='Endpoint for OAuth2.0 Mutual-TLS Authorization',
),
loading.Opt(
'oauth2-client-id',
required=True,
help='Client credential ID for OAuth2.0 Mutual-TLS '
'Authorization',
),
]
)
return options
def load_from_options(self, **kwargs):
@ -426,8 +513,9 @@ class OAuth2mTlsClientCredential(loading.BaseV3Loader):
m = 'You must provide an OAuth2.0 Mutual-TLS endpoint.'
raise exceptions.OptionError(m)
if not kwargs.get('oauth2_client_id'):
m = ('You must provide an client credential ID for '
'OAuth2.0 Mutual-TLS Authorization.')
m = (
'You must provide an client credential ID for '
'OAuth2.0 Mutual-TLS Authorization.'
)
raise exceptions.OptionError(m)
return super(OAuth2mTlsClientCredential,
self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)

View File

@ -31,11 +31,14 @@ class NoAuth(loading.BaseLoader):
return noauth.NoAuth
def get_options(self):
options = super(NoAuth, self).get_options()
options = super().get_options()
options.extend([
loading.Opt('endpoint',
help='The endpoint that will always be used'),
])
options.extend(
[
loading.Opt(
'endpoint', help='The endpoint that will always be used'
)
]
)
return options

View File

@ -32,9 +32,11 @@ def get_oslo_config():
cfg = _NOT_FOUND
if cfg is _NOT_FOUND:
raise ImportError("oslo.config is not an automatic dependency of "
"keystoneauth. If you wish to use oslo.config "
"you need to import it into your application's "
"requirements file. ")
raise ImportError(
"oslo.config is not an automatic dependency of "
"keystoneauth. If you wish to use oslo.config "
"you need to import it into your application's "
"requirements file. "
)
return cfg

View File

@ -15,15 +15,16 @@ from keystoneauth1.loading import _utils
from keystoneauth1.loading import base
__all__ = ('register_argparse_arguments',
'register_service_argparse_arguments',
'register_conf_options',
'load_from_conf_options',
'get_conf_options')
__all__ = (
'register_argparse_arguments',
'register_service_argparse_arguments',
'register_conf_options',
'load_from_conf_options',
'get_conf_options',
)
class Adapter(base.BaseLoader):
@property
def plugin_class(self):
return adapter.Adapter
@ -76,7 +77,7 @@ class Adapter(base.BaseLoader):
the new ``endpoint_override`` option name::
old_opt = oslo_cfg.DeprecatedOpt('api_endpoint', 'old_group')
deprecated_opts={'endpoint_override': [old_opt]}
deprecated_opts = {'endpoint_override': [old_opt]}
:returns: A list of oslo_config options.
"""
@ -86,106 +87,127 @@ class Adapter(base.BaseLoader):
deprecated_opts = {}
# This is goofy, but need to support hyphens *or* underscores
deprecated_opts = {name.replace('_', '-'): opt
for name, opt in deprecated_opts.items()}
deprecated_opts = {
name.replace('_', '-'): opt
for name, opt in deprecated_opts.items()
}
opts = [cfg.StrOpt('service-type',
deprecated_opts=deprecated_opts.get('service-type'),
help='The default service_type for endpoint URL '
'discovery.'),
cfg.StrOpt('service-name',
deprecated_opts=deprecated_opts.get('service-name'),
help='The default service_name for endpoint URL '
'discovery.'),
cfg.ListOpt('valid-interfaces',
deprecated_opts=deprecated_opts.get(
'valid-interfaces'),
help='List of interfaces, in order of preference, '
'for endpoint URL.'),
cfg.StrOpt('region-name',
deprecated_opts=deprecated_opts.get('region-name'),
help='The default region_name for endpoint URL '
'discovery.'),
cfg.StrOpt('endpoint-override',
deprecated_opts=deprecated_opts.get(
'endpoint-override'),
help='Always use this endpoint URL for requests '
'for this client. NOTE: The unversioned '
'endpoint should be specified here; to '
'request a particular API version, use the '
'`version`, `min-version`, and/or '
'`max-version` options.'),
cfg.StrOpt('version',
deprecated_opts=deprecated_opts.get('version'),
help='Minimum Major API version within a given '
'Major API version for endpoint URL '
'discovery. Mutually exclusive with '
'min_version and max_version'),
cfg.StrOpt('min-version',
deprecated_opts=deprecated_opts.get('min-version'),
help='The minimum major version of a given API, '
'intended to be used as the lower bound of a '
'range with max_version. Mutually exclusive '
'with version. If min_version is given with '
'no max_version it is as if max version is '
'"latest".'),
cfg.StrOpt('max-version',
deprecated_opts=deprecated_opts.get('max-version'),
help='The maximum major version of a given API, '
'intended to be used as the upper bound of a '
'range with min_version. Mutually exclusive '
'with version.'),
cfg.IntOpt('connect-retries',
deprecated_opts=deprecated_opts.get(
'connect-retries'),
help='The maximum number of retries that should be '
'attempted for connection errors.'),
cfg.FloatOpt('connect-retry-delay',
deprecated_opts=deprecated_opts.get(
'connect-retry-delay'),
help='Delay (in seconds) between two retries '
'for connection errors. If not set, '
'exponential retry starting with 0.5 '
'seconds up to a maximum of 60 seconds '
'is used.'),
cfg.IntOpt('status-code-retries',
deprecated_opts=deprecated_opts.get(
'status-code-retries'),
help='The maximum number of retries that should be '
'attempted for retriable HTTP status codes.'),
cfg.FloatOpt('status-code-retry-delay',
deprecated_opts=deprecated_opts.get(
'status-code-retry-delay'),
help='Delay (in seconds) between two retries '
'for retriable status codes. If not set, '
'exponential retry starting with 0.5 '
'seconds up to a maximum of 60 seconds '
'is used.'),
cfg.ListOpt('retriable-status-codes',
deprecated_opts=deprecated_opts.get(
'retriable-status-codes'),
item_type=cfg.types.Integer(),
help='List of retriable HTTP status codes that '
'should be retried. If not set default to '
' [503]'
),
]
opts = [
cfg.StrOpt(
'service-type',
deprecated_opts=deprecated_opts.get('service-type'),
help='The default service_type for endpoint URL discovery.',
),
cfg.StrOpt(
'service-name',
deprecated_opts=deprecated_opts.get('service-name'),
help='The default service_name for endpoint URL discovery.',
),
cfg.ListOpt(
'valid-interfaces',
deprecated_opts=deprecated_opts.get('valid-interfaces'),
help='List of interfaces, in order of preference, '
'for endpoint URL.',
),
cfg.StrOpt(
'region-name',
deprecated_opts=deprecated_opts.get('region-name'),
help='The default region_name for endpoint URL discovery.',
),
cfg.StrOpt(
'endpoint-override',
deprecated_opts=deprecated_opts.get('endpoint-override'),
help='Always use this endpoint URL for requests '
'for this client. NOTE: The unversioned '
'endpoint should be specified here; to '
'request a particular API version, use the '
'`version`, `min-version`, and/or '
'`max-version` options.',
),
cfg.StrOpt(
'version',
deprecated_opts=deprecated_opts.get('version'),
help='Minimum Major API version within a given '
'Major API version for endpoint URL '
'discovery. Mutually exclusive with '
'min_version and max_version',
),
cfg.StrOpt(
'min-version',
deprecated_opts=deprecated_opts.get('min-version'),
help='The minimum major version of a given API, '
'intended to be used as the lower bound of a '
'range with max_version. Mutually exclusive '
'with version. If min_version is given with '
'no max_version it is as if max version is '
'"latest".',
),
cfg.StrOpt(
'max-version',
deprecated_opts=deprecated_opts.get('max-version'),
help='The maximum major version of a given API, '
'intended to be used as the upper bound of a '
'range with min_version. Mutually exclusive '
'with version.',
),
cfg.IntOpt(
'connect-retries',
deprecated_opts=deprecated_opts.get('connect-retries'),
help='The maximum number of retries that should be '
'attempted for connection errors.',
),
cfg.FloatOpt(
'connect-retry-delay',
deprecated_opts=deprecated_opts.get('connect-retry-delay'),
help='Delay (in seconds) between two retries '
'for connection errors. If not set, '
'exponential retry starting with 0.5 '
'seconds up to a maximum of 60 seconds '
'is used.',
),
cfg.IntOpt(
'status-code-retries',
deprecated_opts=deprecated_opts.get('status-code-retries'),
help='The maximum number of retries that should be '
'attempted for retriable HTTP status codes.',
),
cfg.FloatOpt(
'status-code-retry-delay',
deprecated_opts=deprecated_opts.get('status-code-retry-delay'),
help='Delay (in seconds) between two retries '
'for retriable status codes. If not set, '
'exponential retry starting with 0.5 '
'seconds up to a maximum of 60 seconds '
'is used.',
),
cfg.ListOpt(
'retriable-status-codes',
deprecated_opts=deprecated_opts.get('retriable-status-codes'),
item_type=cfg.types.Integer(),
help='List of retriable HTTP status codes that '
'should be retried. If not set default to '
' [503]',
),
]
if include_deprecated:
opts += [
cfg.StrOpt('interface',
help='The default interface for endpoint URL '
'discovery.',
deprecated_for_removal=True,
deprecated_reason='Using valid-interfaces is'
' preferrable because it is'
' capable of accepting a list of'
' possible interfaces.'),
cfg.StrOpt(
'interface',
help='The default interface for endpoint URL '
'discovery.',
deprecated_for_removal=True,
deprecated_reason='Using valid-interfaces is'
' preferrable because it is'
' capable of accepting a list of'
' possible interfaces.',
)
]
return opts
def register_conf_options(self, conf, group, include_deprecated=True,
deprecated_opts=None):
def register_conf_options(
self, conf, group, include_deprecated=True, deprecated_opts=None
):
"""Register the oslo_config options that are needed for an Adapter.
The options that are set are:
@ -231,12 +253,14 @@ class Adapter(base.BaseLoader):
the new ``endpoint_override`` option name::
old_opt = oslo_cfg.DeprecatedOpt('api_endpoint', 'old_group')
deprecated_opts={'endpoint_override': [old_opt]}
deprecated_opts = {'endpoint_override': [old_opt]}
:returns: The list of options that was registered.
"""
opts = self.get_conf_options(include_deprecated=include_deprecated,
deprecated_opts=deprecated_opts)
opts = self.get_conf_options(
include_deprecated=include_deprecated,
deprecated_opts=deprecated_opts,
)
conf.register_group(_utils.get_oslo_config().OptGroup(group))
conf.register_opts(opts, group=group)
return opts
@ -270,16 +294,19 @@ def process_conf_options(confgrp, kwargs):
:raise TypeError: If invalid conf option values or combinations are found.
"""
if confgrp.valid_interfaces and getattr(confgrp, 'interface', None):
raise TypeError("interface and valid_interfaces are mutually"
" exclusive. Please use valid_interfaces.")
raise TypeError(
"interface and valid_interfaces are mutually"
" exclusive. Please use valid_interfaces."
)
if confgrp.valid_interfaces:
for iface in confgrp.valid_interfaces:
if iface not in ('public', 'internal', 'admin'):
# TODO(efried): s/valies/values/ - are we allowed to fix this?
raise TypeError("'{iface}' is not a valid value for"
" valid_interfaces. Valid valies are"
" public, internal or admin".format(
iface=iface))
raise TypeError(
f"'{iface}' is not a valid value for"
" valid_interfaces. Valid valies are"
" public, internal or admin"
)
kwargs.setdefault('interface', confgrp.valid_interfaces)
elif hasattr(confgrp, 'interface'):
kwargs.setdefault('interface', confgrp.interface)
@ -290,16 +317,16 @@ def process_conf_options(confgrp, kwargs):
kwargs.setdefault('version', confgrp.version)
kwargs.setdefault('min_version', confgrp.min_version)
kwargs.setdefault('max_version', confgrp.max_version)
if kwargs['version'] and (
kwargs['max_version'] or kwargs['min_version']):
if kwargs['version'] and (kwargs['max_version'] or kwargs['min_version']):
raise TypeError(
"version is mutually exclusive with min_version and"
" max_version")
"version is mutually exclusive with min_version and max_version"
)
kwargs.setdefault('connect_retries', confgrp.connect_retries)
kwargs.setdefault('connect_retry_delay', confgrp.connect_retry_delay)
kwargs.setdefault('status_code_retries', confgrp.status_code_retries)
kwargs.setdefault('status_code_retry_delay',
confgrp.status_code_retry_delay)
kwargs.setdefault(
'status_code_retry_delay', confgrp.status_code_retry_delay
)
kwargs.setdefault('retriable_status_codes', confgrp.retriable_status_codes)

View File

@ -19,12 +19,14 @@ from keystoneauth1 import exceptions
PLUGIN_NAMESPACE = 'keystoneauth1.plugin'
__all__ = ('get_available_plugin_names',
'get_available_plugin_loaders',
'get_plugin_loader',
'get_plugin_options',
'BaseLoader',
'PLUGIN_NAMESPACE')
__all__ = (
'get_available_plugin_names',
'get_available_plugin_loaders',
'get_plugin_loader',
'get_plugin_options',
'BaseLoader',
'PLUGIN_NAMESPACE',
)
def _auth_plugin_available(ext):
@ -41,10 +43,12 @@ def get_available_plugin_names():
:returns: A list of names.
:rtype: frozenset
"""
mgr = stevedore.EnabledExtensionManager(namespace=PLUGIN_NAMESPACE,
check_func=_auth_plugin_available,
invoke_on_load=True,
propagate_map_exceptions=True)
mgr = stevedore.EnabledExtensionManager(
namespace=PLUGIN_NAMESPACE,
check_func=_auth_plugin_available,
invoke_on_load=True,
propagate_map_exceptions=True,
)
return frozenset(mgr.names())
@ -55,10 +59,12 @@ def get_available_plugin_loaders():
loader as the value.
:rtype: dict
"""
mgr = stevedore.EnabledExtensionManager(namespace=PLUGIN_NAMESPACE,
check_func=_auth_plugin_available,
invoke_on_load=True,
propagate_map_exceptions=True)
mgr = stevedore.EnabledExtensionManager(
namespace=PLUGIN_NAMESPACE,
check_func=_auth_plugin_available,
invoke_on_load=True,
propagate_map_exceptions=True,
)
return dict(mgr.map(lambda ext: (ext.entry_point.name, ext.obj)))
@ -75,9 +81,9 @@ def get_plugin_loader(name):
if a plugin cannot be created.
"""
try:
mgr = stevedore.DriverManager(namespace=PLUGIN_NAMESPACE,
invoke_on_load=True,
name=name)
mgr = stevedore.DriverManager(
namespace=PLUGIN_NAMESPACE, invoke_on_load=True, name=name
)
except RuntimeError:
raise exceptions.NoMatchingPlugin(name)
@ -99,7 +105,6 @@ def get_plugin_options(name):
class BaseLoader(metaclass=abc.ABCMeta):
@property
def plugin_class(self):
raise NotImplementedError()
@ -153,8 +158,11 @@ class BaseLoader(metaclass=abc.ABCMeta):
handle differences between the registered options and what is required
to create the plugin.
"""
missing_required = [o for o in self.get_options()
if o.required and kwargs.get(o.dest) is None]
missing_required = [
o
for o in self.get_options()
if o.required and kwargs.get(o.dest) is None
]
if missing_required:
raise exceptions.MissingRequiredOptions(missing_required)

View File

@ -16,17 +16,18 @@ import os
from keystoneauth1.loading import base
__all__ = ('register_argparse_arguments',
'load_from_argparse_arguments')
__all__ = ('register_argparse_arguments', 'load_from_argparse_arguments')
def _register_plugin_argparse_arguments(parser, plugin):
for opt in plugin.get_options():
parser.add_argument(*opt.argparse_args,
default=opt.argparse_default,
metavar=opt.metavar,
help=opt.help,
dest='os_%s' % opt.dest)
parser.add_argument(
*opt.argparse_args,
default=opt.argparse_default,
metavar=opt.metavar,
help=opt.help,
dest=f'os_{opt.dest}',
)
def register_argparse_arguments(parser, argv, default=None):
@ -48,14 +49,17 @@ def register_argparse_arguments(parser, argv, default=None):
if a plugin cannot be created.
"""
in_parser = argparse.ArgumentParser(add_help=False)
env_plugin = os.environ.get('OS_AUTH_TYPE',
os.environ.get('OS_AUTH_PLUGIN', default))
env_plugin = os.environ.get(
'OS_AUTH_TYPE', os.environ.get('OS_AUTH_PLUGIN', default)
)
for p in (in_parser, parser):
p.add_argument('--os-auth-type',
'--os-auth-plugin',
metavar='<name>',
default=env_plugin,
help='Authentication type to use')
p.add_argument(
'--os-auth-type',
'--os-auth-plugin',
metavar='<name>',
default=env_plugin,
help='Authentication type to use',
)
options, _args = in_parser.parse_known_args(argv)
@ -66,7 +70,7 @@ def register_argparse_arguments(parser, argv, default=None):
msg = 'Default Authentication options'
plugin = options.os_auth_type
else:
msg = 'Options specific to the %s plugin.' % options.os_auth_type
msg = f'Options specific to the {options.os_auth_type} plugin.'
plugin = base.get_plugin_loader(options.os_auth_type)
group = parser.add_argument_group('Authentication Options', msg)
@ -97,6 +101,6 @@ def load_from_argparse_arguments(namespace, **kwargs):
plugin = base.get_plugin_loader(namespace.os_auth_type)
def _getter(opt):
return getattr(namespace, 'os_%s' % opt.dest)
return getattr(namespace, f'os_{opt.dest}')
return plugin.load_from_options_getter(_getter, **kwargs)

View File

@ -13,18 +13,22 @@
from keystoneauth1.loading import base
from keystoneauth1.loading import opts
_AUTH_TYPE_OPT = opts.Opt('auth_type',
deprecated=[opts.Opt('auth_plugin')],
help='Authentication type to load')
_AUTH_TYPE_OPT = opts.Opt(
'auth_type',
deprecated=[opts.Opt('auth_plugin')],
help='Authentication type to load',
)
_section_help = 'Config Section from which to load plugin specific options'
_AUTH_SECTION_OPT = opts.Opt('auth_section', help=_section_help)
__all__ = ('get_common_conf_options',
'get_plugin_conf_options',
'register_conf_options',
'load_from_conf_options')
__all__ = (
'get_common_conf_options',
'get_plugin_conf_options',
'register_conf_options',
'load_from_conf_options',
)
def get_common_conf_options():

View File

@ -14,11 +14,13 @@ from keystoneauth1 import exceptions
from keystoneauth1.loading import base
from keystoneauth1.loading import opts
__all__ = ('BaseIdentityLoader',
'BaseV2Loader',
'BaseV3Loader',
'BaseFederationLoader',
'BaseGenericLoader')
__all__ = (
'BaseIdentityLoader',
'BaseV2Loader',
'BaseV3Loader',
'BaseFederationLoader',
'BaseGenericLoader',
)
class BaseIdentityLoader(base.BaseLoader):
@ -31,13 +33,11 @@ class BaseIdentityLoader(base.BaseLoader):
"""
def get_options(self):
options = super(BaseIdentityLoader, self).get_options()
options = super().get_options()
options.extend([
opts.Opt('auth-url',
required=True,
help='Authentication URL'),
])
options.extend(
[opts.Opt('auth-url', required=True, help='Authentication URL')]
)
return options
@ -51,14 +51,17 @@ class BaseV2Loader(BaseIdentityLoader):
"""
def get_options(self):
options = super(BaseV2Loader, self).get_options()
options = super().get_options()
options.extend([
opts.Opt('tenant-id', help='Tenant ID'),
opts.Opt('tenant-name', help='Tenant Name'),
opts.Opt('trust-id',
help='ID of the trust to use as a trustee use'),
])
options.extend(
[
opts.Opt('tenant-id', help='Tenant ID'),
opts.Opt('tenant-name', help='Tenant Name'),
opts.Opt(
'trust-id', help='ID of the trust to use as a trustee use'
),
]
)
return options
@ -72,35 +75,44 @@ class BaseV3Loader(BaseIdentityLoader):
"""
def get_options(self):
options = super(BaseV3Loader, self).get_options()
options = super().get_options()
options.extend([
opts.Opt('system-scope', help='Scope for system operations'),
opts.Opt('domain-id', help='Domain ID to scope to'),
opts.Opt('domain-name', help='Domain name to scope to'),
opts.Opt('project-id', help='Project ID to scope to'),
opts.Opt('project-name', help='Project name to scope to'),
opts.Opt('project-domain-id',
help='Domain ID containing project'),
opts.Opt('project-domain-name',
help='Domain name containing project'),
opts.Opt('trust-id',
help='ID of the trust to use as a trustee use'),
])
options.extend(
[
opts.Opt('system-scope', help='Scope for system operations'),
opts.Opt('domain-id', help='Domain ID to scope to'),
opts.Opt('domain-name', help='Domain name to scope to'),
opts.Opt('project-id', help='Project ID to scope to'),
opts.Opt('project-name', help='Project name to scope to'),
opts.Opt(
'project-domain-id', help='Domain ID containing project'
),
opts.Opt(
'project-domain-name',
help='Domain name containing project',
),
opts.Opt(
'trust-id', help='ID of the trust to use as a trustee use'
),
]
)
return options
def load_from_options(self, **kwargs):
if (kwargs.get('project_name') and
not (kwargs.get('project_domain_name') or
kwargs.get('project_domain_id'))):
m = "You have provided a project_name. In the V3 identity API a " \
"project_name is only unique within a domain so you must " \
"also provide either a project_domain_id or " \
if kwargs.get('project_name') and not (
kwargs.get('project_domain_name')
or kwargs.get('project_domain_id')
):
m = (
"You have provided a project_name. In the V3 identity API a "
"project_name is only unique within a domain so you must "
"also provide either a project_domain_id or "
"project_domain_name."
)
raise exceptions.OptionError(m)
return super(BaseV3Loader, self).load_from_options(**kwargs)
return super().load_from_options(**kwargs)
class BaseFederationLoader(BaseV3Loader):
@ -112,16 +124,22 @@ class BaseFederationLoader(BaseV3Loader):
"""
def get_options(self):
options = super(BaseFederationLoader, self).get_options()
options = super().get_options()
options.extend([
opts.Opt('identity-provider',
help="Identity Provider's name",
required=True),
opts.Opt('protocol',
help='Protocol for federated plugin',
required=True),
])
options.extend(
[
opts.Opt(
'identity-provider',
help="Identity Provider's name",
required=True,
),
opts.Opt(
'protocol',
help='Protocol for federated plugin',
required=True,
),
]
)
return options
@ -136,32 +154,48 @@ class BaseGenericLoader(BaseIdentityLoader):
"""
def get_options(self):
options = super(BaseGenericLoader, self).get_options()
options = super().get_options()
options.extend([
opts.Opt('system-scope', help='Scope for system operations'),
opts.Opt('domain-id', help='Domain ID to scope to'),
opts.Opt('domain-name', help='Domain name to scope to'),
opts.Opt('project-id', help='Project ID to scope to',
deprecated=[opts.Opt('tenant-id')]),
opts.Opt('project-name', help='Project name to scope to',
deprecated=[opts.Opt('tenant-name')]),
opts.Opt('project-domain-id',
help='Domain ID containing project'),
opts.Opt('project-domain-name',
help='Domain name containing project'),
opts.Opt('trust-id',
help='ID of the trust to use as a trustee use'),
opts.Opt('default-domain-id',
help='Optional domain ID to use with v3 and v2 '
'parameters. It will be used for both the user '
'and project domain in v3 and ignored in '
'v2 authentication.'),
opts.Opt('default-domain-name',
help='Optional domain name to use with v3 API and v2 '
'parameters. It will be used for both the user '
'and project domain in v3 and ignored in '
'v2 authentication.'),
])
options.extend(
[
opts.Opt('system-scope', help='Scope for system operations'),
opts.Opt('domain-id', help='Domain ID to scope to'),
opts.Opt('domain-name', help='Domain name to scope to'),
opts.Opt(
'project-id',
help='Project ID to scope to',
deprecated=[opts.Opt('tenant-id')],
),
opts.Opt(
'project-name',
help='Project name to scope to',
deprecated=[opts.Opt('tenant-name')],
),
opts.Opt(
'project-domain-id', help='Domain ID containing project'
),
opts.Opt(
'project-domain-name',
help='Domain name containing project',
),
opts.Opt(
'trust-id', help='ID of the trust to use as a trustee use'
),
opts.Opt(
'default-domain-id',
help='Optional domain ID to use with v3 and v2 '
'parameters. It will be used for both the user '
'and project domain in v3 and ignored in '
'v2 authentication.',
),
opts.Opt(
'default-domain-name',
help='Optional domain name to use with v3 API and v2 '
'parameters. It will be used for both the user '
'and project domain in v3 and ignored in '
'v2 authentication.',
),
]
)
return options

View File

@ -19,7 +19,7 @@ from keystoneauth1.loading import _utils
__all__ = ('Opt',)
class Opt(object):
class Opt:
"""An option required by an authentication plugin.
Opts provide a means for authentication plugins that are going to be
@ -60,17 +60,19 @@ class Opt(object):
appropriate) set the string that should be used to prompt with.
"""
def __init__(self,
name,
type=str,
help=None,
secret=False,
dest=None,
deprecated=None,
default=None,
metavar=None,
required=False,
prompt=None):
def __init__(
self,
name,
type=str,
help=None,
secret=False,
dest=None,
deprecated=None,
default=None,
metavar=None,
required=False,
prompt=None,
):
if not callable(type):
raise TypeError('type must be callable')
@ -95,33 +97,37 @@ class Opt(object):
def __repr__(self):
"""Return string representation of option name."""
return '<Opt: %s>' % self.name
return f'<Opt: {self.name}>'
def _to_oslo_opt(self):
cfg = _utils.get_oslo_config()
deprecated_opts = [cfg.DeprecatedOpt(o.name) for o in self.deprecated]
return cfg.Opt(name=self.name,
type=self.type,
help=self.help,
secret=self.secret,
required=self.required,
dest=self.dest,
deprecated_opts=deprecated_opts,
metavar=self.metavar)
return cfg.Opt(
name=self.name,
type=self.type,
help=self.help,
secret=self.secret,
required=self.required,
dest=self.dest,
deprecated_opts=deprecated_opts,
metavar=self.metavar,
)
def __eq__(self, other):
"""Define equality operator on option parameters."""
return (type(self) is type(other) and
self.name == other.name and
self.type == other.type and
self.help == other.help and
self.secret == other.secret and
self.required == other.required and
self.dest == other.dest and
self.deprecated == other.deprecated and
self.default == other.default and
self.metavar == other.metavar)
return (
type(self) is type(other)
and self.name == other.name
and self.type == other.type
and self.help == other.help
and self.secret == other.secret
and self.required == other.required
and self.dest == other.dest
and self.deprecated == other.deprecated
and self.default == other.default
and self.metavar == other.metavar
)
# NOTE: This function is only needed by Python 2. If we get to point where
# we don't support Python 2 anymore, this function should be removed.
@ -135,13 +141,15 @@ class Opt(object):
@property
def argparse_args(self):
return ['--os-%s' % o.name for o in self._all_opts]
return [f'--os-{o.name}' for o in self._all_opts]
@property
def argparse_default(self):
# select the first ENV that is not false-y or return None
for o in self._all_opts:
v = os.environ.get('OS_%s' % o.name.replace('-', '_').upper())
v = os.environ.get(
'OS_{}'.format(o.name.replace('-', '_').upper())
)
if v:
return v

View File

@ -18,11 +18,13 @@ from keystoneauth1.loading import base
from keystoneauth1 import session
__all__ = ('register_argparse_arguments',
'load_from_argparse_arguments',
'register_conf_options',
'load_from_conf_options',
'get_conf_options')
__all__ = (
'register_argparse_arguments',
'load_from_argparse_arguments',
'register_conf_options',
'load_from_conf_options',
'get_conf_options',
)
def _positive_non_zero_float(argument_value):
@ -31,16 +33,15 @@ def _positive_non_zero_float(argument_value):
try:
value = float(argument_value)
except ValueError:
msg = "%s must be a float" % argument_value
msg = f"{argument_value} must be a float"
raise argparse.ArgumentTypeError(msg)
if value <= 0:
msg = "%s must be greater than 0" % argument_value
msg = f"{argument_value} must be greater than 0"
raise argparse.ArgumentTypeError(msg)
return value
class Session(base.BaseLoader):
@property
def plugin_class(self):
return session.Session
@ -48,13 +49,15 @@ class Session(base.BaseLoader):
def get_options(self):
return []
def load_from_options(self,
insecure=False,
verify=None,
cacert=None,
cert=None,
key=None,
**kwargs):
def load_from_options(
self,
insecure=False,
verify=None,
cacert=None,
cert=None,
key=None,
**kwargs,
):
"""Create a session with individual certificate parameters.
Some parameters used to create a session don't lend themselves to be
@ -72,14 +75,13 @@ class Session(base.BaseLoader):
# requests lib form of having the cert and key as a tuple
cert = (cert, key)
return super(Session, self).load_from_options(verify=verify,
cert=cert,
**kwargs)
return super().load_from_options(verify=verify, cert=cert, **kwargs)
def register_argparse_arguments(self, parser):
session_group = parser.add_argument_group(
'API Connection Options',
'Options controlling the HTTP API Connections')
'Options controlling the HTTP API Connections',
)
session_group.add_argument(
'--insecure',
@ -89,7 +91,8 @@ class Session(base.BaseLoader):
'"insecure" TLS (https) requests. The '
'server\'s certificate will not be verified '
'against any certificate authorities. This '
'option should be used with caution.')
'option should be used with caution.',
)
session_group.add_argument(
'--os-cacert',
@ -97,36 +100,41 @@ class Session(base.BaseLoader):
default=os.environ.get('OS_CACERT'),
help='Specify a CA bundle file to use in '
'verifying a TLS (https) server certificate. '
'Defaults to env[OS_CACERT].')
'Defaults to env[OS_CACERT].',
)
session_group.add_argument(
'--os-cert',
metavar='<certificate>',
default=os.environ.get('OS_CERT'),
help='The location for the keystore (PEM formatted) '
'containing the public key of this client. '
'Defaults to env[OS_CERT].')
'containing the public key of this client. '
'Defaults to env[OS_CERT].',
)
session_group.add_argument(
'--os-key',
metavar='<key>',
default=os.environ.get('OS_KEY'),
help='The location for the keystore (PEM formatted) '
'containing the private key of this client. '
'Defaults to env[OS_KEY].')
'containing the private key of this client. '
'Defaults to env[OS_KEY].',
)
session_group.add_argument(
'--timeout',
default=600,
type=_positive_non_zero_float,
metavar='<seconds>',
help='Set request timeout (in seconds).')
help='Set request timeout (in seconds).',
)
session_group.add_argument(
'--collect-timing',
default=False,
action='store_true',
help='Collect per-API call timing information.')
help='Collect per-API call timing information.',
)
def load_from_argparse_arguments(self, namespace, **kwargs):
kwargs.setdefault('insecure', namespace.insecure)
@ -162,7 +170,7 @@ class Session(base.BaseLoader):
``cafile`` option name::
old_opt = oslo_cfg.DeprecatedOpt('ca_file', 'old_group')
deprecated_opts={'cafile': [old_opt]}
deprecated_opts = {'cafile': [old_opt]}
:returns: A list of oslo_config options.
"""
@ -171,34 +179,47 @@ class Session(base.BaseLoader):
if deprecated_opts is None:
deprecated_opts = {}
return [cfg.StrOpt('cafile',
deprecated_opts=deprecated_opts.get('cafile'),
help='PEM encoded Certificate Authority to use '
'when verifying HTTPs connections.'),
cfg.StrOpt('certfile',
deprecated_opts=deprecated_opts.get('certfile'),
help='PEM encoded client certificate cert file'),
cfg.StrOpt('keyfile',
deprecated_opts=deprecated_opts.get('keyfile'),
help='PEM encoded client certificate key file'),
cfg.BoolOpt('insecure',
default=False,
deprecated_opts=deprecated_opts.get('insecure'),
help='Verify HTTPS connections.'),
cfg.IntOpt('timeout',
deprecated_opts=deprecated_opts.get('timeout'),
help='Timeout value for http requests'),
cfg.BoolOpt('collect-timing',
deprecated_opts=deprecated_opts.get(
'collect-timing'),
default=False,
help='Collect per-API call timing information.'),
cfg.BoolOpt('split-loggers',
deprecated_opts=deprecated_opts.get(
'split-loggers'),
default=False,
help='Log requests to multiple loggers.')
]
return [
cfg.StrOpt(
'cafile',
deprecated_opts=deprecated_opts.get('cafile'),
help='PEM encoded Certificate Authority to use '
'when verifying HTTPs connections.',
),
cfg.StrOpt(
'certfile',
deprecated_opts=deprecated_opts.get('certfile'),
help='PEM encoded client certificate cert file',
),
cfg.StrOpt(
'keyfile',
deprecated_opts=deprecated_opts.get('keyfile'),
help='PEM encoded client certificate key file',
),
cfg.BoolOpt(
'insecure',
default=False,
deprecated_opts=deprecated_opts.get('insecure'),
help='Verify HTTPS connections.',
),
cfg.IntOpt(
'timeout',
deprecated_opts=deprecated_opts.get('timeout'),
help='Timeout value for http requests',
),
cfg.BoolOpt(
'collect-timing',
deprecated_opts=deprecated_opts.get('collect-timing'),
default=False,
help='Collect per-API call timing information.',
),
cfg.BoolOpt(
'split-loggers',
deprecated_opts=deprecated_opts.get('split-loggers'),
default=False,
help='Log requests to multiple loggers.',
),
]
def register_conf_options(self, conf, group, deprecated_opts=None):
"""Register the oslo_config options that are needed for a session.
@ -223,7 +244,7 @@ class Session(base.BaseLoader):
``cafile`` option name::
old_opt = oslo_cfg.DeprecatedOpt('ca_file', 'old_group')
deprecated_opts={'cafile': [old_opt]}
deprecated_opts = {'cafile': [old_opt]}
:returns: The list of options that was registered.
"""

View File

@ -20,7 +20,7 @@ AUTH_INTERFACE = object()
IDENTITY_AUTH_HEADER_NAME = 'X-Auth-Token'
class BaseAuthPlugin(object):
class BaseAuthPlugin:
"""The basic structure of an authentication plugin.
.. note::
@ -110,10 +110,9 @@ class BaseAuthPlugin(object):
return {IDENTITY_AUTH_HEADER_NAME: token}
def get_endpoint_data(self, session,
endpoint_override=None,
discover_versions=True,
**kwargs):
def get_endpoint_data(
self, session, endpoint_override=None, discover_versions=True, **kwargs
):
"""Return a valid endpoint data for a the service.
:param session: A session object that can be used for communication.
@ -140,8 +139,10 @@ class BaseAuthPlugin(object):
return endpoint_data
return endpoint_data.get_versioned_data(
session, cache=self._discovery_cache,
discover_versions=discover_versions)
session,
cache=self._discovery_cache,
discover_versions=discover_versions,
)
def get_api_major_version(self, session, endpoint_override=None, **kwargs):
"""Get the major API version from the endpoint.
@ -158,16 +159,22 @@ class BaseAuthPlugin(object):
:rtype: `keystoneauth1.discover.EndpointData` or None
"""
endpoint_data = self.get_endpoint_data(
session, endpoint_override=endpoint_override,
discover_versions=False, **kwargs)
session,
endpoint_override=endpoint_override,
discover_versions=False,
**kwargs,
)
if endpoint_data is None:
return
if endpoint_data.api_version is None:
# No version detected from the URL, trying full discovery.
endpoint_data = self.get_endpoint_data(
session, endpoint_override=endpoint_override,
discover_versions=True, **kwargs)
session,
endpoint_override=endpoint_override,
discover_versions=True,
**kwargs,
)
if endpoint_data and endpoint_data.api_version:
return endpoint_data.api_version
@ -195,7 +202,8 @@ class BaseAuthPlugin(object):
:rtype: string
"""
endpoint_data = self.get_endpoint_data(
session, discover_versions=False, **kwargs)
session, discover_versions=False, **kwargs
)
if not endpoint_data:
return None
return endpoint_data.url
@ -340,7 +348,7 @@ class FixedEndpointPlugin(BaseAuthPlugin):
"""A base class for plugins that have one fixed endpoint."""
def __init__(self, endpoint=None):
super(FixedEndpointPlugin, self).__init__()
super().__init__()
self.endpoint = endpoint
def get_endpoint(self, session, **kwargs):
@ -352,10 +360,9 @@ class FixedEndpointPlugin(BaseAuthPlugin):
"""
return kwargs.get('endpoint_override') or self.endpoint
def get_endpoint_data(self, session,
endpoint_override=None,
discover_versions=True,
**kwargs):
def get_endpoint_data(
self, session, endpoint_override=None, discover_versions=True, **kwargs
):
"""Return a valid endpoint data for a the service.
:param session: A session object that can be used for communication.
@ -374,8 +381,9 @@ class FixedEndpointPlugin(BaseAuthPlugin):
:return: Valid EndpointData or None if not available.
:rtype: `keystoneauth1.discover.EndpointData` or None
"""
return super(FixedEndpointPlugin, self).get_endpoint_data(
return super().get_endpoint_data(
session,
endpoint_override=endpoint_override or self.endpoint,
discover_versions=discover_versions,
**kwargs)
**kwargs,
)

View File

@ -18,9 +18,8 @@ __all__ = ('ServiceTokenAuthWrapper',)
class ServiceTokenAuthWrapper(plugin.BaseAuthPlugin):
def __init__(self, user_auth, service_auth):
super(ServiceTokenAuthWrapper, self).__init__()
super().__init__()
self.user_auth = user_auth
self.service_auth = service_auth

View File

@ -40,14 +40,12 @@ try:
except ImportError:
osprofiler_web = None
DEFAULT_USER_AGENT = 'keystoneauth1/%s %s %s/%s' % (
keystoneauth1.__version__, requests.utils.default_user_agent(),
platform.python_implementation(), platform.python_version())
DEFAULT_USER_AGENT = f'keystoneauth1/{keystoneauth1.__version__} {requests.utils.default_user_agent()} {platform.python_implementation()}/{platform.python_version()}'
# NOTE(jamielennox): Clients will likely want to print more than json. Please
# propose a patch if you have a content type you think is reasonable to print
# here and we'll add it to the list as required.
_LOG_CONTENT_TYPES = set(['application/json', 'text/plain'])
_LOG_CONTENT_TYPES = {'application/json', 'text/plain'}
_MAX_RETRY_INTERVAL = 60.0
_EXPONENTIAL_DELAY_START = 0.5
@ -101,7 +99,7 @@ def _sanitize_headers(headers):
return str_dict
class NoOpSemaphore(object):
class NoOpSemaphore:
"""Empty context manager for use as a default semaphore."""
def __enter__(self):
@ -114,7 +112,6 @@ class NoOpSemaphore(object):
class _JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime.datetime):
return o.isoformat()
@ -123,10 +120,10 @@ class _JSONEncoder(json.JSONEncoder):
if netaddr and isinstance(o, netaddr.IPAddress):
return str(o)
return super(_JSONEncoder, self).default(o)
return super().default(o)
class _StringFormatter(object):
class _StringFormatter:
"""A String formatter that fetches values on demand."""
def __init__(self, session, auth):
@ -142,8 +139,10 @@ class _StringFormatter(object):
raise AttributeError(item)
if not value:
raise ValueError("This type of authentication does not provide a "
"%s that can be substituted" % item)
raise ValueError(
"This type of authentication does not provide a "
f"{item} that can be substituted"
)
return value
@ -159,8 +158,11 @@ def _determine_calling_package():
# because sys.modules can change during iteration, which results
# in a RuntimeError
# https://docs.python.org/3/library/sys.html#sys.modules
mod_lookup = dict((m.__file__, n) for n, m in sys.modules.copy().items()
if hasattr(m, '__file__'))
mod_lookup = {
m.__file__: n
for n, m in sys.modules.copy().items()
if hasattr(m, '__file__')
}
# NOTE(shaleh): these are not useful because they hide the real
# user of the code. debtcollector did not import keystoneauth but
@ -205,7 +207,7 @@ def _determine_user_agent():
# NOTE(shaleh): mod_wsgi is not any more useful than just
# reporting "keystoneauth". Ignore it and perform the package name
# heuristic.
ignored = ('mod_wsgi', )
ignored = ('mod_wsgi',)
try:
name = sys.argv[0]
@ -222,7 +224,7 @@ def _determine_user_agent():
return name
class RequestTiming(object):
class RequestTiming:
"""Contains timing information for an HTTP interaction."""
#: HTTP method used for the call (GET, POST, etc)
@ -240,7 +242,7 @@ class RequestTiming(object):
self.elapsed = elapsed
class _Retries(object):
class _Retries:
__slots__ = ('_fixed_delay', '_current')
def __init__(self, fixed_delay=None):
@ -263,7 +265,7 @@ class _Retries(object):
next = __next__
class Session(object):
class Session:
"""Maintains client communication state and common functionality.
As much as possible the parameters to this class reflect and are passed
@ -341,14 +343,26 @@ class Session(object):
_DEFAULT_REDIRECT_LIMIT = 30
def __init__(self, auth=None, session=None, original_ip=None, verify=True,
cert=None, timeout=None, user_agent=None,
redirect=_DEFAULT_REDIRECT_LIMIT, additional_headers=None,
app_name=None, app_version=None, additional_user_agent=None,
discovery_cache=None, split_loggers=None,
collect_timing=False, rate_semaphore=None,
connect_retries=0):
def __init__(
self,
auth=None,
session=None,
original_ip=None,
verify=True,
cert=None,
timeout=None,
user_agent=None,
redirect=_DEFAULT_REDIRECT_LIMIT,
additional_headers=None,
app_name=None,
app_version=None,
additional_user_agent=None,
discovery_cache=None,
split_loggers=None,
collect_timing=False,
rate_semaphore=None,
connect_retries=0,
):
self.auth = auth
self.session = _construct_session(session)
# NOTE(mwhahaha): keep a reference to the session object so we can
@ -383,7 +397,7 @@ class Session(object):
self.timeout = float(timeout)
if user_agent is not None:
self.user_agent = "%s %s" % (user_agent, DEFAULT_USER_AGENT)
self.user_agent = f"{user_agent} {DEFAULT_USER_AGENT}"
self._json = _JSONEncoder()
@ -431,13 +445,17 @@ class Session(object):
@staticmethod
def _process_header(header):
"""Redact the secure headers to be logged."""
secure_headers = ('authorization', 'x-auth-token',
'x-subject-token', 'x-service-token')
secure_headers = (
'authorization',
'x-auth-token',
'x-subject-token',
'x-service-token',
)
if header[0].lower() in secure_headers:
token_hasher = hashlib.sha256()
token_hasher.update(header[1].encode('utf-8'))
token_hash = token_hasher.hexdigest()
return (header[0], '{SHA256}%s' % token_hash)
return (header[0], f'{{SHA256}}{token_hash}')
return header
def _get_split_loggers(self, split_loggers):
@ -458,9 +476,17 @@ class Session(object):
split_loggers = False
return split_loggers
def _http_log_request(self, url, method=None, data=None,
json=None, headers=None, query_params=None,
logger=None, split_loggers=None):
def _http_log_request(
self,
url,
method=None,
data=None,
json=None,
headers=None,
query_params=None,
logger=None,
split_loggers=None,
):
string_parts = []
if self._get_split_loggers(split_loggers):
@ -484,7 +510,7 @@ class Session(object):
if self.verify is False:
string_parts.append('--insecure')
elif isinstance(self.verify, str):
string_parts.append('--cacert "%s"' % self.verify)
string_parts.append(f'--cacert "{self.verify}"')
if method:
string_parts.extend(['-X', method])
@ -495,15 +521,16 @@ class Session(object):
url = url + '?' + urllib.parse.urlencode(query_params)
# URLs with query strings need to be wrapped in quotes in order
# for the CURL command to run properly.
string_parts.append('"%s"' % url)
string_parts.append(f'"{url}"')
else:
string_parts.append(url)
if headers:
# Sort headers so that testing can work consistently.
for header in sorted(headers.items()):
string_parts.append('-H "%s: %s"'
% self._process_header(header))
string_parts.append(
'-H "{}: {}"'.format(*self._process_header(header))
)
if json:
data = self._json.encode(json)
if data:
@ -512,13 +539,20 @@ class Session(object):
data = data.decode("ascii")
except UnicodeDecodeError:
data = "<binary_data>"
string_parts.append("-d '%s'" % data)
string_parts.append(f"-d '{data}'")
logger.debug(' '.join(string_parts))
def _http_log_response(self, response=None, json=None,
status_code=None, headers=None, text=None,
logger=None, split_loggers=True):
def _http_log_response(
self,
response=None,
json=None,
status_code=None,
headers=None,
text=None,
logger=None,
split_loggers=True,
):
string_parts = []
body_parts = []
if self._get_split_loggers(split_loggers):
@ -540,11 +574,13 @@ class Session(object):
headers = response.headers
if status_code:
string_parts.append('[%s]' % status_code)
string_parts.append(f'[{status_code}]')
if headers:
# Sort headers so that testing can work consistently.
for header in sorted(headers.items()):
string_parts.append('%s: %s' % self._process_header(header))
string_parts.append(
'{}: {}'.format(*self._process_header(header))
)
logger.debug(' '.join(string_parts))
if not body_logger.isEnabledFor(logging.DEBUG):
@ -565,12 +601,15 @@ class Session(object):
# [1] https://www.w3.org/Protocols/rfc1341/4_Content-Type.html
for log_type in _LOG_CONTENT_TYPES:
if content_type is not None and content_type.startswith(
log_type):
log_type
):
text = self._remove_service_catalog(response.text)
break
else:
text = ('Omitted, Content-Type is set to %s. Only '
'%s responses have their bodies logged.')
text = (
'Omitted, Content-Type is set to %s. Only '
'%s responses have their bodies logged.'
)
text = text % (content_type, ', '.join(_LOG_CONTENT_TYPES))
if json:
text = self._json.encode(json)
@ -581,7 +620,8 @@ class Session(object):
@staticmethod
def _set_microversion_headers(
headers, microversion, service_type, endpoint_filter):
headers, microversion, service_type, endpoint_filter
):
# We're converting it to normalized version number for two reasons.
# First, to validate it's a real version number. Second, so that in
# the future we can pre-validate that it is within the range of
@ -592,27 +632,32 @@ class Session(object):
# with the microversion range we found in discovery.
microversion = discover.normalize_version_number(microversion)
# Can't specify a M.latest microversion
if (microversion[0] != discover.LATEST and
discover.LATEST in microversion[1:]):
if (
microversion[0] != discover.LATEST
and discover.LATEST in microversion[1:]
):
raise TypeError(
"Specifying a '{major}.latest' microversion is not allowed.")
"Specifying a '{major}.latest' microversion is not allowed."
)
microversion = discover.version_to_string(microversion)
if not service_type:
if endpoint_filter and 'service_type' in endpoint_filter:
service_type = endpoint_filter['service_type']
else:
raise TypeError(
"microversion {microversion} was requested but no"
f"microversion {microversion} was requested but no"
" service_type information is available. Either provide a"
" service_type in endpoint_filter or pass"
" microversion_service_type as an argument.".format(
microversion=microversion))
" microversion_service_type as an argument."
)
# TODO(mordred) cinder uses volume in its microversion header. This
# logic should be handled in the future by os-service-types but for
# now hard-code for cinder.
if (service_type.startswith('volume') or
service_type == 'block-storage'):
if (
service_type.startswith('volume')
or service_type == 'block-storage'
):
service_type = 'volume'
elif service_type.startswith('share'):
# NOTE(gouthamr) manila doesn't honor the "OpenStack-API-Version"
@ -622,25 +667,44 @@ class Session(object):
# service catalog
service_type = 'shared-file-system'
headers.setdefault('OpenStack-API-Version',
'{service_type} {microversion}'.format(
service_type=service_type,
microversion=microversion))
headers.setdefault(
'OpenStack-API-Version', f'{service_type} {microversion}'
)
header_names = _mv_legacy_headers_for_service(service_type)
for h in header_names:
headers.setdefault(h, microversion)
def request(self, url, method, json=None, original_ip=None,
user_agent=None, redirect=None, authenticated=None,
endpoint_filter=None, auth=None, requests_auth=None,
raise_exc=True, allow_reauth=True, log=True,
endpoint_override=None, connect_retries=None, logger=None,
allow=None, client_name=None, client_version=None,
microversion=None, microversion_service_type=None,
status_code_retries=0, retriable_status_codes=None,
rate_semaphore=None, global_request_id=None,
connect_retry_delay=None, status_code_retry_delay=None,
**kwargs):
def request(
self,
url,
method,
json=None,
original_ip=None,
user_agent=None,
redirect=None,
authenticated=None,
endpoint_filter=None,
auth=None,
requests_auth=None,
raise_exc=True,
allow_reauth=True,
log=True,
endpoint_override=None,
connect_retries=None,
logger=None,
allow=None,
client_name=None,
client_version=None,
microversion=None,
microversion_service_type=None,
status_code_retries=0,
retriable_status_codes=None,
rate_semaphore=None,
global_request_id=None,
connect_retry_delay=None,
status_code_retry_delay=None,
**kwargs,
):
"""Send an HTTP request with the specified characteristics.
Wrapper around `requests.Session.request` to handle tasks such as
@ -766,21 +830,26 @@ class Session(object):
# case insensitive.
if kwargs.get('headers'):
kwargs['headers'] = requests.structures.CaseInsensitiveDict(
kwargs['headers'])
kwargs['headers']
)
else:
kwargs['headers'] = requests.structures.CaseInsensitiveDict()
if connect_retries is None:
connect_retries = self._connect_retries
# HTTP 503 - Service Unavailable
retriable_status_codes = retriable_status_codes or \
_RETRIABLE_STATUS_CODES
retriable_status_codes = (
retriable_status_codes or _RETRIABLE_STATUS_CODES
)
rate_semaphore = rate_semaphore or self._rate_semaphore
headers = kwargs.setdefault('headers', dict())
headers = kwargs.setdefault('headers', {})
if microversion:
self._set_microversion_headers(
headers, microversion, microversion_service_type,
endpoint_filter)
headers,
microversion,
microversion_service_type,
endpoint_filter,
)
if authenticated is None:
authenticated = bool(auth or self.auth)
@ -807,13 +876,14 @@ class Session(object):
if endpoint_override:
base_url = endpoint_override % _StringFormatter(self, auth)
elif endpoint_filter:
base_url = self.get_endpoint(auth, allow=allow,
**endpoint_filter)
base_url = self.get_endpoint(
auth, allow=allow, **endpoint_filter
)
if not base_url:
raise exceptions.EndpointNotFound()
url = '%s/%s' % (base_url.rstrip('/'), url.lstrip('/'))
url = '{}/{}'.format(base_url.rstrip('/'), url.lstrip('/'))
if self.cert:
kwargs.setdefault('cert', self.cert)
@ -835,17 +905,17 @@ class Session(object):
agent = []
if self.app_name and self.app_version:
agent.append('%s/%s' % (self.app_name, self.app_version))
agent.append(f'{self.app_name}/{self.app_version}')
elif self.app_name:
agent.append(self.app_name)
if client_name and client_version:
agent.append('%s/%s' % (client_name, client_version))
agent.append(f'{client_name}/{client_version}')
elif client_name:
agent.append(client_name)
for additional in self.additional_user_agent:
agent.append('%s/%s' % additional)
agent.append('{}/{}'.format(*additional))
if not agent:
# NOTE(jamielennox): determine_user_agent will return an empty
@ -861,8 +931,9 @@ class Session(object):
user_agent = headers.setdefault('User-Agent', ' '.join(agent))
if self.original_ip:
headers.setdefault('Forwarded',
'for=%s;by=%s' % (self.original_ip, user_agent))
headers.setdefault(
'Forwarded', f'for={self.original_ip};by={user_agent}'
)
if json is not None:
headers.setdefault('Content-Type', 'application/json')
@ -890,14 +961,18 @@ class Session(object):
# be logged properly, but those sent in the `params` parameter
# (which the requests library handles) need to be explicitly
# picked out so they can be included in the URL that gets loggged.
query_params = kwargs.get('params', dict())
query_params = kwargs.get('params', {})
if log:
self._http_log_request(url, method=method,
data=kwargs.get('data'),
headers=headers,
query_params=query_params,
logger=logger, split_loggers=split_loggers)
self._http_log_request(
url,
method=method,
data=kwargs.get('data'),
headers=headers,
query_params=query_params,
logger=logger,
split_loggers=split_loggers,
)
# Force disable requests redirect handling. We will manage this below.
kwargs['allow_redirects'] = False
@ -908,12 +983,21 @@ class Session(object):
connect_retry_delays = _Retries(connect_retry_delay)
status_code_retry_delays = _Retries(status_code_retry_delay)
send = functools.partial(self._send_request,
url, method, redirect, log, logger,
split_loggers, connect_retries,
status_code_retries, retriable_status_codes,
rate_semaphore, connect_retry_delays,
status_code_retry_delays)
send = functools.partial(
self._send_request,
url,
method,
redirect,
log,
logger,
split_loggers,
connect_retries,
status_code_retries,
retriable_status_codes,
rate_semaphore,
connect_retry_delays,
status_code_retry_delays,
)
try:
connection_params = self.get_auth_connection_params(auth=auth)
@ -942,8 +1026,9 @@ class Session(object):
# Nova uses 'x-compute-request-id' and other services like
# Glance, Cinder etc are using 'x-openstack-request-id' to store
# request-id in the header
request_id = (resp.headers.get('x-openstack-request-id') or
resp.headers.get('x-compute-request-id'))
request_id = resp.headers.get(
'x-openstack-request-id'
) or resp.headers.get('x-compute-request-id')
if request_id:
if self._get_split_loggers(split_loggers):
id_logger = utils.get_logger(__name__ + '.request-id')
@ -953,21 +1038,25 @@ class Session(object):
id_logger.debug(
'%(method)s call to %(service_name)s for '
'%(url)s used request id '
'%(response_request_id)s', {
'%(response_request_id)s',
{
'method': resp.request.method,
'service_name': service_name,
'url': resp.url,
'response_request_id': request_id
})
'response_request_id': request_id,
},
)
else:
id_logger.debug(
'%(method)s call to '
'%(url)s used request id '
'%(response_request_id)s', {
'%(response_request_id)s',
{
'method': resp.request.method,
'url': resp.url,
'response_request_id': request_id
})
'response_request_id': request_id,
},
)
# handle getting a 401 Unauthorized response by invalidating the plugin
# and then retrying the request. This is only tried once.
@ -980,30 +1069,46 @@ class Session(object):
resp = send(**kwargs)
if raise_exc and resp.status_code >= 400:
logger.debug('Request returned failure status: %s',
resp.status_code)
logger.debug(
'Request returned failure status: %s', resp.status_code
)
raise exceptions.from_response(resp, method, url)
if self._collect_timing:
for h in resp.history:
self._api_times.append(RequestTiming(
method=h.request.method,
url=h.request.url,
elapsed=h.elapsed,
))
self._api_times.append(RequestTiming(
method=resp.request.method,
url=resp.request.url,
elapsed=resp.elapsed,
))
self._api_times.append(
RequestTiming(
method=h.request.method,
url=h.request.url,
elapsed=h.elapsed,
)
)
self._api_times.append(
RequestTiming(
method=resp.request.method,
url=resp.request.url,
elapsed=resp.elapsed,
)
)
return resp
def _send_request(self, url, method, redirect, log, logger, split_loggers,
connect_retries, status_code_retries,
retriable_status_codes, rate_semaphore,
connect_retry_delays, status_code_retry_delays,
**kwargs):
def _send_request(
self,
url,
method,
redirect,
log,
logger,
split_loggers,
connect_retries,
status_code_retries,
retriable_status_codes,
rate_semaphore,
connect_retry_delays,
status_code_retry_delays,
**kwargs,
):
# NOTE(jamielennox): We handle redirection manually because the
# requests lib follows some browser patterns where it will redirect
# POSTs as GETs for certain statuses which is not want we want for an
@ -1020,11 +1125,10 @@ class Session(object):
with rate_semaphore:
resp = self.session.request(method, url, **kwargs)
except requests.exceptions.SSLError as e:
msg = 'SSL exception connecting to %(url)s: %(error)s' % {
'url': url, 'error': e}
msg = f'SSL exception connecting to {url}: {e}'
raise exceptions.SSLError(msg)
except requests.exceptions.Timeout:
msg = 'Request to %s timed out' % url
msg = f'Request to {url} timed out'
raise exceptions.ConnectTimeout(msg)
except requests.exceptions.ConnectionError as e:
# NOTE(sdague): urllib3/requests connection error is a
@ -1033,11 +1137,10 @@ class Session(object):
# level message is often really important in figuring
# out the difference between network misconfigurations
# and firewall blocking.
msg = 'Unable to establish connection to %s: %s' % (url, e)
msg = f'Unable to establish connection to {url}: {e}'
raise exceptions.ConnectFailure(msg)
except requests.exceptions.RequestException as e:
msg = 'Unexpected exception for %(url)s: %(error)s' % {
'url': url, 'error': e}
msg = f'Unexpected exception for {url}: {e}'
raise exceptions.UnknownConnectionError(msg, e)
except exceptions.RetriableConnectionFailure as e:
@ -1045,26 +1148,33 @@ class Session(object):
raise
delay = next(connect_retry_delays)
logger.warning('Failure: %(e)s. Retrying in %(delay).1fs.'
'%(retries)s retries left',
{'e': e, 'delay': delay,
'retries': connect_retries})
logger.warning(
'Failure: %(e)s. Retrying in %(delay).1fs.'
'%(retries)s retries left',
{'e': e, 'delay': delay, 'retries': connect_retries},
)
time.sleep(delay)
return self._send_request(
url, method, redirect, log, logger, split_loggers,
url,
method,
redirect,
log,
logger,
split_loggers,
status_code_retries=status_code_retries,
retriable_status_codes=retriable_status_codes,
rate_semaphore=rate_semaphore,
connect_retries=connect_retries - 1,
connect_retry_delays=connect_retry_delays,
status_code_retry_delays=status_code_retry_delays,
**kwargs)
**kwargs,
)
if log:
self._http_log_response(
response=resp, logger=logger,
split_loggers=split_loggers)
response=resp, logger=logger, split_loggers=split_loggers
)
if resp.status_code in self._REDIRECT_STATUSES:
# be careful here in python True == 1 and False == 0
@ -1080,8 +1190,11 @@ class Session(object):
try:
location = resp.headers['location']
except KeyError:
logger.warning("Failed to redirect request to %s as new "
"location was not provided.", resp.url)
logger.warning(
"Failed to redirect request to %s as new "
"location was not provided.",
resp.url,
)
else:
# NOTE(TheJulia): Location redirects generally should have
# URI's to the destination.
@ -1090,50 +1203,69 @@ class Session(object):
kwargs['params'] = {}
if 'x-openstack-request-id' in resp.headers:
kwargs['headers'].setdefault('x-openstack-request-id',
resp.headers[
'x-openstack-request-id'])
kwargs['headers'].setdefault(
'x-openstack-request-id',
resp.headers['x-openstack-request-id'],
)
# NOTE(jamielennox): We don't keep increasing delays.
# This request actually worked so we can reset the delay count.
connect_retry_delays.reset()
status_code_retry_delays.reset()
new_resp = self._send_request(
location, method, redirect, log, logger, split_loggers,
location,
method,
redirect,
log,
logger,
split_loggers,
rate_semaphore=rate_semaphore,
connect_retries=connect_retries,
status_code_retries=status_code_retries,
retriable_status_codes=retriable_status_codes,
connect_retry_delays=connect_retry_delays,
status_code_retry_delays=status_code_retry_delays,
**kwargs)
**kwargs,
)
if not isinstance(new_resp.history, list):
new_resp.history = list(new_resp.history)
new_resp.history.insert(0, resp)
resp = new_resp
elif (resp.status_code in retriable_status_codes and
status_code_retries > 0):
elif (
resp.status_code in retriable_status_codes
and status_code_retries > 0
):
delay = next(status_code_retry_delays)
logger.warning('Retriable status code %(code)s. Retrying in '
'%(delay).1fs. %(retries)s retries left',
{'code': resp.status_code, 'delay': delay,
'retries': status_code_retries})
logger.warning(
'Retriable status code %(code)s. Retrying in '
'%(delay).1fs. %(retries)s retries left',
{
'code': resp.status_code,
'delay': delay,
'retries': status_code_retries,
},
)
time.sleep(delay)
# NOTE(jamielennox): We don't keep increasing connection delays.
# This request actually worked so we can reset the delay count.
connect_retry_delays.reset()
return self._send_request(
url, method, redirect, log, logger, split_loggers,
url,
method,
redirect,
log,
logger,
split_loggers,
connect_retries=connect_retries,
status_code_retries=status_code_retries - 1,
retriable_status_codes=retriable_status_codes,
rate_semaphore=rate_semaphore,
connect_retry_delays=connect_retry_delays,
status_code_retry_delays=status_code_retry_delays,
**kwargs)
**kwargs,
)
return resp
@ -1288,9 +1420,14 @@ class Session(object):
auth = self._auth_required(auth, 'determine endpoint URL')
return auth.get_api_major_version(self, **kwargs)
def get_all_version_data(self, auth=None, interface='public',
region_name=None, service_type=None,
**kwargs):
def get_all_version_data(
self,
auth=None,
interface='public',
region_name=None,
service_type=None,
**kwargs,
):
"""Get version data for all services in the catalog.
:param auth:
@ -1318,7 +1455,8 @@ class Session(object):
interface=interface,
region_name=region_name,
service_type=service_type,
**kwargs)
**kwargs,
)
def get_auth_connection_params(self, auth=None, **kwargs):
"""Return auth connection params as provided by the auth plugin.
@ -1461,17 +1599,19 @@ class TCPKeepAliveAdapter(requests.adapters.HTTPAdapter):
]
# Windows subsystem for Linux does not support this feature
if (hasattr(socket, 'TCP_KEEPCNT') and
not utils.is_windows_linux_subsystem):
if (
hasattr(socket, 'TCP_KEEPCNT')
and not utils.is_windows_linux_subsystem
):
socket_options += [
# Set the maximum number of keep-alive probes
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4),
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
]
if hasattr(socket, 'TCP_KEEPINTVL'):
socket_options += [
# Send keep-alive probes every 15 seconds
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15),
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
]
# After waiting 60 seconds, and then sending a probe once every 15
@ -1479,4 +1619,4 @@ class TCPKeepAliveAdapter(requests.adapters.HTTPAdapter):
# hands for no longer than 2 minutes before a ConnectionError is
# raised.
kwargs['socket_options'] = socket_options
super(TCPKeepAliveAdapter, self).init_poolmanager(*args, **kwargs)
super().init_poolmanager(*args, **kwargs)

View File

@ -21,7 +21,6 @@ from keystoneauth1.tests.unit import utils
class AccessV2Test(utils.TestCase):
def test_building_unscoped_accessinfo(self):
token = fixture.V2Token(expires='2012-10-03T16:58:01Z')
@ -115,12 +114,9 @@ class AccessV2Test(utils.TestCase):
'user': {
'id': 'user_id1',
'name': 'user_name1',
'roles': [
{'name': 'role1'},
{'name': 'role2'},
],
'roles': [{'name': 'role1'}, {'name': 'role2'}],
},
},
}
}
auth_ref = access.create(body=diablo_token)
@ -148,12 +144,9 @@ class AccessV2Test(utils.TestCase):
'name': 'user_name1',
'tenantId': 'tenant_id1',
'tenantName': 'tenant_name1',
'roles': [
{'name': 'role1'},
{'name': 'role2'},
],
'roles': [{'name': 'role1'}, {'name': 'role2'}],
},
},
}
}
auth_ref = access.create(body=grizzly_token)
@ -179,11 +172,13 @@ class AccessV2Test(utils.TestCase):
self.assertIsInstance(auth_ref, access.AccessInfoV2)
self.assertEqual([role_id], auth_ref.role_ids)
self.assertEqual([role_id],
auth_ref._data['access']['metadata']['roles'])
self.assertEqual(
[role_id], auth_ref._data['access']['metadata']['roles']
)
self.assertEqual([role_name], auth_ref.role_names)
self.assertEqual([{'name': role_name}],
auth_ref._data['access']['user']['roles'])
self.assertEqual(
[{'name': role_name}], auth_ref._data['access']['user']['roles']
)
def test_trusts(self):
user_id = uuid.uuid4().hex

View File

@ -20,7 +20,7 @@ from keystoneauth1.tests.unit import utils
class ServiceCatalogTest(utils.TestCase):
def setUp(self):
super(ServiceCatalogTest, self).setUp()
super().setUp()
self.AUTH_RESPONSE_BODY = fixture.V2Token(
token_id='ab48a9efdfedb23ty3494',
@ -29,18 +29,21 @@ class ServiceCatalogTest(utils.TestCase):
tenant_name='My Project',
user_id='123',
user_name='jqsmith',
audit_chain_id=uuid.uuid4().hex)
audit_chain_id=uuid.uuid4().hex,
)
self.AUTH_RESPONSE_BODY.add_role(id='234', name='compute:admin')
role = self.AUTH_RESPONSE_BODY.add_role(id='235',
name='object-store:admin')
role = self.AUTH_RESPONSE_BODY.add_role(
id='235', name='object-store:admin'
)
role['tenantId'] = '1'
s = self.AUTH_RESPONSE_BODY.add_service('compute', 'Cloud Servers')
endpoint = s.add_endpoint(
public='https://compute.north.host/v1/1234',
internal='https://compute.north.host/v1/1234',
region='North')
region='North',
)
endpoint['tenantId'] = '1'
endpoint['versionId'] = '1.0'
endpoint['versionInfo'] = 'https://compute.north.host/v1.0/'
@ -49,16 +52,19 @@ class ServiceCatalogTest(utils.TestCase):
endpoint = s.add_endpoint(
public='https://compute.north.host/v1.1/3456',
internal='https://compute.north.host/v1.1/3456',
region='North')
region='North',
)
endpoint['tenantId'] = '2'
endpoint['versionId'] = '1.1'
endpoint['versionInfo'] = 'https://compute.north.host/v1.1/'
endpoint['versionList'] = 'https://compute.north.host/'
s = self.AUTH_RESPONSE_BODY.add_service('object-store', 'Cloud Files')
endpoint = s.add_endpoint(public='https://swift.north.host/v1/blah',
internal='https://swift.north.host/v1/blah',
region='South')
endpoint = s.add_endpoint(
public='https://swift.north.host/v1/blah',
internal='https://swift.north.host/v1/blah',
region='South',
)
endpoint['tenantId'] = '11'
endpoint['versionId'] = '1.0'
endpoint['versionInfo'] = 'uri'
@ -67,48 +73,62 @@ class ServiceCatalogTest(utils.TestCase):
endpoint = s.add_endpoint(
public='https://swift.north.host/v1.1/blah',
internal='https://compute.north.host/v1.1/blah',
region='South')
region='South',
)
endpoint['tenantId'] = '2'
endpoint['versionId'] = '1.1'
endpoint['versionInfo'] = 'https://swift.north.host/v1.1/'
endpoint['versionList'] = 'https://swift.north.host/'
s = self.AUTH_RESPONSE_BODY.add_service('image', 'Image Servers')
s.add_endpoint(public='https://image.north.host/v1/',
internal='https://image-internal.north.host/v1/',
region='North')
s.add_endpoint(public='https://image.south.host/v1/',
internal='https://image-internal.south.host/v1/',
region='South')
s.add_endpoint(
public='https://image.north.host/v1/',
internal='https://image-internal.north.host/v1/',
region='North',
)
s.add_endpoint(
public='https://image.south.host/v1/',
internal='https://image-internal.south.host/v1/',
region='South',
)
def test_building_a_service_catalog(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
sc = auth_ref.service_catalog
self.assertEqual(sc.url_for(service_type='compute'),
"https://compute.north.host/v1/1234")
self.assertRaises(exceptions.EndpointNotFound,
sc.url_for,
region_name="South",
service_type='compute')
self.assertEqual(
sc.url_for(service_type='compute'),
"https://compute.north.host/v1/1234",
)
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for,
region_name="South",
service_type='compute',
)
def test_service_catalog_endpoints(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
sc = auth_ref.service_catalog
public_ep = sc.get_endpoints(service_type='compute',
interface='publicURL')
public_ep = sc.get_endpoints(
service_type='compute', interface='publicURL'
)
self.assertEqual(public_ep['compute'][1]['tenantId'], '2')
self.assertEqual(public_ep['compute'][1]['versionId'], '1.1')
self.assertEqual(public_ep['compute'][1]['internalURL'],
"https://compute.north.host/v1.1/3456")
self.assertEqual(
public_ep['compute'][1]['internalURL'],
"https://compute.north.host/v1.1/3456",
)
def test_service_catalog_empty(self):
self.AUTH_RESPONSE_BODY['access']['serviceCatalog'] = []
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
self.assertRaises(exceptions.EmptyCatalog,
auth_ref.service_catalog.url_for,
service_type='image',
interface='internalURL')
self.assertRaises(
exceptions.EmptyCatalog,
auth_ref.service_catalog.url_for,
service_type='image',
interface='internalURL',
)
def test_service_catalog_get_endpoints_region_names(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
@ -116,23 +136,27 @@ class ServiceCatalogTest(utils.TestCase):
endpoints = sc.get_endpoints(service_type='image', region_name='North')
self.assertEqual(len(endpoints), 1)
self.assertEqual(endpoints['image'][0]['publicURL'],
'https://image.north.host/v1/')
self.assertEqual(
endpoints['image'][0]['publicURL'], 'https://image.north.host/v1/'
)
endpoints = sc.get_endpoints(service_type='image', region_name='South')
self.assertEqual(len(endpoints), 1)
self.assertEqual(endpoints['image'][0]['publicURL'],
'https://image.south.host/v1/')
self.assertEqual(
endpoints['image'][0]['publicURL'], 'https://image.south.host/v1/'
)
endpoints = sc.get_endpoints(service_type='compute')
self.assertEqual(len(endpoints['compute']), 2)
endpoints = sc.get_endpoints(service_type='compute',
region_name='North')
endpoints = sc.get_endpoints(
service_type='compute', region_name='North'
)
self.assertEqual(len(endpoints['compute']), 2)
endpoints = sc.get_endpoints(service_type='compute',
region_name='West')
endpoints = sc.get_endpoints(
service_type='compute', region_name='West'
)
self.assertEqual(len(endpoints['compute']), 0)
def test_service_catalog_url_for_region_names(self):
@ -145,8 +169,12 @@ class ServiceCatalogTest(utils.TestCase):
url = sc.url_for(service_type='image', region_name='South')
self.assertEqual(url, 'https://image.south.host/v1/')
self.assertRaises(exceptions.EndpointNotFound, sc.url_for,
service_type='image', region_name='West')
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for,
service_type='image',
region_name='West',
)
def test_servcie_catalog_get_url_region_names(self):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
@ -170,21 +198,33 @@ class ServiceCatalogTest(utils.TestCase):
auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
sc = auth_ref.service_catalog
url = sc.url_for(service_name='Image Servers', interface='public',
service_type='image', region_name='North')
url = sc.url_for(
service_name='Image Servers',
interface='public',
service_type='image',
region_name='North',
)
self.assertEqual('https://image.north.host/v1/', url)
self.assertRaises(exceptions.EndpointNotFound, sc.url_for,
service_name='Image Servers', service_type='compute')
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for,
service_name='Image Servers',
service_type='compute',
)
urls = sc.get_urls(service_type='image', service_name='Image Servers',
interface='public')
urls = sc.get_urls(
service_type='image',
service_name='Image Servers',
interface='public',
)
self.assertIn('https://image.north.host/v1/', urls)
self.assertIn('https://image.south.host/v1/', urls)
urls = sc.get_urls(service_type='image', service_name='Servers',
interface='public')
urls = sc.get_urls(
service_type='image', service_name='Servers', interface='public'
)
self.assertEqual(0, len(urls))
@ -194,23 +234,28 @@ class ServiceCatalogTest(utils.TestCase):
for i in range(3):
s = token.add_service('compute')
s.add_endpoint(public='public-%d' % i,
admin='admin-%d' % i,
internal='internal-%d' % i,
region='region-%d' % i)
s.add_endpoint(
public='public-%d' % i,
admin='admin-%d' % i,
internal='internal-%d' % i,
region='region-%d' % i,
)
auth_ref = access.create(body=token)
urls = auth_ref.service_catalog.get_urls(service_type='compute',
interface='publicURL')
urls = auth_ref.service_catalog.get_urls(
service_type='compute', interface='publicURL'
)
self.assertEqual(set(['public-0', 'public-1', 'public-2']), set(urls))
self.assertEqual({'public-0', 'public-1', 'public-2'}, set(urls))
urls = auth_ref.service_catalog.get_urls(service_type='compute',
interface='publicURL',
region_name='region-1')
urls = auth_ref.service_catalog.get_urls(
service_type='compute',
interface='publicURL',
region_name='region-1',
)
self.assertEqual(('public-1', ), urls)
self.assertEqual(('public-1',), urls)
def test_service_catalog_endpoint_id(self):
token = fixture.V2Token()
@ -228,23 +273,27 @@ class ServiceCatalogTest(utils.TestCase):
urls = auth_ref.service_catalog.get_urls(interface='public')
self.assertEqual(2, len(urls))
urls = auth_ref.service_catalog.get_urls(endpoint_id=endpoint_id,
interface='public')
urls = auth_ref.service_catalog.get_urls(
endpoint_id=endpoint_id, interface='public'
)
self.assertEqual((public_url, ), urls)
self.assertEqual((public_url,), urls)
# with bad endpoint_id nothing should be found
urls = auth_ref.service_catalog.get_urls(endpoint_id=uuid.uuid4().hex,
interface='public')
urls = auth_ref.service_catalog.get_urls(
endpoint_id=uuid.uuid4().hex, interface='public'
)
self.assertEqual(0, len(urls))
# we ignore a service_id because v2 doesn't know what it is
urls = auth_ref.service_catalog.get_urls(endpoint_id=endpoint_id,
service_id=uuid.uuid4().hex,
interface='public')
urls = auth_ref.service_catalog.get_urls(
endpoint_id=endpoint_id,
service_id=uuid.uuid4().hex,
interface='public',
)
self.assertEqual((public_url, ), urls)
self.assertEqual((public_url,), urls)
def test_service_catalog_without_service_type(self):
token = fixture.V2Token()
@ -260,8 +309,9 @@ class ServiceCatalogTest(utils.TestCase):
s.add_endpoint(public=public_url)
auth_ref = access.create(body=token)
urls = auth_ref.service_catalog.get_urls(service_type=None,
interface='public')
urls = auth_ref.service_catalog.get_urls(
service_type=None, interface='public'
)
self.assertEqual(3, len(urls))

View File

@ -21,7 +21,6 @@ from keystoneauth1.tests.unit import utils
class AccessV3Test(utils.TestCase):
def test_building_unscoped_accessinfo(self):
token = fixture.V3Token()
token_id = uuid.uuid4().hex
@ -52,10 +51,14 @@ class AccessV3Test(utils.TestCase):
self.assertIsNone(auth_ref.project_domain_id)
self.assertIsNone(auth_ref.project_domain_name)
self.assertEqual(auth_ref.expires, timeutils.parse_isotime(
token['token']['expires_at']))
self.assertEqual(auth_ref.issued, timeutils.parse_isotime(
token['token']['issued_at']))
self.assertEqual(
auth_ref.expires,
timeutils.parse_isotime(token['token']['expires_at']),
)
self.assertEqual(
auth_ref.issued,
timeutils.parse_isotime(token['token']['issued_at']),
)
self.assertEqual(auth_ref.expires, token.expires)
self.assertEqual(auth_ref.issued, token.issued)
@ -197,8 +200,9 @@ class AccessV3Test(utils.TestCase):
self.assertEqual(auth_ref.tenant_id, auth_ref.project_id)
self.assertEqual(token.project_domain_id, auth_ref.project_domain_id)
self.assertEqual(token.project_domain_name,
auth_ref.project_domain_name)
self.assertEqual(
token.project_domain_name, auth_ref.project_domain_name
)
self.assertEqual(token.user_domain_id, auth_ref.user_domain_id)
self.assertEqual(token.user_domain_name, auth_ref.user_domain_name)
@ -243,8 +247,9 @@ class AccessV3Test(utils.TestCase):
self.assertEqual(auth_ref.tenant_id, auth_ref.project_id)
self.assertEqual(token.project_domain_id, auth_ref.project_domain_id)
self.assertEqual(token.project_domain_name,
auth_ref.project_domain_name)
self.assertEqual(
token.project_domain_name, auth_ref.project_domain_name
)
self.assertEqual(token.user_domain_id, auth_ref.user_domain_id)
self.assertEqual(token.user_domain_name, auth_ref.user_domain_name)
@ -262,19 +267,22 @@ class AccessV3Test(utils.TestCase):
token = fixture.V3Token()
token.set_project_scope()
token.set_oauth(access_token_id=access_token_id,
consumer_id=consumer_id)
token.set_oauth(
access_token_id=access_token_id, consumer_id=consumer_id
)
auth_ref = access.create(body=token)
self.assertEqual(consumer_id, auth_ref.oauth_consumer_id)
self.assertEqual(access_token_id, auth_ref.oauth_access_token_id)
self.assertEqual(consumer_id,
auth_ref._data['token']['OS-OAUTH1']['consumer_id'])
self.assertEqual(
consumer_id, auth_ref._data['token']['OS-OAUTH1']['consumer_id']
)
self.assertEqual(
access_token_id,
auth_ref._data['token']['OS-OAUTH1']['access_token_id'])
auth_ref._data['token']['OS-OAUTH1']['access_token_id'],
)
def test_federated_property_standard_token(self):
"""Check if is_federated property returns expected value."""

View File

@ -20,10 +20,11 @@ from keystoneauth1.tests.unit import utils
class ServiceCatalogTest(utils.TestCase):
def setUp(self):
super(ServiceCatalogTest, self).setUp()
super().setUp()
self.AUTH_RESPONSE_BODY = fixture.V3Token(
audit_chain_id=uuid.uuid4().hex)
audit_chain_id=uuid.uuid4().hex
)
self.AUTH_RESPONSE_BODY.set_project_scope()
self.AUTH_RESPONSE_BODY.add_role(name='admin')
@ -34,207 +35,251 @@ class ServiceCatalogTest(utils.TestCase):
public='https://compute.north.host/novapi/public',
internal='https://compute.north.host/novapi/internal',
admin='https://compute.north.host/novapi/admin',
region='North')
region='North',
)
s = self.AUTH_RESPONSE_BODY.add_service('object-store', name='swift')
s.add_standard_endpoints(
public='http://swift.north.host/swiftapi/public',
internal='http://swift.north.host/swiftapi/internal',
admin='http://swift.north.host/swiftapi/admin',
region='South')
region='South',
)
s = self.AUTH_RESPONSE_BODY.add_service('image', name='glance')
s.add_standard_endpoints(
public='http://glance.north.host/glanceapi/public',
internal='http://glance.north.host/glanceapi/internal',
admin='http://glance.north.host/glanceapi/admin',
region='North')
region='North',
)
s.add_standard_endpoints(
public='http://glance.south.host/glanceapi/public',
internal='http://glance.south.host/glanceapi/internal',
admin='http://glance.south.host/glanceapi/admin',
region='South')
region='South',
)
s = self.AUTH_RESPONSE_BODY.add_service('block-storage', name='cinder')
s.add_standard_endpoints(
public='http://cinder.north.host/cinderapi/public',
internal='http://cinder.north.host/cinderapi/internal',
admin='http://cinder.north.host/cinderapi/admin',
region='North')
region='North',
)
s = self.AUTH_RESPONSE_BODY.add_service('volumev2', name='cinder')
s.add_standard_endpoints(
public='http://cinder.south.host/cinderapi/public/v2',
internal='http://cinder.south.host/cinderapi/internal/v2',
admin='http://cinder.south.host/cinderapi/admin/v2',
region='South')
region='South',
)
s = self.AUTH_RESPONSE_BODY.add_service('volumev3', name='cinder')
s.add_standard_endpoints(
public='http://cinder.south.host/cinderapi/public/v3',
internal='http://cinder.south.host/cinderapi/internal/v3',
admin='http://cinder.south.host/cinderapi/admin/v3',
region='South')
region='South',
)
self.north_endpoints = {'public':
'http://glance.north.host/glanceapi/public',
'internal':
'http://glance.north.host/glanceapi/internal',
'admin':
'http://glance.north.host/glanceapi/admin'}
self.north_endpoints = {
'public': 'http://glance.north.host/glanceapi/public',
'internal': 'http://glance.north.host/glanceapi/internal',
'admin': 'http://glance.north.host/glanceapi/admin',
}
self.south_endpoints = {'public':
'http://glance.south.host/glanceapi/public',
'internal':
'http://glance.south.host/glanceapi/internal',
'admin':
'http://glance.south.host/glanceapi/admin'}
self.south_endpoints = {
'public': 'http://glance.south.host/glanceapi/public',
'internal': 'http://glance.south.host/glanceapi/internal',
'admin': 'http://glance.south.host/glanceapi/admin',
}
def test_building_a_service_catalog(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
self.assertEqual(sc.url_for(service_type='compute'),
"https://compute.north.host/novapi/public")
self.assertEqual(sc.url_for(service_type='compute',
interface='internal'),
"https://compute.north.host/novapi/internal")
self.assertEqual(
sc.url_for(service_type='compute'),
"https://compute.north.host/novapi/public",
)
self.assertEqual(
sc.url_for(service_type='compute', interface='internal'),
"https://compute.north.host/novapi/internal",
)
self.assertRaises(exceptions.EndpointNotFound,
sc.url_for,
region_name='South',
service_type='compute')
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for,
region_name='South',
service_type='compute',
)
def test_service_catalog_endpoints(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
public_ep = sc.get_endpoints(service_type='compute',
interface='public')
public_ep = sc.get_endpoints(
service_type='compute', interface='public'
)
self.assertEqual(public_ep['compute'][0]['region'], 'North')
self.assertEqual(public_ep['compute'][0]['url'],
"https://compute.north.host/novapi/public")
self.assertEqual(
public_ep['compute'][0]['url'],
"https://compute.north.host/novapi/public",
)
def test_service_catalog_alias_find_official(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
# Tests that we find the block-storage endpoint when we request
# the volume endpoint.
public_ep = sc.get_endpoints(service_type='volume',
interface='public',
region_name='North')
public_ep = sc.get_endpoints(
service_type='volume', interface='public', region_name='North'
)
self.assertEqual(public_ep['block-storage'][0]['region'], 'North')
self.assertEqual(public_ep['block-storage'][0]['url'],
"http://cinder.north.host/cinderapi/public")
self.assertEqual(
public_ep['block-storage'][0]['url'],
"http://cinder.north.host/cinderapi/public",
)
def test_service_catalog_alias_find_exact_match(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
# Tests that we find the volumev3 endpoint when we request it.
public_ep = sc.get_endpoints(service_type='volumev3',
interface='public')
public_ep = sc.get_endpoints(
service_type='volumev3', interface='public'
)
self.assertEqual(public_ep['volumev3'][0]['region'], 'South')
self.assertEqual(public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3")
self.assertEqual(
public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3",
)
def test_service_catalog_alias_find_best_match(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
# Tests that we find the volumev3 endpoint when we request
# block-storage when only volumev2 and volumev3 are present since
# volumev3 comes first in the list.
public_ep = sc.get_endpoints(service_type='block-storage',
interface='public',
region_name='South')
public_ep = sc.get_endpoints(
service_type='block-storage',
interface='public',
region_name='South',
)
self.assertEqual(public_ep['volumev3'][0]['region'], 'South')
self.assertEqual(public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3")
self.assertEqual(
public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3",
)
def test_service_catalog_alias_all_by_name(self):
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
# Tests that we find all the cinder endpoints since we request
# them by name and that no filtering related to aliases happens.
public_ep = sc.get_endpoints(service_name='cinder',
interface='public')
public_ep = sc.get_endpoints(service_name='cinder', interface='public')
self.assertEqual(public_ep['volumev2'][0]['region'], 'South')
self.assertEqual(public_ep['volumev2'][0]['url'],
"http://cinder.south.host/cinderapi/public/v2")
self.assertEqual(
public_ep['volumev2'][0]['url'],
"http://cinder.south.host/cinderapi/public/v2",
)
self.assertEqual(public_ep['volumev3'][0]['region'], 'South')
self.assertEqual(public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3")
self.assertEqual(
public_ep['volumev3'][0]['url'],
"http://cinder.south.host/cinderapi/public/v3",
)
self.assertEqual(public_ep['block-storage'][0]['region'], 'North')
self.assertEqual(public_ep['block-storage'][0]['url'],
"http://cinder.north.host/cinderapi/public")
self.assertEqual(
public_ep['block-storage'][0]['url'],
"http://cinder.north.host/cinderapi/public",
)
def test_service_catalog_regions(self):
self.AUTH_RESPONSE_BODY['token']['region_name'] = "North"
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
url = sc.url_for(service_type='image', interface='public')
self.assertEqual(url, "http://glance.north.host/glanceapi/public")
self.AUTH_RESPONSE_BODY['token']['region_name'] = "South"
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
sc = auth_ref.service_catalog
url = sc.url_for(service_type='image',
region_name="South",
interface='internal')
url = sc.url_for(
service_type='image', region_name="South", interface='internal'
)
self.assertEqual(url, "http://glance.south.host/glanceapi/internal")
def test_service_catalog_empty(self):
self.AUTH_RESPONSE_BODY['token']['catalog'] = []
auth_ref = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY)
self.assertRaises(exceptions.EmptyCatalog,
auth_ref.service_catalog.url_for,
service_type='image',
interface='internalURL')
auth_ref = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
)
self.assertRaises(
exceptions.EmptyCatalog,
auth_ref.service_catalog.url_for,
service_type='image',
interface='internalURL',
)
def test_service_catalog_get_endpoints_region_names(self):
sc = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY).service_catalog
sc = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
endpoints = sc.get_endpoints(service_type='image', region_name='North')
self.assertEqual(len(endpoints), 1)
for endpoint in endpoints['image']:
self.assertEqual(endpoint['url'],
self.north_endpoints[endpoint['interface']])
self.assertEqual(
endpoint['url'], self.north_endpoints[endpoint['interface']]
)
endpoints = sc.get_endpoints(service_type='image', region_name='South')
self.assertEqual(len(endpoints), 1)
for endpoint in endpoints['image']:
self.assertEqual(endpoint['url'],
self.south_endpoints[endpoint['interface']])
self.assertEqual(
endpoint['url'], self.south_endpoints[endpoint['interface']]
)
endpoints = sc.get_endpoints(service_type='compute')
self.assertEqual(len(endpoints['compute']), 3)
endpoints = sc.get_endpoints(service_type='compute',
region_name='North')
endpoints = sc.get_endpoints(
service_type='compute', region_name='North'
)
self.assertEqual(len(endpoints['compute']), 3)
endpoints = sc.get_endpoints(service_type='compute',
region_name='West')
endpoints = sc.get_endpoints(
service_type='compute', region_name='West'
)
self.assertEqual(len(endpoints['compute']), 0)
def test_service_catalog_url_for_region_names(self):
sc = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY).service_catalog
sc = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
url = sc.url_for(service_type='image', region_name='North')
self.assertEqual(url, self.north_endpoints['public'])
@ -242,12 +287,17 @@ class ServiceCatalogTest(utils.TestCase):
url = sc.url_for(service_type='image', region_name='South')
self.assertEqual(url, self.south_endpoints['public'])
self.assertRaises(exceptions.EndpointNotFound, sc.url_for,
service_type='image', region_name='West')
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for,
service_type='image',
region_name='West',
)
def test_service_catalog_get_url_region_names(self):
sc = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY).service_catalog
sc = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
urls = sc.get_urls(service_type='image')
self.assertEqual(len(urls), 2)
@ -264,29 +314,43 @@ class ServiceCatalogTest(utils.TestCase):
self.assertEqual(len(urls), 0)
def test_service_catalog_service_name(self):
sc = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY).service_catalog
sc = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
url = sc.url_for(service_name='glance', interface='public',
service_type='image', region_name='North')
url = sc.url_for(
service_name='glance',
interface='public',
service_type='image',
region_name='North',
)
self.assertEqual('http://glance.north.host/glanceapi/public', url)
url = sc.url_for(service_name='glance', interface='public',
service_type='image', region_name='South')
url = sc.url_for(
service_name='glance',
interface='public',
service_type='image',
region_name='South',
)
self.assertEqual('http://glance.south.host/glanceapi/public', url)
self.assertRaises(exceptions.EndpointNotFound, sc.url_for,
service_name='glance', service_type='compute')
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for,
service_name='glance',
service_type='compute',
)
urls = sc.get_urls(service_type='image', service_name='glance',
interface='public')
urls = sc.get_urls(
service_type='image', service_name='glance', interface='public'
)
self.assertIn('http://glance.north.host/glanceapi/public', urls)
self.assertIn('http://glance.south.host/glanceapi/public', urls)
urls = sc.get_urls(service_type='image',
service_name='Servers',
interface='public')
urls = sc.get_urls(
service_type='image', service_name='Servers', interface='public'
)
self.assertEqual(0, len(urls))
@ -304,81 +368,102 @@ class ServiceCatalogTest(utils.TestCase):
s = f.add_service('volume')
s.add_standard_endpoints(
public='http://public.com:8776/v1/%s' % tenant,
internal='http://internal:8776/v1/%s' % tenant,
admin='http://admin:8776/v1/%s' % tenant,
region=region)
public=f'http://public.com:8776/v1/{tenant}',
internal=f'http://internal:8776/v1/{tenant}',
admin=f'http://admin:8776/v1/{tenant}',
region=region,
)
s = f.add_service('image')
s.add_standard_endpoints(public='http://public.com:9292/v1',
internal='http://internal:9292/v1',
admin='http://admin:9292/v1',
region=region)
s.add_standard_endpoints(
public='http://public.com:9292/v1',
internal='http://internal:9292/v1',
admin='http://admin:9292/v1',
region=region,
)
s = f.add_service('compute')
s.add_standard_endpoints(
public='http://public.com:8774/v2/%s' % tenant,
internal='http://internal:8774/v2/%s' % tenant,
admin='http://admin:8774/v2/%s' % tenant,
region=region)
public=f'http://public.com:8774/v2/{tenant}',
internal=f'http://internal:8774/v2/{tenant}',
admin=f'http://admin:8774/v2/{tenant}',
region=region,
)
s = f.add_service('ec2')
s.add_standard_endpoints(
public='http://public.com:8773/services/Cloud',
internal='http://internal:8773/services/Cloud',
admin='http://admin:8773/services/Admin',
region=region)
region=region,
)
s = f.add_service('identity')
s.add_standard_endpoints(public='http://public.com:5000/v3',
internal='http://internal:5000/v3',
admin='http://admin:35357/v3',
region=region)
s.add_standard_endpoints(
public='http://public.com:5000/v3',
internal='http://internal:5000/v3',
admin='http://admin:35357/v3',
region=region,
)
pr_auth_ref = access.create(body=f)
pr_sc = pr_auth_ref.service_catalog
# this will work because there are no service names on that token
url_ref = 'http://public.com:8774/v2/225da22d3ce34b15877ea70b2a575f58'
url = pr_sc.url_for(service_type='compute', service_name='NotExist',
interface='public')
url = pr_sc.url_for(
service_type='compute', service_name='NotExist', interface='public'
)
self.assertEqual(url_ref, url)
ab_auth_ref = access.create(body=self.AUTH_RESPONSE_BODY)
ab_sc = ab_auth_ref.service_catalog
# this won't work because there is a name and it's not this one
self.assertRaises(exceptions.EndpointNotFound, ab_sc.url_for,
service_type='compute', service_name='NotExist',
interface='public')
self.assertRaises(
exceptions.EndpointNotFound,
ab_sc.url_for,
service_type='compute',
service_name='NotExist',
interface='public',
)
class ServiceCatalogV3Test(ServiceCatalogTest):
def test_building_a_service_catalog(self):
sc = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY).service_catalog
sc = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
self.assertEqual(sc.url_for(service_type='compute'),
'https://compute.north.host/novapi/public')
self.assertEqual(sc.url_for(service_type='compute',
interface='internal'),
'https://compute.north.host/novapi/internal')
self.assertEqual(
sc.url_for(service_type='compute'),
'https://compute.north.host/novapi/public',
)
self.assertEqual(
sc.url_for(service_type='compute', interface='internal'),
'https://compute.north.host/novapi/internal',
)
self.assertRaises(exceptions.EndpointNotFound,
sc.url_for,
region_name='South',
service_type='compute')
self.assertRaises(
exceptions.EndpointNotFound,
sc.url_for,
region_name='South',
service_type='compute',
)
def test_service_catalog_endpoints(self):
sc = access.create(auth_token=uuid.uuid4().hex,
body=self.AUTH_RESPONSE_BODY).service_catalog
sc = access.create(
auth_token=uuid.uuid4().hex, body=self.AUTH_RESPONSE_BODY
).service_catalog
public_ep = sc.get_endpoints(service_type='compute',
interface='public')
public_ep = sc.get_endpoints(
service_type='compute', interface='public'
)
self.assertEqual(public_ep['compute'][0]['region_id'], 'North')
self.assertEqual(public_ep['compute'][0]['url'],
'https://compute.north.host/novapi/public')
self.assertEqual(
public_ep['compute'][0]['url'],
'https://compute.north.host/novapi/public',
)
def test_service_catalog_multiple_service_types(self):
token = fixture.V3Token()
@ -386,23 +471,26 @@ class ServiceCatalogV3Test(ServiceCatalogTest):
for i in range(3):
s = token.add_service('compute')
s.add_standard_endpoints(public='public-%d' % i,
admin='admin-%d' % i,
internal='internal-%d' % i,
region='region-%d' % i)
s.add_standard_endpoints(
public='public-%d' % i,
admin='admin-%d' % i,
internal='internal-%d' % i,
region='region-%d' % i,
)
auth_ref = access.create(resp=None, body=token)
urls = auth_ref.service_catalog.get_urls(service_type='compute',
interface='public')
urls = auth_ref.service_catalog.get_urls(
service_type='compute', interface='public'
)
self.assertEqual(set(['public-0', 'public-1', 'public-2']), set(urls))
self.assertEqual({'public-0', 'public-1', 'public-2'}, set(urls))
urls = auth_ref.service_catalog.get_urls(service_type='compute',
interface='public',
region_name='region-1')
urls = auth_ref.service_catalog.get_urls(
service_type='compute', interface='public', region_name='region-1'
)
self.assertEqual(('public-1', ), urls)
self.assertEqual(('public-1',), urls)
def test_service_catalog_endpoint_id(self):
token = fixture.V3Token()
@ -419,37 +507,42 @@ class ServiceCatalogV3Test(ServiceCatalogTest):
auth_ref = access.create(body=token)
# initially assert that we get back all our urls for a simple filter
urls = auth_ref.service_catalog.get_urls(service_type='compute',
interface='public')
urls = auth_ref.service_catalog.get_urls(
service_type='compute', interface='public'
)
self.assertEqual(2, len(urls))
# with bad endpoint_id nothing should be found
urls = auth_ref.service_catalog.get_urls(service_type='compute',
endpoint_id=uuid.uuid4().hex,
interface='public')
urls = auth_ref.service_catalog.get_urls(
service_type='compute',
endpoint_id=uuid.uuid4().hex,
interface='public',
)
self.assertEqual(0, len(urls))
# with service_id we get back both public endpoints
urls = auth_ref.service_catalog.get_urls(service_type='compute',
service_id=service_id,
interface='public')
urls = auth_ref.service_catalog.get_urls(
service_type='compute', service_id=service_id, interface='public'
)
self.assertEqual(2, len(urls))
# with service_id and endpoint_id we get back the url we want
urls = auth_ref.service_catalog.get_urls(service_type='compute',
service_id=service_id,
endpoint_id=endpoint_id,
interface='public')
urls = auth_ref.service_catalog.get_urls(
service_type='compute',
service_id=service_id,
endpoint_id=endpoint_id,
interface='public',
)
self.assertEqual((public_url, ), urls)
self.assertEqual((public_url,), urls)
# with service_id and endpoint_id we get back the url we want
urls = auth_ref.service_catalog.get_urls(service_type='compute',
endpoint_id=endpoint_id,
interface='public')
urls = auth_ref.service_catalog.get_urls(
service_type='compute', endpoint_id=endpoint_id, interface='public'
)
self.assertEqual((public_url, ), urls)
self.assertEqual((public_url,), urls)
def test_service_catalog_without_service_type(self):
token = fixture.V3Token()

View File

@ -23,7 +23,8 @@ def project_scoped_token():
project_id='225da22d3ce34b15877ea70b2a575f58',
project_name='exampleproject',
project_domain_id='4e6893b7ba0b4006840c3845660b86ed',
project_domain_name='exampledomain')
project_domain_name='exampledomain',
)
fixture.add_role(id='76e72a', name='admin')
fixture.add_role(id='f4f392', name='member')
@ -33,36 +34,43 @@ def project_scoped_token():
service = fixture.add_service('volume')
service.add_standard_endpoints(
public='http://public.com:8776/v1/%s' % tenant,
internal='http://internal:8776/v1/%s' % tenant,
admin='http://admin:8776/v1/%s' % tenant,
region=region)
public=f'http://public.com:8776/v1/{tenant}',
internal=f'http://internal:8776/v1/{tenant}',
admin=f'http://admin:8776/v1/{tenant}',
region=region,
)
service = fixture.add_service('image')
service.add_standard_endpoints(public='http://public.com:9292/v1',
internal='http://internal:9292/v1',
admin='http://admin:9292/v1',
region=region)
service.add_standard_endpoints(
public='http://public.com:9292/v1',
internal='http://internal:9292/v1',
admin='http://admin:9292/v1',
region=region,
)
service = fixture.add_service('compute')
service.add_standard_endpoints(
public='http://public.com:8774/v2/%s' % tenant,
internal='http://internal:8774/v2/%s' % tenant,
admin='http://admin:8774/v2/%s' % tenant,
region=region)
public=f'http://public.com:8774/v2/{tenant}',
internal=f'http://internal:8774/v2/{tenant}',
admin=f'http://admin:8774/v2/{tenant}',
region=region,
)
service = fixture.add_service('ec2')
service.add_standard_endpoints(
public='http://public.com:8773/services/Cloud',
internal='http://internal:8773/services/Cloud',
admin='http://admin:8773/services/Admin',
region=region)
region=region,
)
service = fixture.add_service('identity')
service.add_standard_endpoints(public='http://public.com:5000/v3',
internal='http://internal:5000/v3',
admin='http://admin:35357/v3',
region=region)
service.add_standard_endpoints(
public='http://public.com:5000/v3',
internal='http://internal:5000/v3',
admin='http://admin:35357/v3',
region=region,
)
return fixture
@ -75,48 +83,56 @@ def domain_scoped_token():
user_domain_name='exampledomain',
expires='2010-11-01T03:32:15-05:00',
domain_id='8e9283b7ba0b1038840c3842058b86ab',
domain_name='anotherdomain')
domain_name='anotherdomain',
)
fixture.add_role(id='76e72a', name='admin')
fixture.add_role(id='f4f392', name='member')
region = 'RegionOne'
service = fixture.add_service('volume')
service.add_standard_endpoints(public='http://public.com:8776/v1/None',
internal='http://internal.com:8776/v1/None',
admin='http://admin.com:8776/v1/None',
region=region)
service.add_standard_endpoints(
public='http://public.com:8776/v1/None',
internal='http://internal.com:8776/v1/None',
admin='http://admin.com:8776/v1/None',
region=region,
)
service = fixture.add_service('image')
service.add_standard_endpoints(public='http://public.com:9292/v1',
internal='http://internal:9292/v1',
admin='http://admin:9292/v1',
region=region)
service.add_standard_endpoints(
public='http://public.com:9292/v1',
internal='http://internal:9292/v1',
admin='http://admin:9292/v1',
region=region,
)
service = fixture.add_service('compute')
service.add_standard_endpoints(public='http://public.com:8774/v1.1/None',
internal='http://internal:8774/v1.1/None',
admin='http://admin:8774/v1.1/None',
region=region)
service.add_standard_endpoints(
public='http://public.com:8774/v1.1/None',
internal='http://internal:8774/v1.1/None',
admin='http://admin:8774/v1.1/None',
region=region,
)
service = fixture.add_service('ec2')
service.add_standard_endpoints(
public='http://public.com:8773/services/Cloud',
internal='http://internal:8773/services/Cloud',
admin='http://admin:8773/services/Admin',
region=region)
region=region,
)
service = fixture.add_service('identity')
service.add_standard_endpoints(public='http://public.com:5000/v3',
internal='http://internal:5000/v3',
admin='http://admin:35357/v3',
region=region)
service.add_standard_endpoints(
public='http://public.com:5000/v3',
internal='http://internal:5000/v3',
admin='http://admin:35357/v3',
region=region,
)
return fixture
AUTH_SUBJECT_TOKEN = '3e2813b7ba0b4006840c3825860b86ed'
AUTH_RESPONSE_HEADERS = {
'X-Subject-Token': AUTH_SUBJECT_TOKEN,
}
AUTH_RESPONSE_HEADERS = {'X-Subject-Token': AUTH_SUBJECT_TOKEN}

View File

@ -15,7 +15,6 @@ from keystoneauth1.tests.unit import utils
class ExceptionTests(utils.TestCase):
def test_clientexception_with_message(self):
test_message = 'Unittest exception message.'
exc = exceptions.ClientException(message=test_message)
@ -23,10 +22,8 @@ class ExceptionTests(utils.TestCase):
def test_clientexception_with_no_message(self):
exc = exceptions.ClientException()
self.assertEqual(exceptions.ClientException.__name__,
exc.message)
self.assertEqual(exceptions.ClientException.__name__, exc.message)
def test_using_default_message(self):
exc = exceptions.AuthorizationFailure()
self.assertEqual(exceptions.AuthorizationFailure.message,
exc.message)
self.assertEqual(exceptions.AuthorizationFailure.message, exc.message)

View File

@ -17,8 +17,7 @@ from keystoneauth1.tests.unit.extras.kerberos import utils
from keystoneauth1.tests.unit import utils as test_utils
REQUEST = {'auth': {'identity': {'methods': ['kerberos'],
'kerberos': {}}}}
REQUEST = {'auth': {'identity': {'methods': ['kerberos'], 'kerberos': {}}}}
class TestCase(test_utils.TestCase):
@ -27,7 +26,7 @@ class TestCase(test_utils.TestCase):
TEST_V3_URL = test_utils.TestCase.TEST_ROOT_URL + 'v3'
def setUp(self):
super(TestCase, self).setUp()
super().setUp()
km = utils.KerberosMock(self.requests_mock)
self.kerberos_mock = self.useFixture(km)

View File

@ -16,24 +16,26 @@ from keystoneauth1.tests.unit import utils as test_utils
class FedKerbLoadingTests(test_utils.TestCase):
def test_options(self):
opts = [o.name for o in
loading.get_plugin_loader('v3fedkerb').get_options()]
opts = [
o.name
for o in loading.get_plugin_loader('v3fedkerb').get_options()
]
allowed_opts = ['system-scope',
'domain-id',
'domain-name',
'identity-provider',
'project-id',
'project-name',
'project-domain-id',
'project-domain-name',
'protocol',
'trust-id',
'auth-url',
'mutual-auth',
]
allowed_opts = [
'system-scope',
'domain-id',
'domain-name',
'identity-provider',
'project-id',
'project-name',
'project-domain-id',
'project-domain-name',
'protocol',
'trust-id',
'auth-url',
'mutual-auth',
]
self.assertCountEqual(allowed_opts, opts)
@ -45,6 +47,6 @@ class FedKerbLoadingTests(test_utils.TestCase):
self.assertRaises(exceptions.MissingRequiredOptions, self.create)
def test_load(self):
self.create(auth_url='auth_url',
identity_provider='idp',
protocol='protocol')
self.create(
auth_url='auth_url', identity_provider='idp', protocol='protocol'
)

View File

@ -15,21 +15,23 @@ from keystoneauth1.tests.unit import utils as test_utils
class KerberosLoadingTests(test_utils.TestCase):
def test_options(self):
opts = [o.name for o in
loading.get_plugin_loader('v3kerberos').get_options()]
opts = [
o.name
for o in loading.get_plugin_loader('v3kerberos').get_options()
]
allowed_opts = ['system-scope',
'domain-id',
'domain-name',
'project-id',
'project-name',
'project-domain-id',
'project-domain-name',
'trust-id',
'auth-url',
'mutual-auth',
]
allowed_opts = [
'system-scope',
'domain-id',
'domain-name',
'project-id',
'project-name',
'project-domain-id',
'project-domain-name',
'trust-id',
'auth-url',
'mutual-auth',
]
self.assertCountEqual(allowed_opts, opts)

View File

@ -19,12 +19,11 @@ from keystoneauth1.tests.unit.extras.kerberos import base
class TestMappedAuth(base.TestCase):
def setUp(self):
if kerberos.requests_kerberos is None:
self.skipTest("Kerberos support isn't available.")
super(TestMappedAuth, self).setUp()
super().setUp()
self.protocol = uuid.uuid4().hex
self.identity_provider = uuid.uuid4().hex
@ -32,18 +31,18 @@ class TestMappedAuth(base.TestCase):
@property
def token_url(self):
fmt = '%s/OS-FEDERATION/identity_providers/%s/protocols/%s/auth'
return fmt % (
self.TEST_V3_URL,
self.identity_provider,
self.protocol)
return fmt % (self.TEST_V3_URL, self.identity_provider, self.protocol)
def test_unscoped_mapped_auth(self):
token_id, _ = self.kerberos_mock.mock_auth_success(
url=self.token_url, method='GET')
url=self.token_url, method='GET'
)
plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol,
identity_provider=self.identity_provider)
auth_url=self.TEST_V3_URL,
protocol=self.protocol,
identity_provider=self.identity_provider,
)
sess = session.Session()
tok = plugin.get_token(sess)
@ -51,23 +50,27 @@ class TestMappedAuth(base.TestCase):
self.assertEqual(token_id, tok)
def test_project_scoped_mapped_auth(self):
self.kerberos_mock.mock_auth_success(url=self.token_url,
method='GET')
self.kerberos_mock.mock_auth_success(url=self.token_url, method='GET')
scoped_id = uuid.uuid4().hex
scoped_body = ks_fixture.V3Token()
scoped_body.set_project_scope()
self.requests_mock.post(
'%s/auth/tokens' % self.TEST_V3_URL,
f'{self.TEST_V3_URL}/auth/tokens',
json=scoped_body,
headers={'X-Subject-Token': scoped_id,
'Content-Type': 'application/json'})
headers={
'X-Subject-Token': scoped_id,
'Content-Type': 'application/json',
},
)
plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol,
auth_url=self.TEST_V3_URL,
protocol=self.protocol,
identity_provider=self.identity_provider,
project_id=scoped_body.project_id)
project_id=scoped_body.project_id,
)
sess = session.Session()
tok = plugin.get_token(sess)
@ -77,24 +80,28 @@ class TestMappedAuth(base.TestCase):
self.assertEqual(scoped_body.project_id, proj)
def test_authenticate_with_mutual_authentication_required(self):
self.kerberos_mock.mock_auth_success(url=self.token_url,
method='GET')
self.kerberos_mock.mock_auth_success(url=self.token_url, method='GET')
scoped_id = uuid.uuid4().hex
scoped_body = ks_fixture.V3Token()
scoped_body.set_project_scope()
self.requests_mock.post(
'%s/auth/tokens' % self.TEST_V3_URL,
f'{self.TEST_V3_URL}/auth/tokens',
json=scoped_body,
headers={'X-Subject-Token': scoped_id,
'Content-Type': 'application/json'})
headers={
'X-Subject-Token': scoped_id,
'Content-Type': 'application/json',
},
)
plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol,
auth_url=self.TEST_V3_URL,
protocol=self.protocol,
identity_provider=self.identity_provider,
project_id=scoped_body.project_id,
mutual_auth='required')
mutual_auth='required',
)
sess = session.Session()
tok = plugin.get_token(sess)
@ -105,24 +112,28 @@ class TestMappedAuth(base.TestCase):
self.assertEqual(self.kerberos_mock.called_auth_server, True)
def test_authenticate_with_mutual_authentication_disabled(self):
self.kerberos_mock.mock_auth_success(url=self.token_url,
method='GET')
self.kerberos_mock.mock_auth_success(url=self.token_url, method='GET')
scoped_id = uuid.uuid4().hex
scoped_body = ks_fixture.V3Token()
scoped_body.set_project_scope()
self.requests_mock.post(
'%s/auth/tokens' % self.TEST_V3_URL,
f'{self.TEST_V3_URL}/auth/tokens',
json=scoped_body,
headers={'X-Subject-Token': scoped_id,
'Content-Type': 'application/json'})
headers={
'X-Subject-Token': scoped_id,
'Content-Type': 'application/json',
},
)
plugin = kerberos.MappedKerberos(
auth_url=self.TEST_V3_URL, protocol=self.protocol,
auth_url=self.TEST_V3_URL,
protocol=self.protocol,
identity_provider=self.identity_provider,
project_id=scoped_body.project_id,
mutual_auth='disabled')
mutual_auth='disabled',
)
sess = session.Session()
tok = plugin.get_token(sess)

View File

@ -16,12 +16,11 @@ from keystoneauth1.tests.unit.extras.kerberos import base
class TestKerberosAuth(base.TestCase):
def setUp(self):
if kerberos.requests_kerberos is None:
self.skipTest("Kerberos support isn't available.")
super(TestKerberosAuth, self).setUp()
super().setUp()
def test_authenticate_with_kerberos_domain_scoped(self):
token_id, token_body = self.kerberos_mock.mock_auth_success()
@ -33,22 +32,25 @@ class TestKerberosAuth(base.TestCase):
self.assertRequestBody()
self.assertEqual(
self.kerberos_mock.challenge_header,
self.requests_mock.last_request.headers['Authorization'])
self.requests_mock.last_request.headers['Authorization'],
)
self.assertEqual(token_id, a.auth_ref.auth_token)
self.assertEqual(token_id, token)
def test_authenticate_with_kerberos_mutual_authentication_required(self):
token_id, token_body = self.kerberos_mock.mock_auth_success()
a = kerberos.Kerberos(self.TEST_ROOT_URL + 'v3',
mutual_auth='required')
a = kerberos.Kerberos(
self.TEST_ROOT_URL + 'v3', mutual_auth='required'
)
s = session.Session(a)
token = a.get_token(s)
self.assertRequestBody()
self.assertEqual(
self.kerberos_mock.challenge_header,
self.requests_mock.last_request.headers['Authorization'])
self.requests_mock.last_request.headers['Authorization'],
)
self.assertEqual(token_id, a.auth_ref.auth_token)
self.assertEqual(token_id, token)
self.assertEqual(self.kerberos_mock.called_auth_server, True)
@ -56,15 +58,17 @@ class TestKerberosAuth(base.TestCase):
def test_authenticate_with_kerberos_mutual_authentication_disabled(self):
token_id, token_body = self.kerberos_mock.mock_auth_success()
a = kerberos.Kerberos(self.TEST_ROOT_URL + 'v3',
mutual_auth='disabled')
a = kerberos.Kerberos(
self.TEST_ROOT_URL + 'v3', mutual_auth='disabled'
)
s = session.Session(a)
token = a.get_token(s)
self.assertRequestBody()
self.assertEqual(
self.kerberos_mock.challenge_header,
self.requests_mock.last_request.headers['Authorization'])
self.requests_mock.last_request.headers['Authorization'],
)
self.assertEqual(token_id, a.auth_ref.auth_token)
self.assertEqual(token_id, token)
self.assertEqual(self.kerberos_mock.called_auth_server, False)

View File

@ -13,6 +13,7 @@
import uuid
import fixtures
try:
# requests_kerberos won't be available on py3, it doesn't work with py3.
import requests_kerberos
@ -24,29 +25,32 @@ from keystoneauth1.tests.unit import utils as test_utils
class KerberosMock(fixtures.Fixture):
def __init__(self, requests_mock):
super(KerberosMock, self).__init__()
super().__init__()
self.challenge_header = 'Negotiate %s' % uuid.uuid4().hex
self.pass_header = 'Negotiate %s' % uuid.uuid4().hex
self.challenge_header = f'Negotiate {uuid.uuid4().hex}'
self.pass_header = f'Negotiate {uuid.uuid4().hex}'
self.requests_mock = requests_mock
def setUp(self):
super(KerberosMock, self).setUp()
super().setUp()
if requests_kerberos is None:
return
m = fixtures.MockPatchObject(requests_kerberos.HTTPKerberosAuth,
'generate_request_header',
self._generate_request_header)
m = fixtures.MockPatchObject(
requests_kerberos.HTTPKerberosAuth,
'generate_request_header',
self._generate_request_header,
)
self.header_fixture = self.useFixture(m)
m = fixtures.MockPatchObject(requests_kerberos.HTTPKerberosAuth,
'authenticate_server',
self._authenticate_server)
m = fixtures.MockPatchObject(
requests_kerberos.HTTPKerberosAuth,
'authenticate_server',
self._authenticate_server,
)
self.authenticate_fixture = self.useFixture(m)
@ -58,11 +62,12 @@ class KerberosMock(fixtures.Fixture):
return response.headers.get('www-authenticate') == self.pass_header
def mock_auth_success(
self,
token_id=None,
token_body=None,
method='POST',
url=test_utils.TestCase.TEST_ROOT_URL + 'v3/auth/tokens'):
self,
token_id=None,
token_body=None,
method='POST',
url=test_utils.TestCase.TEST_ROOT_URL + 'v3/auth/tokens',
):
if not token_id:
token_id = uuid.uuid4().hex
if not token_body:
@ -70,17 +75,25 @@ class KerberosMock(fixtures.Fixture):
self.called_auth_server = False
response_list = [{'text': 'Fail',
'status_code': 401,
'headers': {'WWW-Authenticate': 'Negotiate'}},
{'headers': {'X-Subject-Token': token_id,
'Content-Type': 'application/json',
'WWW-Authenticate': self.pass_header},
'status_code': 200,
'json': token_body}]
response_list = [
{
'text': 'Fail',
'status_code': 401,
'headers': {'WWW-Authenticate': 'Negotiate'},
},
{
'headers': {
'X-Subject-Token': token_id,
'Content-Type': 'application/json',
'WWW-Authenticate': self.pass_header,
},
'status_code': 200,
'json': token_body,
},
]
self.requests_mock.register_uri(method,
url,
response_list=response_list)
self.requests_mock.register_uri(
method, url, response_list=response_list
)
return token_id, token_body

View File

@ -22,17 +22,20 @@ from keystoneauth1.tests.unit import utils as test_utils
class OAuth1AuthTests(test_utils.TestCase):
TEST_ROOT_URL = 'http://127.0.0.1:5000/'
TEST_URL = '%s%s' % (TEST_ROOT_URL, 'v3')
TEST_URL = '{}{}'.format(TEST_ROOT_URL, 'v3')
TEST_TOKEN = uuid.uuid4().hex
def stub_auth(self, subject_token=None, **kwargs):
if not subject_token:
subject_token = self.TEST_TOKEN
self.stub_url('POST', ['auth', 'tokens'],
headers={'X-Subject-Token': subject_token}, **kwargs)
self.stub_url(
'POST',
['auth', 'tokens'],
headers={'X-Subject-Token': subject_token},
**kwargs,
)
def _validate_oauth_headers(self, auth_header, oauth_client):
"""Validate data in the headers.
@ -42,22 +45,27 @@ class OAuth1AuthTests(test_utils.TestCase):
"""
self.assertThat(auth_header, matchers.StartsWith('OAuth '))
parameters = dict(
oauth1.rfc5849.utils.parse_authorization_header(auth_header))
oauth1.rfc5849.utils.parse_authorization_header(auth_header)
)
self.assertEqual('HMAC-SHA1', parameters['oauth_signature_method'])
self.assertEqual('1.0', parameters['oauth_version'])
self.assertIsInstance(parameters['oauth_nonce'], str)
self.assertEqual(oauth_client.client_key,
parameters['oauth_consumer_key'])
self.assertEqual(
oauth_client.client_key, parameters['oauth_consumer_key']
)
if oauth_client.resource_owner_key:
self.assertEqual(oauth_client.resource_owner_key,
parameters['oauth_token'],)
self.assertEqual(
oauth_client.resource_owner_key, parameters['oauth_token']
)
if oauth_client.verifier:
self.assertEqual(oauth_client.verifier,
parameters['oauth_verifier'])
self.assertEqual(
oauth_client.verifier, parameters['oauth_verifier']
)
if oauth_client.callback_uri:
self.assertEqual(oauth_client.callback_uri,
parameters['oauth_callback'])
self.assertEqual(
oauth_client.callback_uri, parameters['oauth_callback']
)
return parameters
def test_oauth_authenticate_success(self):
@ -66,18 +74,22 @@ class OAuth1AuthTests(test_utils.TestCase):
access_key = uuid.uuid4().hex
access_secret = uuid.uuid4().hex
oauth_token = fixture.V3Token(methods=['oauth1'],
oauth_consumer_id=consumer_key,
oauth_access_token_id=access_key)
oauth_token = fixture.V3Token(
methods=['oauth1'],
oauth_consumer_id=consumer_key,
oauth_access_token_id=access_key,
)
oauth_token.set_project_scope()
self.stub_auth(json=oauth_token)
a = ksa_oauth1.V3OAuth1(self.TEST_URL,
consumer_key=consumer_key,
consumer_secret=consumer_secret,
access_key=access_key,
access_secret=access_secret)
a = ksa_oauth1.V3OAuth1(
self.TEST_URL,
consumer_key=consumer_key,
consumer_secret=consumer_secret,
access_key=access_key,
access_secret=access_secret,
)
s = session.Session(auth=a)
t = s.get_token()
@ -85,32 +97,32 @@ class OAuth1AuthTests(test_utils.TestCase):
self.assertEqual(self.TEST_TOKEN, t)
OAUTH_REQUEST_BODY = {
"auth": {
"identity": {
"methods": ["oauth1"],
"oauth1": {}
}
}
"auth": {"identity": {"methods": ["oauth1"], "oauth1": {}}}
}
self.assertRequestBodyIs(json=OAUTH_REQUEST_BODY)
# Assert that the headers have the same oauthlib data
req_headers = self.requests_mock.last_request.headers
oauth_client = oauth1.Client(consumer_key,
client_secret=consumer_secret,
resource_owner_key=access_key,
resource_owner_secret=access_secret,
signature_method=oauth1.SIGNATURE_HMAC)
self._validate_oauth_headers(req_headers['Authorization'],
oauth_client)
oauth_client = oauth1.Client(
consumer_key,
client_secret=consumer_secret,
resource_owner_key=access_key,
resource_owner_secret=access_secret,
signature_method=oauth1.SIGNATURE_HMAC,
)
self._validate_oauth_headers(
req_headers['Authorization'], oauth_client
)
def test_warning_dual_scope(self):
ksa_oauth1.V3OAuth1(self.TEST_URL,
consumer_key=uuid.uuid4().hex,
consumer_secret=uuid.uuid4().hex,
access_key=uuid.uuid4().hex,
access_secret=uuid.uuid4().hex,
project_id=uuid.uuid4().hex)
ksa_oauth1.V3OAuth1(
self.TEST_URL,
consumer_key=uuid.uuid4().hex,
consumer_secret=uuid.uuid4().hex,
access_key=uuid.uuid4().hex,
access_secret=uuid.uuid4().hex,
project_id=uuid.uuid4().hex,
)
self.assertIn('ignored by the identity server', self.logger.output)

View File

@ -17,9 +17,8 @@ from keystoneauth1.tests.unit import utils as test_utils
class OAuth1LoadingTests(test_utils.TestCase):
def setUp(self):
super(OAuth1LoadingTests, self).setUp()
super().setUp()
self.auth_url = uuid.uuid4().hex
def create(self, **kwargs):
@ -33,10 +32,12 @@ class OAuth1LoadingTests(test_utils.TestCase):
consumer_key = uuid.uuid4().hex
consumer_secret = uuid.uuid4().hex
p = self.create(access_key=access_key,
access_secret=access_secret,
consumer_key=consumer_key,
consumer_secret=consumer_secret)
p = self.create(
access_key=access_key,
access_secret=access_secret,
consumer_key=consumer_key,
consumer_secret=consumer_secret,
)
oauth_method = p.auth_methods[0]
@ -49,9 +50,13 @@ class OAuth1LoadingTests(test_utils.TestCase):
def test_options(self):
options = loading.get_plugin_loader('v3oauth1').get_options()
self.assertEqual(set([o.name for o in options]),
set(['auth-url',
'access-key',
'access-secret',
'consumer-key',
'consumer-secret']))
self.assertEqual(
{o.name for o in options},
{
'auth-url',
'access-key',
'access-secret',
'consumer-key',
'consumer-secret',
},
)

View File

@ -23,22 +23,25 @@ def template(f, **kwargs):
def soap_response(**kwargs):
kwargs.setdefault('provider', 'https://idp.testshib.org/idp/shibboleth')
kwargs.setdefault('consumer',
'https://openstack4.local/Shibboleth.sso/SAML2/ECP')
kwargs.setdefault(
'consumer', 'https://openstack4.local/Shibboleth.sso/SAML2/ECP'
)
kwargs.setdefault('issuer', 'https://openstack4.local/shibboleth')
return template('soap_response.xml', **kwargs).encode('utf-8')
def saml_assertion(**kwargs):
kwargs.setdefault('issuer', 'https://idp.testshib.org/idp/shibboleth')
kwargs.setdefault('destination',
'https://openstack4.local/Shibboleth.sso/SAML2/ECP')
kwargs.setdefault(
'destination', 'https://openstack4.local/Shibboleth.sso/SAML2/ECP'
)
return template('saml_assertion.xml', **kwargs).encode('utf-8')
def authn_request(**kwargs):
kwargs.setdefault('issuer',
'https://openstack4.local/Shibboleth.sso/SAML2/ECP')
kwargs.setdefault(
'issuer', 'https://openstack4.local/Shibboleth.sso/SAML2/ECP'
)
return template('authn_request.xml', **kwargs).encode('utf-8')
@ -56,19 +59,13 @@ UNSCOPED_TOKEN = {
"expires_at": "2014-06-09T10:48:59.643375Z",
"user": {
"OS-FEDERATION": {
"identity_provider": {
"id": "testshib"
},
"protocol": {
"id": "saml2"
},
"groups": [
{"id": "1764fa5cf69a49a4918131de5ce4af9a"}
]
"identity_provider": {"id": "testshib"},
"protocol": {"id": "saml2"},
"groups": [{"id": "1764fa5cf69a49a4918131de5ce4af9a"}],
},
"id": "testhib%20user",
"name": "testhib user"
}
"name": "testhib user",
},
}
}
@ -78,26 +75,22 @@ PROJECTS = {
"domain_id": "37ef61",
"enabled": 'true',
"id": "12d706",
"links": {
"self": "http://identity:35357/v3/projects/12d706"
},
"name": "a project name"
"links": {"self": "http://identity:35357/v3/projects/12d706"},
"name": "a project name",
},
{
"domain_id": "37ef61",
"enabled": 'true',
"id": "9ca0eb",
"links": {
"self": "http://identity:35357/v3/projects/9ca0eb"
},
"name": "another project"
}
"links": {"self": "http://identity:35357/v3/projects/9ca0eb"},
"name": "another project",
},
],
"links": {
"self": "http://identity:35357/v3/OS-FEDERATION/projects",
"previous": 'null',
"next": 'null'
}
"next": 'null',
},
}
DOMAINS = {
@ -106,15 +99,13 @@ DOMAINS = {
"description": "desc of domain",
"enabled": 'true',
"id": "37ef61",
"links": {
"self": "http://identity:35357/v3/domains/37ef61"
},
"name": "my domain"
"links": {"self": "http://identity:35357/v3/domains/37ef61"},
"name": "my domain",
}
],
"links": {
"self": "http://identity:35357/v3/OS-FEDERATION/domains",
"previous": 'null',
"next": 'null'
}
"next": 'null',
},
}

View File

@ -24,7 +24,6 @@ from keystoneauth1.tests.unit import matchers
class AuthenticateviaADFSTests(utils.TestCase):
GROUP = 'auth'
NAMESPACES = {
@ -33,24 +32,23 @@ class AuthenticateviaADFSTests(utils.TestCase):
'wsa': 'http://www.w3.org/2005/08/addressing',
'wsp': 'http://schemas.xmlsoap.org/ws/2004/09/policy',
'a': 'http://www.w3.org/2005/08/addressing',
'o': ('http://docs.oasis-open.org/wss/2004/01/oasis'
'-200401-wss-wssecurity-secext-1.0.xsd')
'o': (
'http://docs.oasis-open.org/wss/2004/01/oasis'
'-200401-wss-wssecurity-secext-1.0.xsd'
),
}
USER_XPATH = ('/s:Envelope/s:Header'
'/o:Security'
'/o:UsernameToken'
'/o:Username')
PASSWORD_XPATH = ('/s:Envelope/s:Header'
'/o:Security'
'/o:UsernameToken'
'/o:Password')
ADDRESS_XPATH = ('/s:Envelope/s:Body'
'/trust:RequestSecurityToken'
'/wsp:AppliesTo/wsa:EndpointReference'
'/wsa:Address')
TO_XPATH = ('/s:Envelope/s:Header'
'/a:To')
USER_XPATH = '/s:Envelope/s:Header/o:Security/o:UsernameToken/o:Username'
PASSWORD_XPATH = (
'/s:Envelope/s:Header/o:Security/o:UsernameToken/o:Password'
)
ADDRESS_XPATH = (
'/s:Envelope/s:Body'
'/trust:RequestSecurityToken'
'/wsp:AppliesTo/wsa:EndpointReference'
'/wsa:Address'
)
TO_XPATH = '/s:Envelope/s:Header/a:To'
TEST_TOKEN = uuid.uuid4().hex
@ -61,24 +59,32 @@ class AuthenticateviaADFSTests(utils.TestCase):
return '4b911420-4982-4009-8afc-5c596cd487f5'
def setUp(self):
super(AuthenticateviaADFSTests, self).setUp()
super().setUp()
self.IDENTITY_PROVIDER = 'adfs'
self.IDENTITY_PROVIDER_URL = ('http://adfs.local/adfs/service/trust/13'
'/usernamemixed')
self.FEDERATION_AUTH_URL = '%s/%s' % (
self.IDENTITY_PROVIDER_URL = (
'http://adfs.local/adfs/service/trust/13/usernamemixed'
)
self.FEDERATION_AUTH_URL = '{}/{}'.format(
self.TEST_URL,
'OS-FEDERATION/identity_providers/adfs/protocols/saml2/auth')
'OS-FEDERATION/identity_providers/adfs/protocols/saml2/auth',
)
self.SP_ENDPOINT = 'https://openstack4.local/Shibboleth.sso/ADFS'
self.SP_ENTITYID = 'https://openstack4.local'
self.adfsplugin = saml2.V3ADFSPassword(
self.TEST_URL, self.IDENTITY_PROVIDER,
self.IDENTITY_PROVIDER_URL, self.SP_ENDPOINT,
self.TEST_USER, self.TEST_TOKEN, self.PROTOCOL)
self.TEST_URL,
self.IDENTITY_PROVIDER,
self.IDENTITY_PROVIDER_URL,
self.SP_ENDPOINT,
self.TEST_USER,
self.TEST_TOKEN,
self.PROTOCOL,
)
self.ADFS_SECURITY_TOKEN_RESPONSE = utils._load_xml(
'ADFS_RequestSecurityTokenResponse.xml')
'ADFS_RequestSecurityTokenResponse.xml'
)
self.ADFS_FAULT = utils._load_xml('ADFS_fault.xml')
def test_get_adfs_security_token(self):
@ -86,7 +92,8 @@ class AuthenticateviaADFSTests(utils.TestCase):
self.requests_mock.post(
self.IDENTITY_PROVIDER_URL,
content=utils.make_oneline(self.ADFS_SECURITY_TOKEN_RESPONSE),
status_code=200)
status_code=200,
)
self.adfsplugin._prepare_adfs_request()
self.adfsplugin._get_adfs_security_token(self.session)
@ -94,59 +101,72 @@ class AuthenticateviaADFSTests(utils.TestCase):
adfs_response = etree.tostring(self.adfsplugin.adfs_token)
fixture_response = self.ADFS_SECURITY_TOKEN_RESPONSE
self.assertThat(fixture_response,
matchers.XMLEquals(adfs_response))
self.assertThat(fixture_response, matchers.XMLEquals(adfs_response))
def test_adfs_request_user(self):
self.adfsplugin._prepare_adfs_request()
user = self.adfsplugin.prepared_request.xpath(
self.USER_XPATH, namespaces=self.NAMESPACES)[0]
self.USER_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.TEST_USER, user.text)
def test_adfs_request_password(self):
self.adfsplugin._prepare_adfs_request()
password = self.adfsplugin.prepared_request.xpath(
self.PASSWORD_XPATH, namespaces=self.NAMESPACES)[0]
self.PASSWORD_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.TEST_TOKEN, password.text)
def test_adfs_request_to(self):
self.adfsplugin._prepare_adfs_request()
to = self.adfsplugin.prepared_request.xpath(
self.TO_XPATH, namespaces=self.NAMESPACES)[0]
self.TO_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.IDENTITY_PROVIDER_URL, to.text)
def test_prepare_adfs_request_address(self):
self.adfsplugin._prepare_adfs_request()
address = self.adfsplugin.prepared_request.xpath(
self.ADDRESS_XPATH, namespaces=self.NAMESPACES)[0]
self.ADDRESS_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.SP_ENDPOINT, address.text)
def test_prepare_adfs_request_custom_endpointreference(self):
self.adfsplugin = saml2.V3ADFSPassword(
self.TEST_URL, self.IDENTITY_PROVIDER,
self.IDENTITY_PROVIDER_URL, self.SP_ENDPOINT,
self.TEST_USER, self.TEST_TOKEN, self.PROTOCOL, self.SP_ENTITYID)
self.TEST_URL,
self.IDENTITY_PROVIDER,
self.IDENTITY_PROVIDER_URL,
self.SP_ENDPOINT,
self.TEST_USER,
self.TEST_TOKEN,
self.PROTOCOL,
self.SP_ENTITYID,
)
self.adfsplugin._prepare_adfs_request()
address = self.adfsplugin.prepared_request.xpath(
self.ADDRESS_XPATH, namespaces=self.NAMESPACES)[0]
self.ADDRESS_XPATH, namespaces=self.NAMESPACES
)[0]
self.assertEqual(self.SP_ENTITYID, address.text)
def test_prepare_sp_request(self):
assertion = etree.XML(self.ADFS_SECURITY_TOKEN_RESPONSE)
assertion = assertion.xpath(
saml2.V3ADFSPassword.ADFS_ASSERTION_XPATH,
namespaces=saml2.V3ADFSPassword.ADFS_TOKEN_NAMESPACES)
namespaces=saml2.V3ADFSPassword.ADFS_TOKEN_NAMESPACES,
)
assertion = assertion[0]
assertion = etree.tostring(assertion)
assertion = assertion.replace(
b'http://docs.oasis-open.org/ws-sx/ws-trust/200512',
b'http://schemas.xmlsoap.org/ws/2005/02/trust')
b'http://schemas.xmlsoap.org/ws/2005/02/trust',
)
assertion = urllib.parse.quote(assertion)
assertion = 'wa=wsignin1.0&wresult=' + assertion
self.adfsplugin.adfs_token = etree.XML(
self.ADFS_SECURITY_TOKEN_RESPONSE)
self.ADFS_SECURITY_TOKEN_RESPONSE
)
self.adfsplugin._prepare_sp_request()
self.assertEqual(assertion, self.adfsplugin.encoded_assertion)
@ -158,15 +178,19 @@ class AuthenticateviaADFSTests(utils.TestCase):
error message from the XML message indicating where was the problem.
"""
content = utils.make_oneline(self.ADFS_FAULT)
self.requests_mock.register_uri('POST',
self.IDENTITY_PROVIDER_URL,
content=content,
status_code=500)
self.requests_mock.register_uri(
'POST',
self.IDENTITY_PROVIDER_URL,
content=content,
status_code=500,
)
self.adfsplugin._prepare_adfs_request()
self.assertRaises(exceptions.AuthorizationFailure,
self.adfsplugin._get_adfs_security_token,
self.session)
self.assertRaises(
exceptions.AuthorizationFailure,
self.adfsplugin._get_adfs_security_token,
self.session,
)
# TODO(marek-denis): Python3 tests complain about missing 'message'
# attributes
# self.assertEqual('a:FailedAuthentication', e.message)
@ -178,14 +202,18 @@ class AuthenticateviaADFSTests(utils.TestCase):
and correctly raise exceptions.InternalServerError once it cannot
parse XML fault message
"""
self.requests_mock.register_uri('POST',
self.IDENTITY_PROVIDER_URL,
content=b'NOT XML',
status_code=500)
self.requests_mock.register_uri(
'POST',
self.IDENTITY_PROVIDER_URL,
content=b'NOT XML',
status_code=500,
)
self.adfsplugin._prepare_adfs_request()
self.assertRaises(exceptions.InternalServerError,
self.adfsplugin._get_adfs_security_token,
self.session)
self.assertRaises(
exceptions.InternalServerError,
self.adfsplugin._get_adfs_security_token,
self.session,
)
# TODO(marek-denis): Need to figure out how to properly send cookies
# from the request_mock methods.
@ -193,9 +221,9 @@ class AuthenticateviaADFSTests(utils.TestCase):
"""Test whether SP issues a cookie."""
cookie = uuid.uuid4().hex
self.requests_mock.post(self.SP_ENDPOINT,
headers={"set-cookie": cookie},
status_code=302)
self.requests_mock.post(
self.SP_ENDPOINT, headers={"set-cookie": cookie}, status_code=302
)
self.adfsplugin.adfs_token = self._build_adfs_request()
self.adfsplugin._prepare_sp_request()
@ -204,55 +232,70 @@ class AuthenticateviaADFSTests(utils.TestCase):
self.assertEqual(1, len(self.session.session.cookies))
def test_send_assertion_to_service_provider_bad_status(self):
self.requests_mock.register_uri('POST', self.SP_ENDPOINT,
status_code=500)
self.requests_mock.register_uri(
'POST', self.SP_ENDPOINT, status_code=500
)
self.adfsplugin.adfs_token = etree.XML(
self.ADFS_SECURITY_TOKEN_RESPONSE)
self.ADFS_SECURITY_TOKEN_RESPONSE
)
self.adfsplugin._prepare_sp_request()
self.assertRaises(
exceptions.InternalServerError,
self.adfsplugin._send_assertion_to_service_provider,
self.session)
self.session,
)
def test_access_sp_no_cookies_fail(self):
# clean cookie jar
self.session.session.cookies = []
self.assertRaises(exceptions.AuthorizationFailure,
self.adfsplugin._access_service_provider,
self.session)
self.assertRaises(
exceptions.AuthorizationFailure,
self.adfsplugin._access_service_provider,
self.session,
)
def test_check_valid_token_when_authenticated(self):
self.requests_mock.register_uri(
'GET', self.FEDERATION_AUTH_URL,
'GET',
self.FEDERATION_AUTH_URL,
json=saml2_fixtures.UNSCOPED_TOKEN,
headers=client_fixtures.AUTH_RESPONSE_HEADERS)
headers=client_fixtures.AUTH_RESPONSE_HEADERS,
)
self.session.session.cookies = [object()]
self.adfsplugin._access_service_provider(self.session)
response = self.adfsplugin.authenticated_response
self.assertEqual(client_fixtures.AUTH_RESPONSE_HEADERS,
response.headers)
self.assertEqual(
client_fixtures.AUTH_RESPONSE_HEADERS, response.headers
)
self.assertEqual(saml2_fixtures.UNSCOPED_TOKEN['token'],
response.json()['token'])
self.assertEqual(
saml2_fixtures.UNSCOPED_TOKEN['token'], response.json()['token']
)
def test_end_to_end_workflow(self):
self.requests_mock.register_uri(
'POST', self.IDENTITY_PROVIDER_URL,
'POST',
self.IDENTITY_PROVIDER_URL,
content=self.ADFS_SECURITY_TOKEN_RESPONSE,
status_code=200)
status_code=200,
)
self.requests_mock.register_uri(
'POST', self.SP_ENDPOINT,
'POST',
self.SP_ENDPOINT,
headers={"set-cookie": 'x'},
status_code=302)
status_code=302,
)
self.requests_mock.register_uri(
'GET', self.FEDERATION_AUTH_URL,
'GET',
self.FEDERATION_AUTH_URL,
json=saml2_fixtures.UNSCOPED_TOKEN,
headers=client_fixtures.AUTH_RESPONSE_HEADERS)
headers=client_fixtures.AUTH_RESPONSE_HEADERS,
)
# NOTE(marek-denis): We need to mimic this until self.requests_mock can
# issue cookies properly.

View File

@ -53,8 +53,8 @@ class SamlAuth2PluginTests(utils.TestCase):
return [r.url.strip('/') for r in self.requests_mock.request_history]
def basic_header(self, username=TEST_USER, password=TEST_PASS):
user_pass = ('%s:%s' % (username, password)).encode('utf-8')
return 'Basic %s' % base64.b64encode(user_pass).decode('utf-8')
user_pass = (f'{username}:{password}').encode()
return 'Basic {}'.format(base64.b64encode(user_pass).decode('utf-8'))
def test_request_accept_headers(self):
# Include some random Accept header
@ -70,18 +70,23 @@ class SamlAuth2PluginTests(utils.TestCase):
# added the PAOS_HEADER to it using the correct media type separator
accept_header = plugin_headers['Accept']
self.assertIn(self.HEADER_MEDIA_TYPE_SEPARATOR, accept_header)
self.assertIn(random_header,
accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR))
self.assertIn(PAOS_HEADER,
accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR))
self.assertIn(
random_header,
accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR),
)
self.assertIn(
PAOS_HEADER, accept_header.split(self.HEADER_MEDIA_TYPE_SEPARATOR)
)
def test_passed_when_not_200(self):
text = uuid.uuid4().hex
test_url = 'http://another.test'
self.requests_mock.get(test_url,
status_code=201,
headers=CONTENT_TYPE_PAOS_HEADER,
text=text)
self.requests_mock.get(
test_url,
status_code=201,
headers=CONTENT_TYPE_PAOS_HEADER,
text=text,
)
resp = requests.get(test_url, auth=self.get_plugin())
self.assertEqual(201, resp.status_code)
@ -99,82 +104,115 @@ class SamlAuth2PluginTests(utils.TestCase):
def test_standard_workflow_302_redirect(self):
text = uuid.uuid4().hex
self.requests_mock.get(self.TEST_SP_URL, response_list=[
dict(headers=CONTENT_TYPE_PAOS_HEADER,
content=utils.make_oneline(saml2_fixtures.SP_SOAP_RESPONSE)),
dict(text=text)
])
self.requests_mock.get(
self.TEST_SP_URL,
response_list=[
{
'headers': CONTENT_TYPE_PAOS_HEADER,
'content': utils.make_oneline(
saml2_fixtures.SP_SOAP_RESPONSE
),
},
{'text': text},
],
)
authm = self.requests_mock.post(self.TEST_IDP_URL,
content=saml2_fixtures.SAML2_ASSERTION)
authm = self.requests_mock.post(
self.TEST_IDP_URL, content=saml2_fixtures.SAML2_ASSERTION
)
self.requests_mock.post(
self.TEST_CONSUMER_URL,
status_code=302,
headers={'Location': self.TEST_SP_URL})
headers={'Location': self.TEST_SP_URL},
)
resp = requests.get(self.TEST_SP_URL, auth=self.get_plugin())
self.assertEqual(200, resp.status_code)
self.assertEqual(text, resp.text)
self.assertEqual(self.calls, [self.TEST_SP_URL,
self.TEST_IDP_URL,
self.TEST_CONSUMER_URL,
self.TEST_SP_URL])
self.assertEqual(
self.calls,
[
self.TEST_SP_URL,
self.TEST_IDP_URL,
self.TEST_CONSUMER_URL,
self.TEST_SP_URL,
],
)
self.assertEqual(self.basic_header(),
authm.last_request.headers['Authorization'])
self.assertEqual(
self.basic_header(), authm.last_request.headers['Authorization']
)
authn_request = self.requests_mock.request_history[1].text
self.assertThat(saml2_fixtures.AUTHN_REQUEST,
matchers.XMLEquals(authn_request))
self.assertThat(
saml2_fixtures.AUTHN_REQUEST, matchers.XMLEquals(authn_request)
)
def test_standard_workflow_303_redirect(self):
text = uuid.uuid4().hex
self.requests_mock.get(self.TEST_SP_URL, response_list=[
dict(headers=CONTENT_TYPE_PAOS_HEADER,
content=utils.make_oneline(saml2_fixtures.SP_SOAP_RESPONSE)),
dict(text=text)
])
self.requests_mock.get(
self.TEST_SP_URL,
response_list=[
{
'headers': CONTENT_TYPE_PAOS_HEADER,
'content': utils.make_oneline(
saml2_fixtures.SP_SOAP_RESPONSE
),
},
{'text': text},
],
)
authm = self.requests_mock.post(self.TEST_IDP_URL,
content=saml2_fixtures.SAML2_ASSERTION)
authm = self.requests_mock.post(
self.TEST_IDP_URL, content=saml2_fixtures.SAML2_ASSERTION
)
self.requests_mock.post(
self.TEST_CONSUMER_URL,
status_code=303,
headers={'Location': self.TEST_SP_URL})
headers={'Location': self.TEST_SP_URL},
)
resp = requests.get(self.TEST_SP_URL, auth=self.get_plugin())
self.assertEqual(200, resp.status_code)
self.assertEqual(text, resp.text)
url_flow = [self.TEST_SP_URL,
self.TEST_IDP_URL,
self.TEST_CONSUMER_URL,
self.TEST_SP_URL]
url_flow = [
self.TEST_SP_URL,
self.TEST_IDP_URL,
self.TEST_CONSUMER_URL,
self.TEST_SP_URL,
]
self.assertEqual(url_flow, [r.url.rstrip('/') for r in resp.history])
self.assertEqual(url_flow, self.calls)
self.assertEqual(self.basic_header(),
authm.last_request.headers['Authorization'])
self.assertEqual(
self.basic_header(), authm.last_request.headers['Authorization']
)
authn_request = self.requests_mock.request_history[1].text
self.assertThat(saml2_fixtures.AUTHN_REQUEST,
matchers.XMLEquals(authn_request))
self.assertThat(
saml2_fixtures.AUTHN_REQUEST, matchers.XMLEquals(authn_request)
)
def test_initial_sp_call_invalid_response(self):
"""Send initial SP HTTP request and receive wrong server response."""
self.requests_mock.get(self.TEST_SP_URL,
headers=CONTENT_TYPE_PAOS_HEADER,
text='NON XML RESPONSE')
self.requests_mock.get(
self.TEST_SP_URL,
headers=CONTENT_TYPE_PAOS_HEADER,
text='NON XML RESPONSE',
)
self.assertRaises(InvalidResponse,
requests.get,
self.TEST_SP_URL,
auth=self.get_plugin())
self.assertRaises(
InvalidResponse,
requests.get,
self.TEST_SP_URL,
auth=self.get_plugin(),
)
self.assertEqual(self.calls, [self.TEST_SP_URL])
@ -184,25 +222,28 @@ class SamlAuth2PluginTests(utils.TestCase):
soap_response = saml2_fixtures.soap_response(consumer=consumer1)
saml_assertion = saml2_fixtures.saml_assertion(destination=consumer2)
self.requests_mock.get(self.TEST_SP_URL,
headers=CONTENT_TYPE_PAOS_HEADER,
content=soap_response)
self.requests_mock.get(
self.TEST_SP_URL,
headers=CONTENT_TYPE_PAOS_HEADER,
content=soap_response,
)
self.requests_mock.post(self.TEST_IDP_URL, content=saml_assertion)
# receive the SAML error, body unchecked
saml_error = self.requests_mock.post(consumer1)
self.assertRaises(saml2.v3.saml2.ConsumerMismatch,
requests.get,
self.TEST_SP_URL,
auth=self.get_plugin())
self.assertRaises(
saml2.v3.saml2.ConsumerMismatch,
requests.get,
self.TEST_SP_URL,
auth=self.get_plugin(),
)
self.assertTrue(saml_error.called)
class AuthenticateviaSAML2Tests(utils.TestCase):
TEST_USER = 'user'
TEST_PASS = 'pass'
TEST_IDP = 'tester'
@ -226,8 +267,10 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
kwargs.setdefault('identity_provider', self.TEST_IDP)
kwargs.setdefault('protocol', self.TEST_PROTOCOL)
templ = ('%(base)s/OS-FEDERATION/identity_providers/'
'%(identity_provider)s/protocols/%(protocol)s/auth')
templ = (
'%(base)s/OS-FEDERATION/identity_providers/'
'%(identity_provider)s/protocols/%(protocol)s/auth'
)
return templ % kwargs
@property
@ -235,11 +278,11 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
return [r.url.strip('/') for r in self.requests_mock.request_history]
def basic_header(self, username=TEST_USER, password=TEST_PASS):
user_pass = ('%s:%s' % (username, password)).encode('utf-8')
return 'Basic %s' % base64.b64encode(user_pass).decode('utf-8')
user_pass = (f'{username}:{password}').encode()
return 'Basic {}'.format(base64.b64encode(user_pass).decode('utf-8'))
def setUp(self):
super(AuthenticateviaSAML2Tests, self).setUp()
super().setUp()
self.session = session.Session()
self.default_sp_url = self.sp_url()
@ -247,35 +290,51 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
token_id = uuid.uuid4().hex
token = ksa_fixtures.V3Token()
self.requests_mock.get(self.default_sp_url, response_list=[
dict(headers=CONTENT_TYPE_PAOS_HEADER,
content=utils.make_oneline(saml2_fixtures.SP_SOAP_RESPONSE)),
dict(headers={'X-Subject-Token': token_id}, json=token)
])
self.requests_mock.get(
self.default_sp_url,
response_list=[
{
'headers': CONTENT_TYPE_PAOS_HEADER,
'content': utils.make_oneline(
saml2_fixtures.SP_SOAP_RESPONSE
),
},
{'headers': {'X-Subject-Token': token_id}, 'json': token},
],
)
authm = self.requests_mock.post(self.TEST_IDP_URL,
content=saml2_fixtures.SAML2_ASSERTION)
authm = self.requests_mock.post(
self.TEST_IDP_URL, content=saml2_fixtures.SAML2_ASSERTION
)
self.requests_mock.post(
self.TEST_CONSUMER_URL,
status_code=302,
headers={'Location': self.sp_url()})
headers={'Location': self.sp_url()},
)
auth_ref = self.get_plugin().get_auth_ref(self.session)
self.assertEqual(token_id, auth_ref.auth_token)
self.assertEqual(self.calls, [self.default_sp_url,
self.TEST_IDP_URL,
self.TEST_CONSUMER_URL,
self.default_sp_url])
self.assertEqual(
self.calls,
[
self.default_sp_url,
self.TEST_IDP_URL,
self.TEST_CONSUMER_URL,
self.default_sp_url,
],
)
self.assertEqual(self.basic_header(),
authm.last_request.headers['Authorization'])
self.assertEqual(
self.basic_header(), authm.last_request.headers['Authorization']
)
authn_request = self.requests_mock.request_history[1].text
self.assertThat(saml2_fixtures.AUTHN_REQUEST,
matchers.XMLEquals(authn_request))
self.assertThat(
saml2_fixtures.AUTHN_REQUEST, matchers.XMLEquals(authn_request)
)
def test_consumer_mismatch_error_workflow(self):
consumer1 = 'http://keystone.test/Shibboleth.sso/SAML2/ECP'
@ -284,29 +343,37 @@ class AuthenticateviaSAML2Tests(utils.TestCase):
soap_response = saml2_fixtures.soap_response(consumer=consumer1)
saml_assertion = saml2_fixtures.saml_assertion(destination=consumer2)
self.requests_mock.get(self.default_sp_url,
headers=CONTENT_TYPE_PAOS_HEADER,
content=soap_response)
self.requests_mock.get(
self.default_sp_url,
headers=CONTENT_TYPE_PAOS_HEADER,
content=soap_response,
)
self.requests_mock.post(self.TEST_IDP_URL, content=saml_assertion)
# receive the SAML error, body unchecked
saml_error = self.requests_mock.post(consumer1)
self.assertRaises(exceptions.AuthorizationFailure,
self.get_plugin().get_auth_ref,
self.session)
self.assertRaises(
exceptions.AuthorizationFailure,
self.get_plugin().get_auth_ref,
self.session,
)
self.assertTrue(saml_error.called)
def test_initial_sp_call_invalid_response(self):
"""Send initial SP HTTP request and receive wrong server response."""
self.requests_mock.get(self.default_sp_url,
headers=CONTENT_TYPE_PAOS_HEADER,
text='NON XML RESPONSE')
self.requests_mock.get(
self.default_sp_url,
headers=CONTENT_TYPE_PAOS_HEADER,
text='NON XML RESPONSE',
)
self.assertRaises(exceptions.AuthorizationFailure,
self.get_plugin().get_auth_ref,
self.session)
self.assertRaises(
exceptions.AuthorizationFailure,
self.get_plugin().get_auth_ref,
self.session,
)
self.assertEqual(self.calls, [self.default_sp_url])

View File

@ -31,9 +31,8 @@ def _load_xml(filename):
class TestCase(utils.TestCase):
TEST_URL = 'https://keystone:5000/v3'
def setUp(self):
super(TestCase, self).setUp()
super().setUp()
self.session = session.Session()

View File

@ -21,9 +21,8 @@ from keystoneauth1.tests.unit import utils
class AccessInfoPluginTests(utils.TestCase):
def setUp(self):
super(AccessInfoPluginTests, self).setUp()
super().setUp()
self.session = session.Session()
self.auth_token = uuid.uuid4().hex
@ -37,19 +36,22 @@ class AccessInfoPluginTests(utils.TestCase):
def test_auth_ref(self):
plugin_obj = self._plugin()
self.assertEqual(self.TEST_ROOT_URL,
plugin_obj.get_endpoint(self.session,
service_type='identity',
interface='public'))
self.assertEqual(
self.TEST_ROOT_URL,
plugin_obj.get_endpoint(
self.session, service_type='identity', interface='public'
),
)
self.assertEqual(self.auth_token, plugin_obj.get_token(session))
def test_auth_url(self):
auth_url = 'http://keystone.test.url'
obj = self._plugin(auth_url=auth_url)
self.assertEqual(auth_url,
obj.get_endpoint(self.session,
interface=plugin.AUTH_INTERFACE))
self.assertEqual(
auth_url,
obj.get_endpoint(self.session, interface=plugin.AUTH_INTERFACE),
)
def test_invalidate(self):
plugin = self._plugin()

File diff suppressed because it is too large Load Diff

View File

@ -25,78 +25,89 @@ from keystoneauth1.tests.unit import utils
class V2IdentityPlugin(utils.TestCase):
TEST_ROOT_URL = 'http://127.0.0.1:5000/'
TEST_URL = '%s%s' % (TEST_ROOT_URL, 'v2.0')
TEST_URL = '{}{}'.format(TEST_ROOT_URL, 'v2.0')
TEST_ROOT_ADMIN_URL = 'http://127.0.0.1:35357/'
TEST_ADMIN_URL = '%s%s' % (TEST_ROOT_ADMIN_URL, 'v2.0')
TEST_ADMIN_URL = '{}{}'.format(TEST_ROOT_ADMIN_URL, 'v2.0')
TEST_PASS = 'password'
TEST_SERVICE_CATALOG = [{
"endpoints": [{
"adminURL": "http://cdn.admin-nets.local:8774/v1.0",
"region": "RegionOne",
"internalURL": "http://127.0.0.1:8774/v1.0",
"publicURL": "http://cdn.admin-nets.local:8774/v1.0/"
}],
"type": "nova_compat",
"name": "nova_compat"
}, {
"endpoints": [{
"adminURL": "http://nova/novapi/admin",
"region": "RegionOne",
"internalURL": "http://nova/novapi/internal",
"publicURL": "http://nova/novapi/public"
}],
"type": "compute",
"name": "nova"
}, {
"endpoints": [{
"adminURL": "http://glance/glanceapi/admin",
"region": "RegionOne",
"internalURL": "http://glance/glanceapi/internal",
"publicURL": "http://glance/glanceapi/public"
}],
"type": "image",
"name": "glance"
}, {
"endpoints": [{
"adminURL": TEST_ADMIN_URL,
"region": "RegionOne",
"internalURL": "http://127.0.0.1:5000/v2.0",
"publicURL": "http://127.0.0.1:5000/v2.0"
}],
"type": "identity",
"name": "keystone"
}, {
"endpoints": [{
"adminURL": "http://swift/swiftapi/admin",
"region": "RegionOne",
"internalURL": "http://swift/swiftapi/internal",
"publicURL": "http://swift/swiftapi/public"
}],
"type": "object-store",
"name": "swift"
}]
TEST_SERVICE_CATALOG = [
{
"endpoints": [
{
"adminURL": "http://cdn.admin-nets.local:8774/v1.0",
"region": "RegionOne",
"internalURL": "http://127.0.0.1:8774/v1.0",
"publicURL": "http://cdn.admin-nets.local:8774/v1.0/",
}
],
"type": "nova_compat",
"name": "nova_compat",
},
{
"endpoints": [
{
"adminURL": "http://nova/novapi/admin",
"region": "RegionOne",
"internalURL": "http://nova/novapi/internal",
"publicURL": "http://nova/novapi/public",
}
],
"type": "compute",
"name": "nova",
},
{
"endpoints": [
{
"adminURL": "http://glance/glanceapi/admin",
"region": "RegionOne",
"internalURL": "http://glance/glanceapi/internal",
"publicURL": "http://glance/glanceapi/public",
}
],
"type": "image",
"name": "glance",
},
{
"endpoints": [
{
"adminURL": TEST_ADMIN_URL,
"region": "RegionOne",
"internalURL": "http://127.0.0.1:5000/v2.0",
"publicURL": "http://127.0.0.1:5000/v2.0",
}
],
"type": "identity",
"name": "keystone",
},
{
"endpoints": [
{
"adminURL": "http://swift/swiftapi/admin",
"region": "RegionOne",
"internalURL": "http://swift/swiftapi/internal",
"publicURL": "http://swift/swiftapi/public",
}
],
"type": "object-store",
"name": "swift",
},
]
def setUp(self):
super(V2IdentityPlugin, self).setUp()
super().setUp()
self.TEST_RESPONSE_DICT = {
"access": {
"token": {
"expires": "%i-02-01T00:00:10.000123Z" %
(1 + time.gmtime().tm_year),
"expires": "%i-02-01T00:00:10.000123Z"
% (1 + time.gmtime().tm_year),
"id": self.TEST_TOKEN,
"tenant": {
"id": self.TEST_TENANT_ID
},
},
"user": {
"id": self.TEST_USER
"tenant": {"id": self.TEST_TENANT_ID},
},
"user": {"id": self.TEST_USER},
"serviceCatalog": self.TEST_SERVICE_CATALOG,
},
}
}
def stub_auth(self, **kwargs):
@ -104,16 +115,24 @@ class V2IdentityPlugin(utils.TestCase):
def test_authenticate_with_username_password(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
self.assertIsNone(a.user_id)
self.assertFalse(a.has_scope_parameters)
s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN},
s.get_auth_headers())
self.assertEqual(
{'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'username': self.TEST_USER,
'password': self.TEST_PASS}}}
req = {
'auth': {
'passwordCredentials': {
'username': self.TEST_USER,
'password': self.TEST_PASS,
}
}
}
self.assertRequestBodyIs(json=req)
self.assertRequestHeaderEqual('Content-Type', 'application/json')
self.assertRequestHeaderEqual('Accept', 'application/json')
@ -121,16 +140,24 @@ class V2IdentityPlugin(utils.TestCase):
def test_authenticate_with_user_id_password(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, user_id=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, user_id=self.TEST_USER, password=self.TEST_PASS
)
self.assertIsNone(a.username)
self.assertFalse(a.has_scope_parameters)
s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN},
s.get_auth_headers())
self.assertEqual(
{'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'userId': self.TEST_USER,
'password': self.TEST_PASS}}}
req = {
'auth': {
'passwordCredentials': {
'userId': self.TEST_USER,
'password': self.TEST_PASS,
}
}
}
self.assertRequestBodyIs(json=req)
self.assertRequestHeaderEqual('Content-Type', 'application/json')
self.assertRequestHeaderEqual('Accept', 'application/json')
@ -138,33 +165,55 @@ class V2IdentityPlugin(utils.TestCase):
def test_authenticate_with_username_password_scoped(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS, tenant_id=self.TEST_TENANT_ID)
a = v2.Password(
self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=self.TEST_TENANT_ID,
)
self.assertTrue(a.has_scope_parameters)
self.assertIsNone(a.user_id)
s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN},
s.get_auth_headers())
self.assertEqual(
{'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'username': self.TEST_USER,
'password': self.TEST_PASS},
'tenantId': self.TEST_TENANT_ID}}
req = {
'auth': {
'passwordCredentials': {
'username': self.TEST_USER,
'password': self.TEST_PASS,
},
'tenantId': self.TEST_TENANT_ID,
}
}
self.assertRequestBodyIs(json=req)
self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN)
def test_authenticate_with_user_id_password_scoped(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, user_id=self.TEST_USER,
password=self.TEST_PASS, tenant_id=self.TEST_TENANT_ID)
a = v2.Password(
self.TEST_URL,
user_id=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=self.TEST_TENANT_ID,
)
self.assertIsNone(a.username)
self.assertTrue(a.has_scope_parameters)
s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN},
s.get_auth_headers())
self.assertEqual(
{'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'userId': self.TEST_USER,
'password': self.TEST_PASS},
'tenantId': self.TEST_TENANT_ID}}
req = {
'auth': {
'passwordCredentials': {
'userId': self.TEST_USER,
'password': self.TEST_PASS,
},
'tenantId': self.TEST_TENANT_ID,
}
}
self.assertRequestBodyIs(json=req)
self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN)
@ -172,8 +221,9 @@ class V2IdentityPlugin(utils.TestCase):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Token(self.TEST_URL, 'foo')
s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN},
s.get_auth_headers())
self.assertEqual(
{'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'token': {'id': 'foo'}}}
self.assertRequestBodyIs(json=req)
@ -184,40 +234,55 @@ class V2IdentityPlugin(utils.TestCase):
def test_with_trust_id(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS, trust_id='trust')
a = v2.Password(
self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
trust_id='trust',
)
self.assertTrue(a.has_scope_parameters)
s = session.Session(a)
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN},
s.get_auth_headers())
self.assertEqual(
{'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
req = {'auth': {'passwordCredentials': {'username': self.TEST_USER,
'password': self.TEST_PASS},
'trust_id': 'trust'}}
req = {
'auth': {
'passwordCredentials': {
'username': self.TEST_USER,
'password': self.TEST_PASS,
},
'trust_id': 'trust',
}
}
self.assertRequestBodyIs(json=req)
self.assertEqual(s.auth.auth_ref.auth_token, self.TEST_TOKEN)
def _do_service_url_test(self, base_url, endpoint_filter):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
self.stub_url('GET', ['path'],
base_url=base_url,
text='SUCCESS', status_code=200)
self.stub_url(
'GET', ['path'], base_url=base_url, text='SUCCESS', status_code=200
)
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a)
resp = s.get('/path', endpoint_filter=endpoint_filter)
self.assertEqual(resp.status_code, 200)
self.assertEqual(self.requests_mock.last_request.url,
base_url + '/path')
self.assertEqual(
self.requests_mock.last_request.url, base_url + '/path'
)
def test_service_url(self):
endpoint_filter = {'service_type': 'compute',
'interface': 'admin',
'service_name': 'nova'}
endpoint_filter = {
'service_type': 'compute',
'interface': 'admin',
'service_name': 'nova',
}
self._do_service_url_test('http://nova/novapi/admin', endpoint_filter)
def test_service_url_defaults_to_public(self):
@ -227,47 +292,62 @@ class V2IdentityPlugin(utils.TestCase):
def test_endpoint_filter_without_service_type_fails(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a)
self.assertRaises(exceptions.EndpointNotFound, s.get, '/path',
endpoint_filter={'interface': 'admin'})
self.assertRaises(
exceptions.EndpointNotFound,
s.get,
'/path',
endpoint_filter={'interface': 'admin'},
)
def test_full_url_overrides_endpoint_filter(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
self.stub_url('GET', [],
base_url='http://testurl/',
text='SUCCESS', status_code=200)
self.stub_url(
'GET',
[],
base_url='http://testurl/',
text='SUCCESS',
status_code=200,
)
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a)
resp = s.get('http://testurl/',
endpoint_filter={'service_type': 'compute'})
resp = s.get(
'http://testurl/', endpoint_filter={'service_type': 'compute'}
)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, 'SUCCESS')
def test_invalid_auth_response_dict(self):
self.stub_auth(json={'hello': 'world'})
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a)
self.assertRaises(exceptions.InvalidResponse, s.get, 'http://any',
authenticated=True)
self.assertRaises(
exceptions.InvalidResponse, s.get, 'http://any', authenticated=True
)
def test_invalid_auth_response_type(self):
self.stub_url('POST', ['tokens'], text='testdata')
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a)
self.assertRaises(exceptions.InvalidResponse, s.get, 'http://any',
authenticated=True)
self.assertRaises(
exceptions.InvalidResponse, s.get, 'http://any', authenticated=True
)
def test_invalidate_response(self):
resp_data1 = copy.deepcopy(self.TEST_RESPONSE_DICT)
@ -279,8 +359,9 @@ class V2IdentityPlugin(utils.TestCase):
auth_responses = [{'json': resp_data1}, {'json': resp_data2}]
self.stub_auth(response_list=auth_responses)
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=self.TEST_PASS)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
s = session.Session(auth=a)
self.assertEqual('token1', s.get_token())
@ -294,41 +375,50 @@ class V2IdentityPlugin(utils.TestCase):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
password = uuid.uuid4().hex
a = v2.Password(self.TEST_URL, username=self.TEST_USER,
password=password)
a = v2.Password(
self.TEST_URL, username=self.TEST_USER, password=password
)
s = session.Session(auth=a)
self.assertEqual(self.TEST_TOKEN, s.get_token())
self.assertEqual({'X-Auth-Token': self.TEST_TOKEN},
s.get_auth_headers())
self.assertEqual(
{'X-Auth-Token': self.TEST_TOKEN}, s.get_auth_headers()
)
self.assertNotIn(password, self.logger.output)
def test_password_with_no_user_id_or_name(self):
self.assertRaises(TypeError,
v2.Password, self.TEST_URL, password=self.TEST_PASS)
self.assertRaises(
TypeError, v2.Password, self.TEST_URL, password=self.TEST_PASS
)
def test_password_cache_id(self):
self.stub_auth(json=self.TEST_RESPONSE_DICT)
trust_id = uuid.uuid4().hex
a = v2.Password(self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
trust_id=trust_id)
a = v2.Password(
self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
trust_id=trust_id,
)
b = v2.Password(self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
trust_id=trust_id)
b = v2.Password(
self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
trust_id=trust_id,
)
a_id = a.get_cache_id()
b_id = b.get_cache_id()
self.assertEqual(a_id, b_id)
c = v2.Password(self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=trust_id) # same value different param
c = v2.Password(
self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=trust_id,
) # same value different param
c_id = c.get_cache_id()
@ -350,18 +440,21 @@ class V2IdentityPlugin(utils.TestCase):
auth_ref = access.create(body=token)
a = v2.Password(self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=uuid.uuid4().hex)
a = v2.Password(
self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
tenant_id=uuid.uuid4().hex,
)
initial_cache_id = a.get_cache_id()
state = a.get_auth_state()
self.assertIsNone(state)
state = json.dumps({'auth_token': auth_ref.auth_token,
'body': auth_ref._data})
state = json.dumps(
{'auth_token': auth_ref.auth_token, 'body': auth_ref._data}
)
a.set_auth_state(state)
self.assertEqual(token.token_id, a.auth_ref.auth_token)

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,6 @@ from keystoneauth1.tests.unit import utils
class TesterFederationPlugin(v3.FederationBaseAuth):
def get_unscoped_auth_ref(self, sess, **kwargs):
# This would go and talk to an idp or something
resp = sess.post(self.federated_token_url, authenticated=False)
@ -32,11 +31,10 @@ class TesterFederationPlugin(v3.FederationBaseAuth):
class V3FederatedPlugin(utils.TestCase):
AUTH_URL = 'http://keystone/v3'
def setUp(self):
super(V3FederatedPlugin, self).setUp()
super().setUp()
self.unscoped_token = fixture.V3Token()
self.unscoped_token_id = uuid.uuid4().hex
@ -46,26 +44,30 @@ class V3FederatedPlugin(utils.TestCase):
self.scoped_token_id = uuid.uuid4().hex
s = self.scoped_token.add_service('compute', name='nova')
s.add_standard_endpoints(public='http://nova/public',
admin='http://nova/admin',
internal='http://nova/internal')
s.add_standard_endpoints(
public='http://nova/public',
admin='http://nova/admin',
internal='http://nova/internal',
)
self.idp = uuid.uuid4().hex
self.protocol = uuid.uuid4().hex
self.token_url = ('%s/OS-FEDERATION/identity_providers/%s/protocols/%s'
'/auth' % (self.AUTH_URL, self.idp, self.protocol))
self.token_url = (
f'{self.AUTH_URL}/OS-FEDERATION/identity_providers/{self.idp}/protocols/{self.protocol}'
'/auth'
)
headers = {'X-Subject-Token': self.unscoped_token_id}
self.unscoped_mock = self.requests_mock.post(self.token_url,
json=self.unscoped_token,
headers=headers)
self.unscoped_mock = self.requests_mock.post(
self.token_url, json=self.unscoped_token, headers=headers
)
headers = {'X-Subject-Token': self.scoped_token_id}
auth_url = self.AUTH_URL + '/auth/tokens'
self.scoped_mock = self.requests_mock.post(auth_url,
json=self.scoped_token,
headers=headers)
self.scoped_mock = self.requests_mock.post(
auth_url, json=self.scoped_token, headers=headers
)
def get_plugin(self, **kwargs):
kwargs.setdefault('auth_url', self.AUTH_URL)
@ -98,9 +100,8 @@ class V3FederatedPlugin(utils.TestCase):
class K2KAuthPluginTest(utils.TestCase):
TEST_ROOT_URL = 'http://127.0.0.1:5000/'
TEST_URL = '%s%s' % (TEST_ROOT_URL, 'v3')
TEST_URL = '{}{}'.format(TEST_ROOT_URL, 'v3')
TEST_PASS = 'password'
REQUEST_ECP_URL = TEST_URL + '/auth/OS-FEDERATION/saml2/ecp'
@ -108,39 +109,45 @@ class K2KAuthPluginTest(utils.TestCase):
SP_ROOT_URL = 'https://sp1.com/v3'
SP_ID = 'sp1'
SP_URL = 'https://sp1.com/Shibboleth.sso/SAML2/ECP'
SP_AUTH_URL = (SP_ROOT_URL +
'/OS-FEDERATION/identity_providers'
'/testidp/protocols/saml2/auth')
SP_AUTH_URL = (
SP_ROOT_URL + '/OS-FEDERATION/identity_providers'
'/testidp/protocols/saml2/auth'
)
SERVICE_PROVIDER_DICT = {
'id': SP_ID,
'auth_url': SP_AUTH_URL,
'sp_url': SP_URL
'sp_url': SP_URL,
}
def setUp(self):
super(K2KAuthPluginTest, self).setUp()
super().setUp()
self.token_v3 = fixture.V3Token()
self.token_v3.add_service_provider(
self.SP_ID, self.SP_AUTH_URL, self.SP_URL)
self.SP_ID, self.SP_AUTH_URL, self.SP_URL
)
self.session = session.Session()
self.k2kplugin = self.get_plugin()
def _get_base_plugin(self):
self.stub_url('POST', ['auth', 'tokens'],
headers={'X-Subject-Token': uuid.uuid4().hex},
json=self.token_v3)
return v3.Password(self.TEST_URL,
username=self.TEST_USER,
password=self.TEST_PASS)
self.stub_url(
'POST',
['auth', 'tokens'],
headers={'X-Subject-Token': uuid.uuid4().hex},
json=self.token_v3,
)
return v3.Password(
self.TEST_URL, username=self.TEST_USER, password=self.TEST_PASS
)
def _mock_k2k_flow_urls(self, redirect_code=302):
# List versions available for auth
self.requests_mock.get(
self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'})
headers={'Content-Type': 'application/json'},
)
# The IdP should return a ECP wrapped assertion when requested
self.requests_mock.register_uri(
@ -148,7 +155,8 @@ class K2KAuthPluginTest(utils.TestCase):
self.REQUEST_ECP_URL,
content=bytes(k2k_fixtures.ECP_ENVELOPE, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=200)
status_code=200,
)
# The SP should respond with a redirect (302 or 303)
self.requests_mock.register_uri(
@ -156,14 +164,16 @@ class K2KAuthPluginTest(utils.TestCase):
self.SP_URL,
content=bytes(k2k_fixtures.TOKEN_BASED_ECP, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=redirect_code)
status_code=redirect_code,
)
# Should not follow the redirect URL, but use the auth_url attribute
self.requests_mock.register_uri(
'GET',
self.SP_AUTH_URL,
json=k2k_fixtures.UNSCOPED_TOKEN,
headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER})
headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER},
)
def get_plugin(self, **kwargs):
kwargs.setdefault('base_plugin', self._get_base_plugin())
@ -178,84 +188,108 @@ class K2KAuthPluginTest(utils.TestCase):
self.requests_mock.get(
self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'})
headers={'Content-Type': 'application/json'},
)
self.requests_mock.register_uri(
'POST', self.REQUEST_ECP_URL,
status_code=401)
'POST', self.REQUEST_ECP_URL, status_code=401
)
self.assertRaises(exceptions.AuthorizationFailure,
self.k2kplugin._get_ecp_assertion,
self.session)
self.assertRaises(
exceptions.AuthorizationFailure,
self.k2kplugin._get_ecp_assertion,
self.session,
)
def test_get_ecp_assertion_empty_response(self):
self.requests_mock.get(
self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'})
headers={'Content-Type': 'application/json'},
)
self.requests_mock.register_uri(
'POST', self.REQUEST_ECP_URL,
'POST',
self.REQUEST_ECP_URL,
headers={'Content-Type': 'application/vnd.paos+xml'},
content=b'', status_code=200)
content=b'',
status_code=200,
)
self.assertRaises(exceptions.InvalidResponse,
self.k2kplugin._get_ecp_assertion,
self.session)
self.assertRaises(
exceptions.InvalidResponse,
self.k2kplugin._get_ecp_assertion,
self.session,
)
def test_get_ecp_assertion_wrong_headers(self):
self.requests_mock.get(
self.TEST_URL,
json={'version': fixture.V3Discovery(self.TEST_URL)},
headers={'Content-Type': 'application/json'})
headers={'Content-Type': 'application/json'},
)
self.requests_mock.register_uri(
'POST', self.REQUEST_ECP_URL,
'POST',
self.REQUEST_ECP_URL,
headers={'Content-Type': uuid.uuid4().hex},
content=b'', status_code=200)
content=b'',
status_code=200,
)
self.assertRaises(exceptions.InvalidResponse,
self.k2kplugin._get_ecp_assertion,
self.session)
self.assertRaises(
exceptions.InvalidResponse,
self.k2kplugin._get_ecp_assertion,
self.session,
)
def test_send_ecp_authn_response(self):
self._mock_k2k_flow_urls()
# Perform the request
response = self.k2kplugin._send_service_provider_ecp_authn_response(
self.session, self.SP_URL, self.SP_AUTH_URL)
self.session, self.SP_URL, self.SP_AUTH_URL
)
# Check the response
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER,
response.headers['X-Subject-Token'])
self.assertEqual(
k2k_fixtures.UNSCOPED_TOKEN_HEADER,
response.headers['X-Subject-Token'],
)
def test_send_ecp_authn_response_303_redirect(self):
self._mock_k2k_flow_urls(redirect_code=303)
# Perform the request
response = self.k2kplugin._send_service_provider_ecp_authn_response(
self.session, self.SP_URL, self.SP_AUTH_URL)
self.session, self.SP_URL, self.SP_AUTH_URL
)
# Check the response
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER,
response.headers['X-Subject-Token'])
self.assertEqual(
k2k_fixtures.UNSCOPED_TOKEN_HEADER,
response.headers['X-Subject-Token'],
)
def test_end_to_end_workflow(self):
self._mock_k2k_flow_urls()
auth_ref = self.k2kplugin.get_auth_ref(self.session)
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER,
auth_ref.auth_token)
self.assertEqual(
k2k_fixtures.UNSCOPED_TOKEN_HEADER, auth_ref.auth_token
)
def test_end_to_end_workflow_303_redirect(self):
self._mock_k2k_flow_urls(redirect_code=303)
auth_ref = self.k2kplugin.get_auth_ref(self.session)
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER,
auth_ref.auth_token)
self.assertEqual(
k2k_fixtures.UNSCOPED_TOKEN_HEADER, auth_ref.auth_token
)
def test_end_to_end_with_generic_password(self):
# List versions available for auth
self.requests_mock.get(
self.TEST_ROOT_URL,
json=fixture.DiscoveryList(self.TEST_ROOT_URL),
headers={'Content-Type': 'application/json'})
headers={'Content-Type': 'application/json'},
)
# The IdP should return a ECP wrapped assertion when requested
self.requests_mock.register_uri(
@ -263,7 +297,8 @@ class K2KAuthPluginTest(utils.TestCase):
self.REQUEST_ECP_URL,
content=bytes(k2k_fixtures.ECP_ENVELOPE, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=200)
status_code=200,
)
# The SP should respond with a redirect (302 or 303)
self.requests_mock.register_uri(
@ -271,24 +306,33 @@ class K2KAuthPluginTest(utils.TestCase):
self.SP_URL,
content=bytes(k2k_fixtures.TOKEN_BASED_ECP, 'latin-1'),
headers={'Content-Type': 'application/vnd.paos+xml'},
status_code=302)
status_code=302,
)
# Should not follow the redirect URL, but use the auth_url attribute
self.requests_mock.register_uri(
'GET',
self.SP_AUTH_URL,
json=k2k_fixtures.UNSCOPED_TOKEN,
headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER})
headers={'X-Subject-Token': k2k_fixtures.UNSCOPED_TOKEN_HEADER},
)
self.stub_url('POST', ['auth', 'tokens'],
headers={'X-Subject-Token': uuid.uuid4().hex},
json=self.token_v3)
self.stub_url(
'POST',
['auth', 'tokens'],
headers={'X-Subject-Token': uuid.uuid4().hex},
json=self.token_v3,
)
plugin = identity.Password(self.TEST_ROOT_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
user_domain_id='default')
plugin = identity.Password(
self.TEST_ROOT_URL,
username=self.TEST_USER,
password=self.TEST_PASS,
user_domain_id='default',
)
k2kplugin = self.get_plugin(base_plugin=plugin)
self.assertEqual(k2k_fixtures.UNSCOPED_TOKEN_HEADER,
k2kplugin.get_token(self.session))
self.assertEqual(
k2k_fixtures.UNSCOPED_TOKEN_HEADER,
k2kplugin.get_token(self.session),
)

Some files were not shown because too many files have changed in this diff Show More