Added some extra functionality.

This commit is contained in:
Roland Hedberg
2014-09-11 20:04:37 +02:00
parent be48f27945
commit 7c2fe90b1b

View File

@@ -67,9 +67,12 @@ def destinations(srvs):
return [s["location"] for s in srvs]
def attribute_requirement(entity):
def attribute_requirement(entity, index=None):
res = {"required": [], "optional": []}
for acs in entity["attribute_consuming_service"]:
if index is not None and acs["index"] != index:
continue
for attr in acs["requested_attribute"]:
if "is_required" in attr and attr["is_required"] == "true":
res["required"].append(attr)
@@ -133,6 +136,9 @@ class MetaData(object):
def __getitem__(self, item):
return self.entity[item]
def __setitem__(self, key, value):
self.entity[key] = value
def do_entity_descriptor(self, entity_descr):
if self.check_validity:
try:
@@ -297,12 +303,14 @@ class MetaData(object):
return self.service(entity_id, typ, service)
def attribute_requirement(self, entity_id, index=0):
def attribute_requirement(self, entity_id, index=None):
""" Returns what attributes the SP requires and which are optional
if any such demands are registered in the Metadata.
:param entity_id: The entity id of the SP
:param index: which of the attribute consumer services its all about
if index=None then return all attributes expected by all
attribute_consuming_services.
:return: 2-tuple, list of required and list of optional attributes
"""
@@ -310,7 +318,7 @@ class MetaData(object):
try:
for sp in self[entity_id]["spsso_descriptor"]:
_res = attribute_requirement(sp)
_res = attribute_requirement(sp, index)
res["required"].extend(_res["required"])
res["optional"].extend(_res["optional"])
except KeyError:
@@ -513,6 +521,7 @@ class MetaDataMD(MetaData):
class MetadataStore(object):
def __init__(self, onts, attrc, config, ca_certs=None,
check_validity=True,
disable_ssl_certificate_validation=False):
"""
:params onts:
@@ -523,11 +532,16 @@ class MetadataStore(object):
"""
self.onts = onts
self.attrc = attrc
self.http = HTTPBase(verify=disable_ssl_certificate_validation,
ca_bundle=ca_certs)
if disable_ssl_certificate_validation:
self.http = HTTPBase(verify=False, ca_bundle=ca_certs)
else:
self.http = HTTPBase(verify=True, ca_bundle=ca_certs)
self.security = security_context(config)
self.ii = 0
self.metadata = {}
self.check_validity = check_validity
def load(self, typ, *args, **kwargs):
if typ == "local":
@@ -539,10 +553,16 @@ class MetadataStore(object):
_md = MetaData(self.onts, self.attrc, args[0], **kwargs)
elif typ == "remote":
key = kwargs["url"]
_args = {}
for _key in ["node_name", "check_validity"]:
try:
_args[_key] = kwargs[_key]
except KeyError:
pass
_md = MetaDataExtern(self.onts, self.attrc,
kwargs["url"], self.security,
kwargs["cert"], self.http,
node_name=kwargs.get('node_name'))
kwargs["cert"], self.http, **_args)
elif typ == "mdfile":
key = args[0]
_md = MetaDataMD(self.onts, self.attrc, args[0])
@@ -559,6 +579,8 @@ class MetadataStore(object):
for key, vals in spec.items():
for val in vals:
if isinstance(val, dict):
if not self.check_validity:
val["check_validity"] = False
self.load(key, **val)
else:
self.load(key, val)
@@ -863,6 +885,10 @@ class MetadataStore(object):
for _md in self.metadata.values():
for ent_id, ent_desc in _md.items():
if descriptor in ent_desc:
if ent_id in res:
#print "duplicated entity_id: %s" % res
pass
else:
res.append(ent_id)
return res