import logging import sys import json from hashlib import sha1 from saml2.httpbase import HTTPBase from saml2.extension.idpdisc import BINDING_DISCO from saml2.extension.idpdisc import DiscoveryResponse from saml2.mdie import to_dict from saml2 import md from saml2 import samlp from saml2 import SAMLError from saml2 import BINDING_HTTP_REDIRECT from saml2 import BINDING_HTTP_POST from saml2 import BINDING_SOAP from saml2.s_utils import UnsupportedBinding, UnknownPrincipal from saml2.sigver import split_len from saml2.validate import valid_instance from saml2.time_util import valid from saml2.validate import NotValid from saml2.sigver import security_context from importlib import import_module __author__ = 'rolandh' logger = logging.getLogger(__name__) class ToOld(Exception): pass REQ2SRV = { # IDP "authn_request": "single_sign_on_service", "name_id_mapping_request": "name_id_mapping_service", # AuthnAuthority "authn_query": "authn_query_service", # AttributeAuthority "attribute_query": "attribute_service", # PDP "authz_decision_query": "authz_service", # AuthnAuthority + IDP + PDP + AttributeAuthority "assertion_id_request": "assertion_id_request_service", # IDP + SP "logout_request": "single_logout_service", "manage_name_id_request": "manage_name_id_service", "artifact_query": "artifact_resolution_service", # SP "assertion_response": "assertion_consumer_service", "attribute_response": "attribute_consuming_service", "discovery_service_request": "discovery_response" } ENTITYATTRIBUTES = "urn:oasis:names:tc:SAML:metadata:attribute&EntityAttributes" # --------------------------------------------------- def destinations(srvs): return [s["location"] for s in srvs] def attribute_requirement(entity): res = {"required": [], "optional": []} for acs in entity["attribute_consuming_service"]: for attr in acs["requested_attribute"]: if "is_required" in attr and attr["is_required"] == "true": res["required"].append(attr) else: res["optional"].append(attr) return res def name(ent, langpref="en"): try: org = ent["organization"] except KeyError: return None for info in ["organization_display_name", "organization_name", "organization_url"]: try: for item in org[info]: if item["lang"] == langpref: return item["text"] except KeyError: pass return None def repack_cert(cert): part = cert.split("\n") if len(part) == 1: part = part[0].strip() return "\n".join(split_len(part, 64)) else: return "\n".join([s.strip() for s in part]) class MetaData(object): def __init__(self, onts, attrc, metadata="", node_name=None, **kwargs): self.onts = onts self.attrc = attrc self.entity = {} self.metadata = metadata self.security = None self.node_name = node_name def items(self): return self.entity.items() def keys(self): return self.entity.keys() def values(self): return self.entity.values() def __contains__(self, item): return item in self.entity def __getitem__(self, item): return self.entity[item] def do_entity_descriptor(self, entity_descr): try: if not valid(entity_descr.valid_until): logger.info("Entity descriptor (entity id:%s) to old" % ( entity_descr.entity_id,)) return except AttributeError: pass # have I seen this entity_id before ? If so if log: ignore it if entity_descr.entity_id in self.entity: print >> sys.stderr, \ "Duplicated Entity descriptor (entity id: '%s')" % \ entity_descr.entity_id return _ent = to_dict(entity_descr, self.onts) flag = 0 # verify support for SAML2 for descr in ["spsso", "idpsso", "role", "authn_authority", "attribute_authority", "pdp", "affiliation"]: _res = [] try: _items = _ent["%s_descriptor" % descr] except KeyError: continue if descr == "affiliation": # Not protocol specific flag += 1 continue for item in _items: for prot in item["protocol_support_enumeration"].split(" "): if prot == samlp.NAMESPACE: item["protocol_support_enumeration"] = prot _res.append(item) break if not _res: del _ent["%s_descriptor" % descr] else: flag += 1 if flag: self.entity[entity_descr.entity_id] = _ent def parse(self, xmlstr): self.entities_descr = md.entities_descriptor_from_string(xmlstr) if not self.entities_descr: self.entity_descr = md.entity_descriptor_from_string(xmlstr) if self.entity_descr: self.do_entity_descriptor(self.entity_descr) else: try: valid_instance(self.entities_descr) except NotValid, exc: logger.error(exc.args[0]) return try: if not valid(self.entities_descr.valid_until): raise ToOld("Metadata not valid anymore, it's after %s" % ( self.entities_descr.valid_until,)) except AttributeError: pass for entity_descr in self.entities_descr.entity_descriptor: self.do_entity_descriptor(entity_descr) def load(self): self.parse(self.metadata) def _service(self, entity_id, typ, service, binding=None): """ Get me all services with a specified entity ID and type, that supports the specified version of binding. :param entity_id: The EntityId :param typ: Type of service (idp, attribute_authority, ...) :param service: which service that is sought for :param binding: A binding identifier :return: list of service descriptions. Or if no binding was specified a list of 2-tuples (binding, srv) """ logger.debug("_service(%s, %s, %s, %s)" % (entity_id, typ, service, binding)) try: srvs = [] for t in self[entity_id][typ]: try: srvs.extend(t[service]) except KeyError: pass except KeyError: return None if not srvs: return srvs if binding: res = [] for srv in srvs: if srv["binding"] == binding: res.append(srv) else: res = {} for srv in srvs: try: res[srv["binding"]].append(srv) except KeyError: res[srv["binding"]] = [srv] logger.debug("_service => %s" % res) return res def _ext_service(self, entity_id, typ, service, binding): try: srvs = self[entity_id][typ] except KeyError: return None if not srvs: return srvs res = [] for srv in srvs: if "extensions" in srv: for elem in srv["extensions"]["extension_elements"]: if elem["__class__"] == service: if elem["binding"] == binding: res.append(elem) return res def any(self, typ, service, binding=None): """ Return any entity that matches the specification :param typ: :param service: :param binding: :return: """ res = {} for ent in self.keys(): bind = self._service(ent, typ, service, binding) if bind: res[ent] = bind return res def bindings(self, entity_id, typ, service): """ Get me all the bindings that are registered for a service entity :param entity_id: :param service: :return: """ return self._service(entity_id, typ, service) def attribute_requirement(self, entity_id, index=0): """ Returns what attributes the SP requires and which are optional if any such demands are registered in the Metadata. :param entity_id: The entity id of the SP :param index: which of the attribute consumer services its all about :return: 2-tuple, list of required and list of optional attributes """ res = {"required": [], "optional": []} try: for sp in self[entity_id]["spsso_descriptor"]: _res = attribute_requirement(sp) res["required"].extend(_res["required"]) res["optional"].extend(_res["optional"]) except KeyError: return None return res def dumps(self): return json.dumps(self.items(), indent=2) def with_descriptor(self, descriptor): res = {} desc = "%s_descriptor" % descriptor for eid, ent in self.items(): if desc in ent: res[eid] = ent return res def __str__(self): return "%s" % self.items() def construct_source_id(self): res = {} for eid, ent in self.items(): for desc in ["spsso_descriptor", "idpsso_descriptor"]: try: for srv in ent[desc]: if "artifact_resolution_service" in srv: s = sha1(eid) res[s.digest()] = ent except KeyError: pass return res def entity_categories(self, entity_id): res = [] if "extensions" in self[entity_id]: for elem in self[entity_id]["extensions"]["extension_elements"]: if elem["__class__"] == ENTITYATTRIBUTES: for attr in elem["attribute"]: res.append(attr["text"]) return res def __eq__(self, other): try: assert isinstance(other, MetaData) except AssertionError: return False if len(self.entity) != len(other.entity): return False if set(self.entity.keys()) != set(other.entity.keys()): return False for key, item in self.entity.items(): try: assert item == other[key] except AssertionError: return False return True class MetaDataFile(MetaData): """ Handles Metadata file on the same machine. The format of the file is the SAML Metadata format. """ def __init__(self, onts, attrc, filename, cert=None, **kwargs): MetaData.__init__(self, onts, attrc, **kwargs) self.filename = filename self.cert = cert def get_metadata_content(self): return open(self.filename).read() def load(self): _txt = self.get_metadata_content() if self.cert: node_name = self.node_name \ or "%s:%s" % (md.EntitiesDescriptor.c_namespace, md.EntitiesDescriptor.c_tag) if self.security.verify_signature(_txt, node_name=node_name, cert_file=self.cert): self.parse(_txt) return True else: self.parse(_txt) return True class MetaDataLoader(MetaDataFile): """ Handles Metadata file loaded by a passed in function. The format of the file is the SAML Metadata format. """ def __init__(self, onts, attrc, loader_callable, cert=None, **kwargs): MetaData.__init__(self, onts, attrc, **kwargs) self.metadata_provider_callable = self.get_metadata_loader(loader_callable) self.cert = cert def get_metadata_loader(self, func): if callable(func): return func i = func.rfind('.') module, attr = func[:i], func[i + 1:] try: mod = import_module(module) except Exception, e: raise RuntimeError('Cannot find metadata provider function %s: "%s"' % (func, e)) try: metadata_loader = getattr(mod, attr) except AttributeError: raise RuntimeError( 'Module "%s" does not define a "%s" metadata loader' % (module, attr) ) if not callable(metadata_loader): raise RuntimeError( 'Metadata loader %s.%s must be callable' % (module, attr) ) return metadata_loader def get_metadata_content(self): return self.metadata_provider_callable() class MetaDataExtern(MetaData): """ Class that handles metadata store somewhere on the net. Accessible but HTTP GET. """ def __init__(self, onts, attrc, url, security, cert, http, **kwargs): """ :params onts: :params attrc: :params url: :params security: SecurityContext() :params cert: :params http: """ MetaData.__init__(self, onts, attrc, **kwargs) self.url = url self.security = security self.cert = cert self.http = http def load(self): """ Imports metadata by the use of HTTP GET. If the fingerprint is known the file will be checked for compliance before it is imported. """ response = self.http.send(self.url) if response.status_code == 200: node_name = self.node_name \ or "%s:%s" % (md.EntitiesDescriptor.c_namespace, md.EntitiesDescriptor.c_tag) _txt = response.text.encode("utf-8") if self.cert: if self.security.verify_signature(_txt, node_name=node_name, cert_file=self.cert): self.parse(_txt) return True else: self.parse(_txt) return True else: logger.info("Response status: %s" % response.status_code) return False class MetaDataMD(MetaData): """ Handles locally stored metadata, the file format is the text representation of the Python representation of the metadata. """ def __init__(self, onts, attrc, filename, **kwargs): MetaData.__init__(self, onts, attrc, **kwargs) self.filename = filename def load(self): for key, item in json.loads(open(self.filename).read()): self.entity[key] = item class MetadataStore(object): def __init__(self, onts, attrc, config, ca_certs=None, disable_ssl_certificate_validation=False): """ :params onts: :params attrc: :params config: Config() :params ca_certs: :params disable_ssl_certificate_validation: """ self.onts = onts self.attrc = attrc self.http = HTTPBase(verify=disable_ssl_certificate_validation, ca_bundle=ca_certs) self.security = security_context(config) self.ii = 0 self.metadata = {} def load(self, typ, *args, **kwargs): if typ == "local": key = args[0] md = MetaDataFile(self.onts, self.attrc, args[0]) elif typ == "inline": self.ii += 1 key = self.ii md = MetaData(self.onts, self.attrc, args[0], **kwargs) elif typ == "remote": key = kwargs["url"] md = MetaDataExtern(self.onts, self.attrc, kwargs["url"], self.security, kwargs["cert"], self.http, node_name=kwargs.get('node_name')) elif typ == "mdfile": key = args[0] md = MetaDataMD(self.onts, self.attrc, args[0]) elif typ == "loader": key = args[0] md = MetaDataLoader(self.onts, self.attrc, args[0]) else: raise SAMLError("Unknown metadata type '%s'" % typ) md.load() self.metadata[key] = md def imp(self, spec): for key, vals in spec.items(): for val in vals: if isinstance(val, dict): self.load(key, **val) else: self.load(key, val) def _service(self, entity_id, typ, service, binding=None): known_principal = False for key, md in self.metadata.items(): srvs = md._service(entity_id, typ, service, binding) if srvs: return srvs elif srvs is None: pass else: known_principal = True if known_principal: logger.error("Unsupported binding: %s (%s)" % (binding, entity_id)) raise UnsupportedBinding(binding) else: logger.error("Unknown principal: %s" % entity_id) raise UnknownPrincipal(entity_id) def _ext_service(self, entity_id, typ, service, binding=None): known_principal = False for key, md in self.metadata.items(): srvs = md._ext_service(entity_id, typ, service, binding) if srvs: return srvs elif srvs is None: pass else: known_principal = True if known_principal: raise UnsupportedBinding(binding) else: raise UnknownPrincipal(entity_id) def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"): # IDP if binding is None: binding = BINDING_HTTP_REDIRECT return self._service(entity_id, "idpsso_descriptor", "single_sign_on_service", binding) def name_id_mapping_service(self, entity_id, binding=None, typ="idpsso"): # IDP if binding is None: binding = BINDING_HTTP_REDIRECT return self._service(entity_id, "idpsso_descriptor", "name_id_mapping_service", binding) def authn_query_service(self, entity_id, binding=None, typ="authn_authority"): # AuthnAuthority if binding is None: binding = BINDING_SOAP return self._service(entity_id, "authn_authority_descriptor", "authn_query_service", binding) def attribute_service(self, entity_id, binding=None, typ="attribute_authority"): # AttributeAuthority if binding is None: binding = BINDING_HTTP_REDIRECT return self._service(entity_id, "attribute_authority_descriptor", "attribute_service", binding) def authz_service(self, entity_id, binding=None, typ="pdp"): # PDP if binding is None: binding = BINDING_SOAP return self._service(entity_id, "pdp_descriptor", "authz_service", binding) def assertion_id_request_service(self, entity_id, binding=None, typ=None): # AuthnAuthority + IDP + PDP + AttributeAuthority if typ is None: raise AttributeError("Missing type specification") if binding is None: binding = BINDING_SOAP return self._service(entity_id, "%s_descriptor" % typ, "assertion_id_request_service", binding) def single_logout_service(self, entity_id, binding=None, typ=None): # IDP + SP if typ is None: raise AttributeError("Missing type specification") if binding is None: binding = BINDING_HTTP_REDIRECT return self._service(entity_id, "%s_descriptor" % typ, "single_logout_service", binding) def manage_name_id_service(self, entity_id, binding=None, typ=None): # IDP + SP if binding is None: binding = BINDING_HTTP_REDIRECT return self._service(entity_id, "%s_descriptor" % typ, "manage_name_id_service", binding) def artifact_resolution_service(self, entity_id, binding=None, typ=None): # IDP + SP if binding is None: binding = BINDING_HTTP_REDIRECT return self._service(entity_id, "%s_descriptor" % typ, "artifact_resolution_service", binding) def assertion_consumer_service(self, entity_id, binding=None, _="spsso"): # SP if binding is None: binding = BINDING_HTTP_POST return self._service(entity_id, "spsso_descriptor", "assertion_consumer_service", binding) def attribute_consuming_service(self, entity_id, binding=None, _="spsso"): # SP if binding is None: binding = BINDING_HTTP_REDIRECT return self._service(entity_id, "spsso_descriptor", "attribute_consuming_service", binding) def discovery_response(self, entity_id, binding=None, _="spsso"): if binding is None: binding = BINDING_DISCO return self._ext_service(entity_id, "spsso_descriptor", "%s&%s" % (DiscoveryResponse.c_namespace, DiscoveryResponse.c_tag), binding) def attribute_requirement(self, entity_id, index=0): for md in self.metadata.values(): if entity_id in md: return md.attribute_requirement(entity_id, index) def keys(self): res = [] for md in self.metadata.values(): res.extend(md.keys()) return res def __getitem__(self, item): for md in self.metadata.values(): try: return md[item] except KeyError: pass raise KeyError(item) def __setitem__(self, key, value): self.metadata[key] = value def entities(self): num = 0 for md in self.metadata.values(): num += len(md.items()) return num def __len__(self): return len(self.metadata) def with_descriptor(self, descriptor): res = {} for md in self.metadata.values(): res.update(md.with_descriptor(descriptor)) return res def name(self, entity_id, langpref="en"): for md in self.metadata.values(): if entity_id in md.items(): return name(md[entity_id], langpref) return None def certs(self, entity_id, descriptor, use="signing"): ent = self.__getitem__(entity_id) if descriptor == "any": res = [] for descr in ["spsso", "idpsso", "role", "authn_authority", "attribute_authority", "pdp"]: try: srvs = ent["%s_descriptor" % descr] except KeyError: continue for srv in srvs: for key in srv["key_descriptor"]: if "use" in key and key["use"] == use: for dat in key["key_info"]["x509_data"]: cert = repack_cert( dat["x509_certificate"]["text"]) if cert not in res: res.append(cert) elif not "use" in key: for dat in key["key_info"]["x509_data"]: cert = repack_cert( dat["x509_certificate"]["text"]) if cert not in res: res.append(cert) else: srvs = ent["%s_descriptor" % descriptor] res = [] for srv in srvs: for key in srv["key_descriptor"]: if "use" in key and key["use"] == use: for dat in key["key_info"]["x509_data"]: res.append(dat["x509_certificate"]["text"]) elif not "use" in key: for dat in key["key_info"]["x509_data"]: res.append(dat["x509_certificate"]["text"]) return res def vo_members(self, entity_id): ad = self.__getitem__(entity_id)["affiliation_descriptor"] return [m["text"] for m in ad["affiliate_member"]] def entity_categories(self, entity_id): ext = self.__getitem__(entity_id)["extensions"] res = [] for elem in ext["extension_elements"]: if elem["__class__"] == ENTITYATTRIBUTES: for attr in elem["attribute"]: if attr["name"] == "http://macedir.org/entity-category": res.extend([v["text"] for v in attr["attribute_value"]]) return res def bindings(self, entity_id, typ, service): for md in self.metadata.values(): if entity_id in md.items(): return md.bindings(entity_id, typ, service) return None def __str__(self): _str = ["{"] for key, val in self.metadata.items(): _str.append("%s: %s" % (key, val)) _str.append("}") return "\n".join(_str) def construct_source_id(self): res = {} for md in self.metadata.values(): res.update(md.construct_source_id()) return res def items(self): res = {} for md in self.metadata.values(): res.update(md.items()) return res.items() def _providers(self, descriptor): res = [] for md in self.metadata.values(): for ent_id, ent_desc in md.items(): if descriptor in ent_desc: res.append(ent_id) return res def service_providers(self): return self._providers("spsso_descriptor") def identity_providers(self): return self._providers("idpsso_descriptor") def attribute_authorities(self): return self._providers("attribute_authority")