diff --git a/example/sp/sp.py b/example/sp/sp.py index 5f8147b..ce100a0 100755 --- a/example/sp/sp.py +++ b/example/sp/sp.py @@ -109,7 +109,7 @@ def logout(environ, start_response, user): else: # All was done using SOAP if result: start_response("302 Found", [("Location", target)]) - return ["Successfull Logout"] + return ["Successful Logout"] else: start_response("500 Internal Server Error") return ["Failed to logout from identity services"] diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index ea8970b..95a18ec 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -17,6 +17,7 @@ import logging import re +from saml2.saml import NAME_FORMAT_URI import xmlenc from saml2 import saml @@ -237,6 +238,18 @@ def filter_attribute_value_assertions(ava, attribute_restrictions=None): del ava[attr] return ava +def restriction_from_attribute_spec(attributes): + restr = {} + for attribute in attributes: + restr[attribute.name] = {} + for val in attribute.attribute_value: + if not val.text: + restr[attribute.name] = None + break + else: + restr[attribute.name] = re.compile(val.text) + return restr + class Policy(object): """ handles restrictions on assertions """ @@ -302,7 +315,7 @@ class Policy(object): :param: The SP entity ID :retur: The format """ - form = "" + form = NAME_FORMAT_URI try: form = self._restrictions[sp_entity_id]["name_form"] @@ -489,9 +502,14 @@ class Assertion(dict): :param sec_context: The security context used when encrypting :return: An Assertion instance """ + + if policy: + _name_format = policy.get_name_form(sp_entity_id) + else: + _name_format = NAME_FORMAT_URI + attr_statement = saml.AttributeStatement(attribute=from_local( - attrconvs, self, - policy.get_name_form(sp_entity_id))) + attrconvs, self, _name_format)) if encrypt == "attributes": for attr in attr_statement.attribute: @@ -505,12 +523,17 @@ class Assertion(dict): # start using now and for some time conds = policy.conditions(sp_entity_id) - + + if authn_auth or authn_class or authn_decl: + _authn_statement = self._authn_statement(authn_class, authn_auth, + authn_decl) + else: + _authn_statement = None + return assertion_factory( issuer=issuer, attribute_statement = attr_statement, - authn_statement = self._authn_statement(authn_class, authn_auth, - authn_decl), + authn_statement = _authn_statement, conditions = conds, subject=factory( saml.Subject, name_id=name_id, diff --git a/src/saml2/client.py b/src/saml2/client.py index cdd4ebc..6eedd73 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -18,6 +18,7 @@ """Contains classes and functions that a SAML2.0 Service Provider (SP) may use to conclude its tasks. """ +from saml2.samlp import logout_response_from_string import saml2 try: @@ -78,16 +79,16 @@ class Saml2Client(Base): if binding == saml2.BINDING_HTTP_POST: logger.info("HTTP POST") - response = self.send_using_http_post(_req_str, location, + (header, body) = self.use_http_form_post(_req_str, location, relay_state) elif binding == saml2.BINDING_HTTP_REDIRECT: logger.info("HTTP REDIRECT") - response = self.send_using_http_get(_req_str, location, + (header, body) = self.use_http_get(_req_str, location, relay_state) else: raise Exception("Unknown binding type: %s" % binding) - return req.id, response + return req.id, header, body def global_logout(self, subject_id, reason="", expire=None, sign=None): """ More or less a layer of indirection :-/ @@ -138,7 +139,8 @@ class Saml2Client(Base): for entity_id in entity_ids: response = False - for binding in [BINDING_SOAP, BINDING_HTTP_POST, + for binding in [#BINDING_SOAP, + BINDING_HTTP_POST, BINDING_HTTP_REDIRECT]: srvs = self.metadata.single_logout_service(entity_id, "idpsso", binding=binding) @@ -148,12 +150,13 @@ class Saml2Client(Base): destination = destinations(srvs)[0] logger.info("destination to provider: %s" % destination) - request = self.create_logout_request(subject_id, - destination, entity_id, - reason, expire) + request = self.create_logout_request(destination, entity_id, + subject_id, reason=reason, + expire=expire) to_sign = [] - #if sign and binding != BINDING_HTTP_REDIRECT: + if binding.startswith("http://"): + sign = True if sign is None: sign = self.logout_requests_signed_default @@ -176,10 +179,9 @@ class Saml2Client(Base): if response: not_done.remove(entity_id) logger.info("OK response from %s" % destination) - responses[entity_id] = response + responses[entity_id] = logout_response_from_string(response) else: - logger.info( - "NOT OK response from %s" % destination) + logger.info("NOT OK response from %s" % destination) else: session_id = request.id @@ -195,21 +197,18 @@ class Saml2Client(Base): if binding == BINDING_HTTP_POST: - response = self.send_using_http_post(srequest, - destination, - rstate) + response = self.use_http_form_post(srequest, + destination, + rstate) else: - response = self.send_using_http_get(srequest, - destination, - rstate) + response = self.use_http_get(srequest, destination, + rstate) - if response: - not_done.remove(entity_id) - logger.info("OK response from %s" % destination) - responses[entity_id] = response - else: - logger.info( - "NOT OK response from %s" % destination) + responses[entity_id] = response + not_done.remove(entity_id) + + # only try one binding + break if not_done: # upstream should try later @@ -407,10 +406,9 @@ class Saml2Client(Base): attribute=None, sp_name_qualifier=None, name_qualifier=None, nameid_format=None, real_id=None, consent=None, extensions=None, - sign=False): + sign=False, binding=BINDING_SOAP): """ Does a attribute request to an attribute authority, this is - by default done over SOAP. Other bindings could be used but not - supported right now. + by default done over SOAP. :param entityid: To whom the query should be sent :param subject_id: The identifier of the subject @@ -423,17 +421,36 @@ class Saml2Client(Base): :param nameid_format: The format of the name ID :param real_id: The identifier which is the key to this entity in the identity database + :param binding: Which binding to use :return: The attributes returned """ - location = self._sso_location(entityid, BINDING_SOAP) + srvs = self.metadata.attribute_service(entityid, binding) + if srvs == []: + raise Exception("No attribute service support at entity") - response_args = {"real_id": real_id} + destination = destinations(srvs)[0] - return self.use_soap(location, "attribute_query", consent=consent, - extensions=extensions, sign=sign, - subject_id=subject_id, attribute=attribute, - sp_name_qualifier=sp_name_qualifier, - name_qualifier=name_qualifier, - nameid_format=nameid_format, - response_args=response_args) + if real_id: + response_args = {"real_id": real_id} + else: + response_args = {} + + if binding == BINDING_SOAP: + return self.use_soap(destination, "attribute_query", consent=consent, + extensions=extensions, sign=sign, + subject_id=subject_id, attribute=attribute, + sp_name_qualifier=sp_name_qualifier, + name_qualifier=name_qualifier, + nameid_format=nameid_format, + response_args=response_args) + elif binding == BINDING_HTTP_POST: + return self.use_soap(destination, "attribute_query", consent=consent, + extensions=extensions, sign=sign, + subject_id=subject_id, attribute=attribute, + sp_name_qualifier=sp_name_qualifier, + name_qualifier=name_qualifier, + nameid_format=nameid_format, + response_args=response_args) + else: + raise Exception("Unsupported binding") \ No newline at end of file diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 8069279..51549d7 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -648,13 +648,16 @@ class Base(HTTPBase): response = None if xmlstr: - try: - # expected return address - return_addr = self.config.endpoint("single_logout_service", - binding=binding)[0] - except Exception: - logger.info("Not supposed to handle this!") - return None + if binding == BINDING_HTTP_REDIRECT: + try: + # expected return address + return_addr = self.config.endpoint("single_logout_service", + binding=binding)[0] + except Exception: + logger.info("Not supposed to handle this!") + return None + else: + return_addr = None try: response = LogoutResponse(self.sec, return_addr) diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py new file mode 100644 index 0000000..2f73187 --- /dev/null +++ b/src/saml2/httpbase.py @@ -0,0 +1,214 @@ +from Cookie import SimpleCookie +import cookielib +import copy +import requests +import time +from saml2 import class_name +from saml2.pack import http_form_post_message +from saml2.pack import make_soap_enveloped_saml_thingy +from saml2.pack import http_redirect_message + +import logging +from saml2.soap import parse_soap_enveloped_saml_response + +logger = logging.getLogger(__name__) + +__author__ = 'rolandh' + +ATTRS = {"version":None, + "name":"", + "value": None, + "port": None, + "port_specified": False, + "domain": "", + "domain_specified": False, + "domain_initial_dot": False, + "path": "", + "path_specified": False, + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": "", + "rfc2109": True} + +PAIRS = { + "port": "port_specified", + "domain": "domain_specified", + "path": "path_specified" +} + +class ConnectionError(Exception): + pass + +def _since_epoch(cdate): + # date format 'Wed, 06-Jun-2012 01:34:34 GMT' + cdate = cdate[5:-4] + try: + t = time.strptime(cdate, "%d-%b-%Y %H:%M:%S") + except ValueError: + t = time.strptime(cdate, "%d-%b-%y %H:%M:%S") + return int(time.mktime(t)) + +class HTTPBase(object): + def __init__(self, verify=True, ca_bundle=None, key_file=None, + cert_file=None): + self.request_args = {"allow_redirects": False,} + self.cookies = {} + self.cookiejar = cookielib.CookieJar() + + self.request_args["verify"] = verify + if ca_bundle: + self.request_args["verify"] = ca_bundle + if key_file: + self.request_args["cert"] = (cert_file, key_file) + + self.sec = None + + def _cookies(self): + cookie_dict = {} + + for _, a in list(self.cookiejar._cookies.items()): + for _, b in list(a.items()): + for cookie in list(b.values()): + # print cookie + cookie_dict[cookie.name] = cookie.value + + return cookie_dict + + def set_cookie(self, kaka, request): + """Returns a cookielib.Cookie based on a set-cookie header line""" + + # default rfc2109=False + # max-age, httponly + for cookie_name, morsel in kaka.items(): + std_attr = ATTRS.copy() + std_attr["name"] = cookie_name + _tmp = morsel.coded_value + if _tmp.startswith('"') and _tmp.endswith('"'): + std_attr["value"] = _tmp[1:-1] + else: + std_attr["value"] = _tmp + + std_attr["version"] = 0 + # copy attributes that have values + for attr in morsel.keys(): + if attr in ATTRS: + if morsel[attr]: + if attr == "expires": + std_attr[attr]=_since_epoch(morsel[attr]) + else: + std_attr[attr]=morsel[attr] + elif attr == "max-age": + if morsel["max-age"]: + std_attr["expires"] = _since_epoch(morsel["max-age"]) + + for att, set in PAIRS.items(): + if std_attr[att]: + std_attr[set] = True + + if std_attr["domain"] and std_attr["domain"].startswith("."): + std_attr["domain_initial_dot"] = True + + if morsel["max-age"] is 0: + try: + self.cookiejar.clear(domain=std_attr["domain"], + path=std_attr["path"], + name=std_attr["name"]) + except ValueError: + pass + else: + new_cookie = cookielib.Cookie(**std_attr) + + self.cookiejar.set_cookie(new_cookie) + + def send(self, url, method="GET", **kwargs): + _kwargs = copy.copy(self.request_args) + if kwargs: + _kwargs.update(kwargs) + + if self.cookiejar: + _kwargs["cookies"] = self._cookies() + #logger.info("SENT COOKIEs: %s" % (_kwargs["cookies"],)) + try: + r = requests.request(method, url, **_kwargs) + except requests.ConnectionError, exc: + raise ConnectionError("%s" % exc) + + try: + #logger.info("RECEIVED COOKIEs: %s" % (r.headers["set-cookie"],)) + self.set_cookie(SimpleCookie(r.headers["set-cookie"]), r) + except AttributeError, err: + pass + + return r + + def use_http_form_post(self, message, destination, relay_state): + """ + Return a form that will automagically execute and POST the message + to the recipient. + + :param message: + :param destination: + :param relay_state: + :return: tuple (header, message) + """ + if not isinstance(message, basestring): + request = "%s" % (message,) + + return http_form_post_message(message, destination, relay_state) + + + def use_http_get(self, message, destination, relay_state): + """ + Send a message using GET, this is the HTTP-Redirect case so + no direct response is expected to this request. + + :param request: + :param destination: + :param relay_state: + :return: tuple (header, None) + """ + if not isinstance(message, basestring): + request = "%s" % (message,) + + return http_redirect_message(message, destination, relay_state) + + def send_using_soap(self, request, destination, headers=None, sign=False): + """ + Send a message using SOAP+POST + + :param request: + :param destination: + :param headers: + :param sign: + :return: + """ + if headers is None: + headers = {"content-type": "application/soap+xml"} + else: + headers.update({"content-type": "application/soap+xml"}) + + soap_message = make_soap_enveloped_saml_thingy(request) + + if sign and self.sec: + _signed = self.sec.sign_statement_using_xmlsec(soap_message, + class_name(request), + nodeid=request.id) + soap_message = _signed + + #_response = self.server.post(soap_message, headers, path=path) + try: + response = self.send(destination, "POST", data=soap_message, + headers=headers) + except Exception, exc: + logger.info("HTTPClient exception: %s" % (exc,)) + return None + + if response: + logger.info("SOAP response: %s" % response) + return parse_soap_enveloped_saml_response(response) + else: + return False + diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py new file mode 100644 index 0000000..063fc41 --- /dev/null +++ b/src/saml2/mdstore.py @@ -0,0 +1,467 @@ +import logging +import httplib2 +import sys +import json +from saml2.attribute_converter import ac_factory + +from saml2.mdie import to_dict + +from saml2 import md, samlp +from saml2 import BINDING_HTTP_REDIRECT +from saml2 import BINDING_HTTP_POST +from saml2 import BINDING_SOAP +from saml2.sigver import verify_signature +from saml2.validate import valid_instance +from saml2.time_util import valid +from saml2.validate import NotValid + +__author__ = 'rolandh' + +logger = logging.getLogger(__name__) + +REQ2SRV = { + # IDP + "authn_request": "single_sign_on_service", + "nameid_mapping_request": "name_id_mapping_service", + # AuthnAuthority + "authn_query": "authn_query_service", + # AttributeAuthority + "attribute_query": "attribute_service", + # PDP + "authz_decision_query": "authz_service", + # AuthnAuthority + IDP + PDP + AttributeAuthority + "assertion_id_request": "assertion_id_request_service", + # IDP + SP + "logout_request": "single_logout_service", + "manage_nameid_query": "manage_name_id_service", + "artifact_query": "artifact_resolution_service", + # SP + "assertion_response": "assertion_consumer_service", + "attribute_response": "attribute_consuming_service", + } + +def destinations(srvs): + return [s["location"] for s in srvs] + +def attribute_requirement(entity): + res = {"required": [], "optional": []} + for acs in entity["attribute_consuming_service"]: + for attr in acs["requested_attribute"]: + if "is_required" in attr and attr["is_required"] == "true": + res["required"].append(attr) + else: + res["optional"].append(attr) + return res + +def name(ent, langpref="en"): + try: + org = ent["organization"] + except KeyError: + return None + + for info in ["organization_display_name", + "organization_name", + "organization_url"]: + try: + for item in org[info]: + if item["lang"] == langpref: + return item["text"] + except KeyError: + pass + return None + +class MetaData(object): + def __init__(self, onts, attrc, metadata=""): + self.onts = onts + self.attrc = attrc + self.entity = {} + self.metadata = metadata + + def do_entity_descriptor(self, entity_descr): + try: + if not valid(entity_descr.valid_until): + logger.info("Entity descriptor (entity id:%s) to old" % ( + entity_descr.entity_id,)) + return + except AttributeError: + pass + + # have I seen this entity_id before ? If so if log: ignore it + if entity_descr.entity_id in self.entity: + print >> sys.stderr,\ + "Duplicated Entity descriptor (entity id: '%s')" %\ + entity_descr.entity_id + return + + _ent = to_dict(entity_descr, self.onts) + flag = 0 + # verify support for SAML2 + for descr in ["spsso", "idpsso", "role", "authn_authority", + "attribute_authority", "pdp", "affiliation"]: + _res = [] + try: + _items = _ent["%s_descriptor" % descr] + except KeyError: + continue + + if descr == "affiliation": # Not protocol specific + flag += 1 + continue + + for item in _items: + for prot in item["protocol_support_enumeration"].split(" "): + if prot == samlp.NAMESPACE: + item["protocol_support_enumeration"] = [prot] + _res.append(item) + break + if not _res: + del _ent["%s_descriptor" % descr] + else: + flag += 1 + + if flag: + self.entity[entity_descr.entity_id] = _ent + + def parse(self, xmlstr): + self.entities_descr = md.entities_descriptor_from_string(xmlstr) + + if not self.entities_descr: + self.entity_descr = md.entity_descriptor_from_string(xmlstr) + if self.entity_descr: + self.do_entity_descriptor(self.entity_descr) + else: + try: + valid_instance(self.entities_descr) + except NotValid, exc: + logger.error(exc.args[0]) + return + + try: + valid(self.entities_descr.valid_until) + except AttributeError: + pass + + for entity_descr in self.entities_descr.entity_descriptor: + self.do_entity_descriptor(entity_descr) + + def load(self): + self.parse(self.metadata) + + def _service(self, entity_id, typ, service, binding=None): + """ Get me all services with a specified + entity ID and type, that supports the specified version of binding. + + :param entity_id: The EntityId + :param typ: Type of service (idp, attribute_authority, ...) + :param service: which service that is sought for + :param binding: A binding identifier + :return: list of service descriptions. + Or if no binding was specified a list of 2-tuples (binding, srv) + """ + + try: + srvs = [] + for t in self.entity[entity_id][typ]: + try: + srvs.extend(t[service]) + except KeyError: + pass + except KeyError: + return None + + if not srvs: + return srvs + + if binding: + res = [] + for srv in srvs: + if srv["binding"] == binding: + res.append(srv) + else: + res = {} + for srv in srvs: + try: + res[srv["binding"]].append(srv) + except KeyError: + res[srv["binding"]] = [srv] + return res + + def attribute_requirement(self, entity_id, index=0): + """ Returns what attributes the SP requires and which are optional + if any such demands are registered in the Metadata. + + :param entity_id: The entity id of the SP + :param index: which of the attribute consumer services its all about + :return: 2-tuple, list of required and list of optional attributes + """ + + res = {"required": [], "optional": []} + + try: + for sp in self.entity[entity_id]["spsso_descriptor"]: + _res = attribute_requirement(sp) + res["required"].extend(_res["required"]) + res["optional"].extend(_res["optional"]) + except KeyError: + return None + + return res + + def dumps(self): + return json.dumps(self.entity, indent=2) + + def with_descriptor(self, descriptor): + res = {} + desc = "%s_descriptor" % descriptor + for id, ent in self.entity.items(): + if desc in ent: + res[id] = ent + return res + +class MetaDataFile(MetaData): + def __init__(self, onts, attrc, filename): + MetaData.__init__(self, onts, attrc) + self.filename = filename + + def load(self): + self.parse(open(self.filename).read()) + +class MetaDataExtern(MetaData): + def __init__(self, onts, attrc, url, xmlsec_binary, cert, http): + MetaData.__init__(self, onts, attrc) + self.url = url + self.cert = cert + self.xmlsec_binary = xmlsec_binary + self.http = http + + def load(self): + """ Imports metadata by the use of HTTP GET. + If the fingerprint is known the file will be checked for + compliance before it is imported. + """ + (response, content) = self.http.request(self.url) + if response.status == 200: + if verify_signature(content, self.xmlsec_binary, self.cert, + node_name="%s:%s" % (md.EntitiesDescriptor.c_namespace, + md.EntitiesDescriptor.c_tag)): + self.parse(content) + return True + else: + logger.info("Response status: %s" % response.status) + return False + +class MetaDataMD(MetaData): + def __init__(self, onts, attrc, filename): + MetaData.__init__(self, onts, attrc) + self.filename = filename + + def load(self): + self.entity = eval(open(self.filename).read()) + +class MetadataStore(object): + def __init__(self, onts, attrc, xmlsec_binary=None, ca_certs=None, + disable_ssl_certificate_validation=False): + self.onts = onts + self.attrc = attrc + self.http = httplib2.Http(ca_certs=ca_certs, + disable_ssl_certificate_validation=disable_ssl_certificate_validation) + self.xmlsec_binary = xmlsec_binary + self.ii = 0 + self.metadata = {} + + def load(self, type, *args, **kwargs): + if type == "local": + key = args[0] + md = MetaDataFile(self.onts, self.attrc, args[0]) + elif type == "inline": + self.ii += 1 + key = self.ii + md = MetaData(self.onts, self.attrc) + elif type == "remote": + key = kwargs["url"] + md = MetaDataExtern(self.onts, self.attrc, + kwargs["url"], self.xmlsec_binary, + kwargs["cert"], self.http) + elif type == "mdfile": + key = args[0] + md = MetaDataMD(self.onts, self.attrc, args[0]) + else: + raise Exception("Unknown metadata type '%s'" % type) + + md.load() + self.metadata[key] = md + + def imp(self, spec): + for key, vals in spec.items(): + for val in vals: + if isinstance(val, dict): + self.load(key, **val) + else: + self.load(key, val) + + def _service(self, entity_id, typ, service, binding=None): + for key, md in self.metadata.items(): + srvs = md._service(entity_id, typ, service, binding) + if srvs: + return srvs + return [] + + def single_sign_on_service(self, entity_id, binding=None): + # IDP + + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self._service(entity_id, "idpsso_descriptor", + "single_sign_on_service", binding) + + def name_id_mapping_service(self, entity_id, binding=None): + # IDP + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self._service(entity_id, "idpsso_descriptor", + "name_id_mapping_service", binding) + + def authn_query_service(self, entity_id, binding=None): + # AuthnAuthority + if binding is None: + binding = BINDING_SOAP + return self._service(entity_id, "authn_authority_descriptor", + "authn_query_service", binding) + + def attribute_service(self, entity_id, binding=None): + # AttributeAuthority + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self._service(entity_id, "attribute_authority_descriptor", + "attribute_service", binding) + + def authz_service(self, entity_id, binding=None): + # PDP + if binding is None: + binding = BINDING_SOAP + return self._service(entity_id, "pdp_descriptor", + "authz_service", binding) + + def assertion_id_request_service(self, entity_id, typ, binding=None): + # AuthnAuthority + IDP + PDP + AttributeAuthority + if binding is None: + binding = BINDING_SOAP + return self._service(entity_id, "%s_descriptor" % typ, + "assertion_id_request_service", binding) + + def single_logout_service(self, entity_id, typ, binding=None): + # IDP + SP + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self._service(entity_id, "%s_descriptor" % typ, + "single_logout_service", binding) + + def manage_name_id_service(self, entity_id, typ, binding=None): + # IDP + SP + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self._service(entity_id, "%s_descriptor" % typ, + "manage_name_id_service", binding) + + def artifact_resolution_service(self, entity_id, typ, binding=None): + # IDP + SP + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self._service(entity_id, "%s_descriptor" % typ, + "artifact_resolution_service", binding) + + def assertion_consumer_service(self, entity_id, binding=None): + # SP + if binding is None: + binding = BINDING_HTTP_POST + return self._service(entity_id, "spsso_descriptor", + "assertion_consumer_service", binding) + + def attribute_consuming_service(self, entity_id, binding=None): + # SP + if binding is None: + binding = BINDING_HTTP_REDIRECT + return self._service(entity_id, "spsso_descriptor", + "attribute_consuming_service", binding) + + def attribute_requirement(self, entity_id, index=0): + for md in self.metadata.values(): + if entity_id in md.entity: + return md.attribute_requirement(entity_id, index) + + def keys(self): + res = [] + for md in self.metadata.values(): + res.extend(md.entity.keys()) + return res + + def __getitem__(self, item): + for md in self.metadata.values(): + try: + return md.entity[item] + except KeyError: + pass + + raise KeyError(item) + + def entities(self): + num = 0 + for md in self.metadata.values(): + num += len(md.entity) + + return num + + def __len__(self): + return len(self.metadata) + + def with_descriptor(self, descriptor): + res = {} + for md in self.metadata.values(): + res.update(md.with_descriptor(descriptor)) + return res + + def name(self, entity_id, langpref="en"): + for md in self.metadata.values(): + if entity_id in md.entity: + return name(md.entity[entity_id], langpref) + + def certs(self, entity_id, descriptor, use="signing"): + ent = self.__getitem__(entity_id) + if descriptor == "any": + res = [] + for descr in ["spsso", "idpsso", "role", "authn_authority", + "attribute_authority", "pdp"]: + try: + srvs = ent["%s_descriptor" % descr] + except KeyError: + continue + + for srv in srvs: + for key in srv["key_descriptor"]: + if "use" in key and key["use"] == use: + for dat in key["key_info"]["x509_data"]: + cert = dat["x509_certificate"]["text"] + if cert not in res: + res.append(cert) + elif not "use" in key: + for dat in key["key_info"]["x509_data"]: + cert = dat["x509_certificate"]["text"] + if cert not in res: + res.append(cert) + else: + srvs = ent["%s_descriptor" % descriptor] + + res = [] + for srv in srvs: + for key in srv["key_descriptor"]: + if "use" in key and key["use"] == use: + for dat in key["key_info"]["x509_data"]: + res.append(dat["x509_certificate"]["text"]) + elif not "use" in key: + for dat in key["key_info"]["x509_data"]: + res.append(dat["x509_certificate"]["text"]) + return res + + def vo_members(self, entity_id): + ad = self.__getitem__(entity_id)["affiliation_descriptor"] + return [m["text"] for m in ad["affiliate_member"]] diff --git a/src/saml2/pack.py b/src/saml2/pack.py index 01e4cf4..4e729ac 100644 --- a/src/saml2/pack.py +++ b/src/saml2/pack.py @@ -119,6 +119,9 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"): return headers, body +DUMMY_NAMESPACE = "http://example.org/" +PREFIX = '' + def make_soap_enveloped_saml_thingy(thingy, header_parts=None): """ Returns a soap envelope containing a SAML request as a text string. @@ -140,9 +143,22 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None): body.tag = '{%s}Body' % NAMESPACE envelope.append(body) - thingy.become_child_element_of(body) - - return ElementTree.tostring(envelope, encoding="UTF-8") + if isinstance(thingy, basestring): + thingy = thingy.replace(PREFIX, "") + _child = ElementTree.Element('') + _child.tag = '{%s}FuddleMuddle' % DUMMY_NAMESPACE + body.append(_child) + _str = ElementTree.tostring(envelope, encoding="UTF-8") + # find an remove the namespace definition + i = _str.find(DUMMY_NAMESPACE) + j = _str.rfind("xmlns:", 0, i) + cut1 = _str[j:i+len(DUMMY_NAMESPACE)+1] + cut2 = "<%s:FuddleMuddle />" % (cut1[6:9],) + _str = _str.replace(cut1, "") + return _str.replace(cut2,thingy) + else: + thingy.become_child_element_of(body) + return ElementTree.tostring(envelope, encoding="UTF-8") def http_soap_message(message): return ([("Content-type", "application/soap+xml")], @@ -188,28 +204,6 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): return body, header -# ----------------------------------------------------------------------------- -# def send_using_http_get(request, destination, key_file=None, cert_file=None, -# log=None): -# -# -# http = HTTPClient(destination, key_file, cert_file, log) -# if log: log.info("HTTP client initiated") -# -# try: -# response = http.get() -# except Exception, exc: -# if log: log.info("HTTPClient exception: %s" % (exc,)) -# return None -# -# if log: log.info("HTTP request sent and got response: %s" % response) -# -# return response - - - - - # ----------------------------------------------------------------------------- PACKING = { diff --git a/src/saml2/server.py b/src/saml2/server.py index 502c0f5..ee675c7 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -57,7 +57,7 @@ from saml2.sigver import response_factory, logoutresponse_factory from saml2.config import config_factory -from saml2.assertion import Assertion, Policy +from saml2.assertion import Assertion, Policy, restriction_from_attribute_spec, filter_attribute_value_assertions logger = logging.getLogger(__name__) @@ -194,8 +194,10 @@ class Identifier(object): nameid_format = name_id_policy.format elif sp_nid: nameid_format = sp_nid[0] - else: + elif local_policy: nameid_format = local_policy.get_nameid_format(sp_entity_id) + else: + raise Exception("Unknown NameID format") if nameid_format == saml.NAMEID_FORMAT_PERSISTENT: return self.persistent_nameid(sp_entity_id, userid) @@ -470,7 +472,7 @@ class Server(HTTPBase): # ------------------------------------------------------------------------ - def create_response(self, in_response_to, consumer_url, + def _authn_response(self, in_response_to, consumer_url, sp_entity_id, identity=None, name_id=None, status=None, authn=None, authn_decl=None, issuer=None, policy=None, @@ -569,7 +571,7 @@ class Server(HTTPBase): def create_aa_response(self, in_response_to, consumer_url, sp_entity_id, identity=None, userid="", name_id=None, status=None, issuer=None, sign_assertion=False, - sign_response=False): + sign_response=False, attributes=None): """ Create an attribute assertion response. :param in_response_to: The session identifier of the request @@ -585,22 +587,53 @@ class Server(HTTPBase): :param sign_response: Whether the whole response should be signed :return: A response instance """ -# name_id = self.ident.construct_nameid(self.conf.policy, userid, -# sp_entity_id, identity) + if not name_id and userid: + try: + name_id = self.ident.construct_nameid(self.conf.policy, userid, + sp_entity_id, identity) + logger.warning("Unspecified NameID format") + except Exception: + pass - return self.create_response(in_response_to, consumer_url, sp_entity_id, - identity, name_id, status, - issuer=issuer, - policy=self.conf.getattr("policy", "aa"), - sign_assertion=sign_assertion, - sign_response=sign_response) + to_sign = [] + args = {} + if identity: + _issuer = self.issuer(issuer) + ast = Assertion(identity) + policy = self.conf.getattr("policy", "aa") + if policy: + ast.apply_policy(sp_entity_id, policy) + else: + policy = Policy() + + if attributes: + restr = restriction_from_attribute_spec(attributes) + ast = filter_attribute_value_assertions(ast) + + assertion = ast.construct(sp_entity_id, in_response_to, + consumer_url, name_id, + self.conf.attribute_converters, + policy, issuer=_issuer) + + if sign_assertion: + assertion.signature = pre_signature_part(assertion.id, + self.sec.my_cert, 1) + # Just the assertion or the response and the assertion ? + to_sign = [(class_name(assertion), assertion.id)] + + + args["assertion"] = assertion + + return self._response(in_response_to, consumer_url, status, issuer, + sign_response, to_sign, **args) # ------------------------------------------------------------------------ def create_authn_response(self, identity, in_response_to, destination, - sp_entity_id, name_id_policy, userid, - authn=None, authn_decl=None, issuer=None, - sign_response=False, sign_assertion=False): + sp_entity_id, name_id_policy=None, userid=None, + name_id=None, authn=None, authn_decl=None, + issuer=None, sign_response=False, + sign_assertion=False): """ Constructs an AuthenticationResponse :param identity: Information about an user @@ -618,24 +651,28 @@ class Server(HTTPBase): :return: A response instance """ - name_id = None - try: - nid_formats = [] - for _sp in self.metadata[sp_entity_id]["spsso_descriptor"]: - if "name_id_format" in _sp: - nid_formats.extend([n.text for n in _sp["name_id_format"]]) + policy = self.conf.getattr("policy", "idp") - policy = self.conf.getattr("policy", "idp") - name_id = self.ident.construct_nameid(policy, userid, sp_entity_id, - identity, name_id_policy, - nid_formats) - except IOError, exc: - response = self.create_error_response(in_response_to, destination, - sp_entity_id, exc, name_id) - return ("%s" % response).split("\n") + if not name_id: + try: + nid_formats = [] + for _sp in self.metadata[sp_entity_id]["spsso_descriptor"]: + if "name_id_format" in _sp: + nid_formats.extend([n.text for n in _sp["name_id_format"]]) + + name_id = self.ident.construct_nameid(policy, userid, + sp_entity_id, identity, + name_id_policy, + nid_formats) + except IOError, exc: + response = self.create_error_response(in_response_to, + destination, + sp_entity_id, + exc, name_id) + return ("%s" % response).split("\n") try: - return self.create_response(in_response_to, # in_response_to + return self._authn_response(in_response_to, # in_response_to destination, # consumer_url sp_entity_id, # sp_entity_id identity, # identity as dictionary @@ -699,86 +736,51 @@ class Server(HTTPBase): return req - def create_logout_response(self, request, bindings, status=None, + def create_logout_response(self, request, binding, status=None, sign=False, issuer=None): """ Create a LogoutResponse. What is returned depends on which binding is used. :param request: The request this is a response to - :param bindings: Which bindings that can be used to send the response + :param binding: Which binding the request came in over :param status: The return status of the response operation :param issuer: The issuer of the message - :return: A 3-tuple consisting of HTTP return code, HTTP headers and - possibly a message. + :return: A logout message. """ - sp_entity_id = request.issuer.text.strip() - - binding = None - dests = [] - for binding in bindings: - srvs = self.metadata.single_logout_service(sp_entity_id, "spsso", - binding=binding) - if srvs: - dests = destinations(srvs) - break + mid = sid() - if not dests: - logger.error("No way to return a response !!!") - return ("412 Precondition Failed", - [("Content-type", "text/html")], - ["No return way defined"]) - - # Pick the first - destination = dests[0] - - logger.info("Logout Destination: %s, binding: %s" % (dests, binding)) - if not status: + if not status: status = success_status_factory() - mid = sid() - rcode = "200 OK" - # response and packaging differs depending on binding - - if binding == BINDING_SOAP: - response = logoutresponse_factory( - sign=sign, - id = mid, - in_response_to = request.id, - status = status, - ) - if sign: - to_sign = [(class_name(response), mid)] - response = signed_instance_factory(response, self.sec, to_sign) - - (headers, message) = http_soap_message(response) - else: - _issuer = self.issuer(issuer) - response = logoutresponse_factory( - sign=sign, - id = mid, - in_response_to = request.id, - status = status, - issuer = _issuer, - destination = destination, - sp_entity_id = sp_entity_id, - instant=instant(), - ) - if sign: - to_sign = [(class_name(response), mid)] - response = signed_instance_factory(response, self.sec, to_sign) + response = "" + if binding in [BINDING_SOAP, BINDING_HTTP_POST]: + response = logoutresponse_factory(sign=sign, id = mid, + in_response_to = request.id, + status = status) + elif binding == BINDING_HTTP_REDIRECT: + sp_entity_id = request.issuer.text.strip() + srvs = self.metadata.single_logout_service(sp_entity_id, "spsso") + if not srvs: + raise Exception("Nowhere to send the response") - logger.info("Response: %s" % (response,)) - if binding == BINDING_HTTP_REDIRECT: - (headers, message) = http_redirect_message(response, - destination, - typ="SAMLResponse") - rcode = "302 Found" - else: - (headers, message) = http_post_message(response, destination, - typ="SAMLResponse") - - return rcode, headers, message + destination = destinations(srvs)[0] + + _issuer = self.issuer(issuer) + response = logoutresponse_factory(sign=sign, id = mid, + in_response_to = request.id, + status = status, + issuer = _issuer, + destination = destination, + sp_entity_id = sp_entity_id, + instant=instant()) + if sign: + to_sign = [(class_name(response), mid)] + response = signed_instance_factory(response, self.sec, to_sign) + + logger.info("Response: %s" % (response,)) + + return response def parse_authz_decision_query(self, xml_string): """ Parse an attribute query diff --git a/tests/idp_all_conf.py b/tests/idp_all_conf.py new file mode 100644 index 0000000..3ab0872 --- /dev/null +++ b/tests/idp_all_conf.py @@ -0,0 +1,75 @@ +from saml2 import BINDING_SOAP, BINDING_HTTP_REDIRECT, BINDING_HTTP_POST +from saml2.saml import NAMEID_FORMAT_PERSISTENT +from saml2.saml import NAME_FORMAT_URI + +try: + from saml2.sigver import get_xmlsec_binary +except ImportError: + get_xmlsec_binary = None + +if get_xmlsec_binary: + xmlsec_path = get_xmlsec_binary(["/opt/local/bin"]) +else: + xmlsec_path = '/usr/bin/xmlsec1' + +BASE = "http://localhost:8088" + +CONFIG = { + "entityid" : "urn:mace:example.com:saml:roland:idp", + "name" : "Rolands IdP", + "service": { + "aa": { + "endpoints" : { + "attribute_service": [ + ("%s/aap" % BASE, BINDING_HTTP_POST), + ("%s/aas" % BASE, BINDING_SOAP) + ] + }, + }, + "idp": { + "endpoints" : { + "single_sign_on_service" : [ + ("%s/sso" % BASE, BINDING_HTTP_REDIRECT)], + "single_logout_service": [ + ("%s/slo" % BASE, BINDING_SOAP), + ("%s/slop" % BASE, BINDING_HTTP_POST)], + }, + "policy": { + "default": { + "lifetime": {"minutes":15}, + "attribute_restrictions": None, # means all I have + "name_form": NAME_FORMAT_URI, + }, + "urn:mace:example.com:saml:roland:sp": { + "lifetime": {"minutes": 5}, + "nameid_format": NAMEID_FORMAT_PERSISTENT, + # "attribute_restrictions":{ + # "givenName": None, + # "surName": None, + # } + } + }, + "subject_data": "subject_data.db", + }, + }, + "debug" : 1, + "key_file" : "test.key", + "cert_file" : "test.pem", + "xmlsec_binary" : xmlsec_path, + "metadata": { + "local": ["servera.xml", "vo_metadata.xml"], + }, + "attribute_map_dir" : "attributemaps", + "organization": { + "name": "Exempel AB", + "display_name": [("Exempel AB","se"),("Example Co.","en")], + "url":"http://www.example.com/roland", + }, + "contact_person": [{ + "given_name":"John", + "sur_name": "Smith", + "email_address": ["john.smith@example.com"], + "contact_type": "technical", + }, + ], + } diff --git a/tests/test_41_response.py b/tests/test_41_response.py index ce0c058..6cea539 100644 --- a/tests/test_41_response.py +++ b/tests/test_41_response.py @@ -27,34 +27,30 @@ class TestResponse: def setup_class(self): server = Server("idp_conf") name_id = server.ident.transient_nameid( - "urn:mace:example.com:saml:roland:sp", - "id12") + "urn:mace:example.com:saml:roland:sp","id12") - self._resp_ = server.create_response( - "id12", # in_response_to - "http://lingon.catalogix.se:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id - IDENTITY, - name_id = name_id - ) + self._resp_ = server.create_authn_response(IDENTITY, + "id12", # in_response_to + "http://lingon.catalogix.se:8087/", # consumer_url + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id=name_id) - self._sign_resp_ = server.create_response( - "id12", # in_response_to - "http://lingon.catalogix.se:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id - IDENTITY, - name_id = name_id, - sign_assertion=True - ) + self._sign_resp_ = server.create_authn_response( + IDENTITY, + "id12", # in_response_to + "http://lingon.catalogix.se:8087/", # consumer_url + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id = name_id, + sign_assertion=True) - self._resp_authn = server.create_response( - "id12", # in_response_to - "http://lingon.catalogix.se:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id - IDENTITY, - name_id = name_id, - authn=(saml.AUTHN_PASSWORD, "http://www.example.com/login") - ) + self._resp_authn = server.create_authn_response( + IDENTITY, + "id12", # in_response_to + "http://lingon.catalogix.se:8087/", # consumer_url + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id = name_id, + authn=(saml.AUTHN_PASSWORD, + "http://www.example.com/login")) conf = config.SPConfig() conf.load_file("server_conf") @@ -80,33 +76,6 @@ class TestResponse: assert isinstance(resp, StatusResponse) assert isinstance(resp, AuthnResponse) - # def test_3(self): - # xml_response = ("%s" % (self._logout_resp,)).split("\n")[1] - # sec = security_context(self.conf) - # resp = response_factory(xml_response, self.conf, - # return_addr="http://lingon.catalogix.se:8087/", - # outstanding_queries={"id12": "http://localhost:8088/sso"}, - # timeslack=10000, decode=False) - # - # assert isinstance(resp, StatusResponse) - # assert isinstance(resp, LogoutResponse) - -# def test_decrypt(self): -# attr_stat = saml.attribute_statement_from_string( -# open("encrypted_attribute_statement.xml").read()) -# -# assert len(attr_stat.attribute) == 0 -# assert len(attr_stat.encrypted_attribute) == 4 -# -# xmlsec = get_xmlsec_binary() -# sec = SecurityContext(xmlsec, key_file="private_key.pem") -# -# resp = AuthnResponse(sec, None, "entity_id") -# resp.decrypt_attributes(attr_stat) -# -# assert len(attr_stat.attribute) == 4 -# assert len(attr_stat.encrypted_attribute) == 4 - def test_only_use_keys_in_metadata(self): conf = config.SPConfig() diff --git a/tests/test_44_authnresp.py b/tests/test_44_authnresp.py index 3fbd8ac..9d002f3 100644 --- a/tests/test_44_authnresp.py +++ b/tests/test_44_authnresp.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from saml2.saml import AUTHN_PASSWORD from saml2 import saml from saml2.server import Server @@ -22,28 +23,31 @@ class TestAuthnResponse: server = Server("idp_conf") name_id = server.ident.transient_nameid( "urn:mace:example.com:saml:roland:sp","id12") - policy = server.conf.getattr("policy", "idp") - self._resp_ = server.create_response( - "id12", # in_response_to - "http://lingon.catalogix.se:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id - IDENTITY, name_id = name_id, policy=policy) - - self._sign_resp_ = server.create_response( - "id12", # in_response_to - "http://lingon.catalogix.se:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id - IDENTITY, - name_id = name_id, sign_assertion=True, policy=policy) + authn = (AUTHN_PASSWORD, "http://www.example.com/login") - self._resp_authn = server.create_response( - "id12", # in_response_to - "http://lingon.catalogix.se:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id - IDENTITY, - name_id = name_id, - authn=(saml.AUTHN_PASSWORD, "http://www.example.com/login"), - policy=policy) + self._resp_ = server.create_authn_response( + IDENTITY, + "id12", # in_response_to + "http://lingon.catalogix.se:8087/", # consumer_url + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id = name_id, + authn=authn) + + self._sign_resp_ = server.create_authn_response( + IDENTITY, + "id12", # in_response_to + "http://lingon.catalogix.se:8087/", # consumer_url + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id = name_id, sign_assertion=True, + authn=authn) + + self._resp_authn = server.create_authn_response( + IDENTITY, + "id12", # in_response_to + "http://lingon.catalogix.se:8087/", # consumer_url + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id = name_id, + authn=authn) self.conf = config_factory("sp", "server_conf") self.conf.only_use_keys_in_metadata = False diff --git a/tests/test_50_server.py b/tests/test_50_server.py index 75a76e0..b7836fb 100644 --- a/tests/test_50_server.py +++ b/tests/test_50_server.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from saml2.saml import AUTHN_PASSWORD from saml2.samlp import response_from_string from saml2.server import Server, Identifier @@ -196,17 +197,17 @@ class TestServer1(): name_id = self.server.ident.transient_nameid( "urn:mace:example.com:saml:roland:sp", "id12") - resp = self.server.create_response( - "id12", # in_response_to - "http://localhost:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id + resp = self.server.create_authn_response( {"eduPersonEntitlement": "Short stop", "surName": "Jeter", "givenName": "Derek", "mail": "derek.jeter@nyy.mlb.com", "title": "The man"}, - name_id, - policy= self.server.conf.getattr("policy") + "id12", # in_response_to + "http://localhost:8087/", # destination + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id=name_id, + authn=(AUTHN_PASSWORD, "http://www.example.com/login") ) print resp.keyswv() @@ -246,10 +247,13 @@ class TestServer1(): assert confirmation.subject_confirmation_data.in_response_to == "id12" def test_sso_response_without_identity(self): - resp = self.server.create_response( + resp = self.server.create_authn_response( + {}, "id12", # in_response_to "http://localhost:8087/", # consumer_url "urn:mace:example.com:saml:roland:sp", # sp_entity_id + userid="USER1", + authn=(AUTHN_PASSWORD, "http://www.example.com/login") ) print resp.keyswv() @@ -299,7 +303,9 @@ class TestServer1(): ava, "id1", "http://local:8087/", "urn:mace:example.com:saml:roland:sp", npolicy, - "foba0001@example.com") + "foba0001@example.com", + authn=(AUTHN_PASSWORD, + "http://www.example.com/login")) response = samlp.response_from_string(resp_str) print response.keyswv() @@ -308,9 +314,11 @@ class TestServer1(): 'issuer', 'id']) print response.assertion[0].keyswv() assert len(response.assertion) == 1 - assert _eq(response.assertion[0].keyswv(), ['authn_statement', - 'attribute_statement', 'subject', 'issue_instant', - 'version', 'issuer', 'conditions', 'id']) + assert _eq(response.assertion[0].keyswv(), ['attribute_statement', + 'issue_instant', 'version', + 'subject', 'conditions', + 'id', 'issuer', + 'authn_statement']) assertion = response.assertion[0] assert len(assertion.attribute_statement) == 1 astate = assertion.attribute_statement[0] @@ -324,14 +332,14 @@ class TestServer1(): ava = { "givenName": ["Derek"], "surName": ["Jeter"], "mail": ["derek@nyy.mlb.com"], "title": "The man"} - signed_resp = self.server.create_response( - "id12", # in_response_to - "http://lingon.catalogix.se:8087/", # consumer_url - "urn:mace:example.com:saml:roland:sp", # sp_entity_id - ava, - name_id = name_id, - sign_assertion=True - ) + signed_resp = self.server.create_authn_response( + ava, + "id12", # in_response_to + "http://lingon.catalogix.se:8087/", # consumer_url + "urn:mace:example.com:saml:roland:sp", # sp_entity_id + name_id = name_id, + sign_assertion=True + ) print signed_resp assert signed_resp @@ -465,25 +473,10 @@ class TestServerLogout(): server = Server("idp_slo_redirect_conf") request = _logout_request("sp_slo_redirect_conf") print request - bindings = [BINDING_HTTP_REDIRECT] - (resp, headers, message) = server.create_logout_response(request, - bindings) - assert resp == '302 Found' + binding = BINDING_HTTP_REDIRECT + response = server.create_logout_response(request, binding) + headers, message = server.use_http_get(response, response.destination, + "/relay_state") assert len(headers) == 1 assert headers[0][0] == "Location" assert message == [''] - -# class TestSign(): -# def test_1(self): -# IDP = server.Server("restrictive_idp.config", debug=1) -# ava = { "givenName": ["Derek"], "surName": ["Jeter"], -# "mail": ["derek@nyy.mlb.com"]} -# -# authn_resp = IDP.authn_response(ava, -# "id1", "http://local:8087/", -# "urn:mace:example.com:saml:roland:sp", -# samlp.NameIDPolicy(format=saml.NAMEID_FORMAT_TRANSIENT, -# allow_create="true"), -# "foba0001@example.com", sign=True) -# print authn_resp -# assert False diff --git a/tests/test_51_client.py b/tests/test_51_client.py index 4714ec2..0dfc918 100644 --- a/tests/test_51_client.py +++ b/tests/test_51_client.py @@ -2,23 +2,20 @@ # -*- coding: utf-8 -*- import base64 -import urllib -from urlparse import urlparse, parse_qs - -from saml2.client import Saml2Client, LogoutError -from saml2 import samlp, BINDING_HTTP_POST, BINDING_HTTP_REDIRECT -from saml2 import BINDING_SOAP -from saml2 import saml, config, class_name -from saml2.discovery import discovery_service_request_url -from saml2.discovery import discovery_service_response -from saml2.saml import NAMEID_FORMAT_PERSISTENT -from saml2.server import Server from saml2.s_utils import decode_base64_and_inflate +from saml2.samlp import response_from_string, logout_request_from_string + +from saml2.client import Saml2Client +from saml2 import samlp, BINDING_HTTP_POST +from saml2 import saml, config, class_name +from saml2.config import SPConfig +from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT, \ + AUTHN_PASSWORD +from saml2.server import Server from saml2.time_util import in_a_while -from saml2.assertion import Assertion -from saml2.assertion import Policy from py.test import raises +from fakeIDP import FakeIDP def for_me(condition, me ): for restriction in condition.audience_restriction: @@ -56,6 +53,8 @@ REQ1 = { "1.2.14": """ "1.2.16":""" urn:mace:example.com:saml:roland:spE8042FB4-4D5B-48C3-8E14-8EDD852790DD"""} +AUTHN = (AUTHN_PASSWORD, "http://www.example.com/login") + class TestClient: def setup_class(self): self.server = Server("idp_conf") @@ -256,7 +255,8 @@ class TestClient: destination="http://lingon.catalogix.se:8087/", sp_entity_id="urn:mace:example.com:saml:roland:sp", name_id_policy=nameid_policy, - userid="foba0001@example.com") + userid="foba0001@example.com", + authn=AUTHN) resp_str = "%s" % resp @@ -298,7 +298,8 @@ class TestClient: destination="http://lingon.catalogix.se:8087/", sp_entity_id="urn:mace:example.com:saml:roland:sp", name_id_policy=nameid_policy, - userid="also0001@example.com") + userid="also0001@example.com", + authn=AUTHN) resp_str = base64.encodestring(resp_str) @@ -328,83 +329,65 @@ class TestClient: assert my_name == "urn:mace:example.com:saml:roland:sp" # Below can only be done with dummy Server -# def test_attribute_query(self): -# resp = self.client.do_attribute_query( -# "urn:mace:example.com:saml:roland:idp", -# "_e7b68a04488f715cda642fbdd90099f5", -# nameid_format=saml.NAMEID_FORMAT_TRANSIENT) -# -# # since no one is answering on the other end -# assert resp is None -# def test_authenticate(self): -# print self.client.metadata.with_descriptor("idpsso") -# id, response = self.client.do_authenticate( -# "urn:mace:example.com:saml:roland:idp", -# "http://www.example.com/relay_state") -# assert response[0] == "Location" -# o = urlparse(response[1]) -# qdict = parse_qs(o.query) -# assert _leq(qdict.keys(), ['SAMLRequest', 'RelayState']) -# saml_request = decode_base64_and_inflate(qdict["SAMLRequest"][0]) -# print saml_request -# authnreq = samlp.authn_request_from_string(saml_request) -# -# def test_authenticate_no_args(self): -# id, response = self.client.do_authenticate(relay_state="http://www.example.com/relay_state") -# assert response[0] == "Location" -# o = urlparse(response[1]) -# qdict = parse_qs(o.query) -# assert _leq(qdict.keys(), ['SAMLRequest', 'RelayState']) -# saml_request = decode_base64_and_inflate(qdict["SAMLRequest"][0]) -# assert qdict["RelayState"][0] == "http://www.example.com/relay_state" -# print saml_request -# authnreq = samlp.authn_request_from_string(saml_request) -# print authnreq.keyswv() -# assert authnreq.destination == "http://localhost:8088/sso" -# assert authnreq.assertion_consumer_service_url == "http://lingon.catalogix.se:8087/" -# assert authnreq.provider_name == "urn:mace:example.com:saml:roland:sp" -# assert authnreq.protocol_binding == BINDING_HTTP_REDIRECT -# name_id_policy = authnreq.name_id_policy -# assert name_id_policy.allow_create == "false" -# assert name_id_policy.format == NAMEID_FORMAT_PERSISTENT -# issuer = authnreq.issuer -# assert issuer.text == "urn:mace:example.com:saml:roland:sp" -# -# -# def test_logout_1(self): -# """ one IdP/AA with BINDING_HTTP_REDIRECT on single_logout_service""" -# -# # information about the user from an IdP -# session_info = { -# "name_id": "123456", -# "issuer": "urn:mace:example.com:saml:roland:idp", -# "not_on_or_after": in_a_while(minutes=15), -# "ava": { -# "givenName": "Anders", -# "surName": "Andersson", -# "mail": "anders.andersson@example.com" -# } -# } -# self.client.users.add_information_about_person(session_info) -# entity_ids = self.client.users.issuers_of_info("123456") -# assert entity_ids == ["urn:mace:example.com:saml:roland:idp"] -# resp = self.client.global_logout("123456", "Tired", -# in_a_while(minutes=5)) -# print resp -# assert resp -# assert resp[0] # a session_id -# assert resp[1] == '200 OK' -# assert resp[2] == [('Content-type', 'text/html')] -# assert resp[3][0] == '' -# assert resp[3][1] == 'SAML 2.0 POST' -# session_info = self.client.state[resp[0]] -# print session_info -# assert session_info["entity_id"] == entity_ids[0] -# assert session_info["subject_id"] == "123456" -# assert session_info["reason"] == "Tired" -# assert session_info["operation"] == "SLO" -# assert session_info["entity_ids"] == entity_ids -# assert session_info["sign"] == True + +IDP = "urn:mace:example.com:saml:roland:idp" +class TestClientWithDummy(): + def setup_class(self): + self.server = FakeIDP("idp_all_conf") + + conf = SPConfig() + conf.load_file("servera_conf") + self.client = Saml2Client(conf) + + self.client.send = self.server.receive + + def test_do_authn(self): + response = self.client.do_authenticate(IDP, + "http://www.example.com/relay_state") + + def test_do_attribute_query(self): + response = self.client.do_attribute_query(IDP, + "_e7b68a04488f715cda642fbdd90099f5", + attribute={"eduPersonAffiliation":None}, + nameid_format=NAMEID_FORMAT_TRANSIENT) + + + def test_logout_1(self): + """ one IdP/AA logout from""" + + # information about the user from an IdP + session_info = { + "name_id": "123456", + "issuer": "urn:mace:example.com:saml:roland:idp", + "not_on_or_after": in_a_while(minutes=15), + "ava": { + "givenName": "Anders", + "surName": "Andersson", + "mail": "anders.andersson@example.com" + } + } + self.client.users.add_information_about_person(session_info) + entity_ids = self.client.users.issuers_of_info("123456") + assert entity_ids == ["urn:mace:example.com:saml:roland:idp"] + resp = self.client.global_logout("123456", "Tired", in_a_while(minutes=5)) + print resp + assert resp + assert len(resp) == 1 + assert resp.keys() == entity_ids + item = resp[entity_ids[0]] + assert isinstance(item, tuple) + assert item[0] == [('Content-type', 'text/html')] + lead = "name=\"SAMLRequest\" value=\"" + body = item[1][3] + i = body.find(lead) + i += len(lead) + j = i + body[i:].find('"') + info = body[i:j] + xml_str = base64.b64decode(info) + #xml_str = decode_base64_and_inflate(info) + req = logout_request_from_string(xml_str) + print req + assert req.reason == "Tired" # # def test_logout_2(self): # """ one IdP/AA with BINDING_SOAP, can't actually send something""" diff --git a/tests/test_60_sp.py b/tests/test_60_sp.py index d0d1101..604882f 100644 --- a/tests/test_60_sp.py +++ b/tests/test_60_sp.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- import base64 -from saml2.saml import NAMEID_FORMAT_TRANSIENT +from saml2.saml import NAMEID_FORMAT_TRANSIENT, AUTHN_PASSWORD from saml2.samlp import NameIDPolicy from s2repoze.plugins.sp import make_plugin from saml2.server import Server @@ -34,6 +34,9 @@ ENV1 = {'SERVER_SOFTWARE': 'CherryPy/3.1.2 WSGI Server', trans_name_policy = NameIDPolicy(format=NAMEID_FORMAT_TRANSIENT, allow_create="true") + +AUTHN = (AUTHN_PASSWORD, "http://www.example.com/login") + class TestSP(): def setup_class(self): self.sp = make_plugin("rem", saml_conf="server_conf") @@ -52,7 +55,8 @@ class TestSP(): "http://lingon.catalogix.se:8087/", "urn:mace:example.com:saml:roland:sp", trans_name_policy, - "foba0001@example.com") + "foba0001@example.com", + authn=AUTHN) resp_str = base64.encodestring(resp_str) self.sp.outstanding_queries = {"id1":"http://www.example.com/service"} diff --git a/tools/make_metadata.py b/tools/make_metadata.py index ef7885c..9a3a64a 100755 --- a/tools/make_metadata.py +++ b/tools/make_metadata.py @@ -1,37 +1,22 @@ #!/usr/bin/env python +import argparse import os -import getopt import sys +from saml2.time_util import in_a_while +from saml2.extension import mdui, idpdisc, shibmd +from saml2.saml import NAME_FORMAT_URI +from saml2.attribute_converter import from_local_name +from saml2 import md, BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, BINDING_SOAP,\ + samlp, class_name +import xmldsig as ds -from saml2.metadata import entity_descriptor, entities_descriptor -from saml2.metadata import sign_entity_descriptor -from saml2.sigver import SecurityContext +from saml2.sigver import SecurityContext, pre_signature_part from saml2.sigver import get_xmlsec_binary from saml2.validate import valid_instance from saml2.config import Config -HELP_MESSAGE = """ -Usage: make_metadata [options] 1*configurationfile - -Valid options: -c:hi:k:np:sv:x: - -c : certificate - -e : Wrap the whole thing in an EntitiesDescriptor - -h : Print this help message - -i id : The ID of the entities descriptor - -k keyfile : A file with a key to sign the metadata with - -n : name - -p : path to the configuration file - -s : sign the metadata - -v : How long, in days, the metadata is valid from the - time of creation - -x : xmlsec1 binaries to be used for the signing - -w : Use wellknown namespace prefixes -""" - -class Usage(Exception): - def __init__(self, msg): - self.msg = msg +from saml2.s_utils import factory +from saml2.s_utils import sid NSPAIR = { "saml2p":"urn:oasis:names:tc:SAML:2.0:protocol", @@ -44,64 +29,531 @@ NSPAIR = { "md":"urn:oasis:names:tc:SAML:2.0:metadata", } -def main(args): - try: - opts, args = getopt.getopt(args, "c:ehi:k:np:sv:wx", - ["help", "name", "id", "keyfile", "sign", - "valid", "xmlsec", "entityid", "path"]) - except getopt.GetoptError, err: - # print help information and exit: - raise Usage(err) # will print something like "option -a not recognized" +DEFAULTS = { + "want_assertions_signed": "true", + "authn_requests_signed": "false", + "want_authn_requests_signed": "true", + } - output = None - verbose = False - valid_for = 0 - name = "" - id = "" - sign = False - xmlsec = "" - keyfile = "" - pubkeyfile = "" - entitiesid = True - path = [] - nspair = None - +ORG_ATTR_TRANSL = { + "organization_name": ("name", md.OrganizationName), + "organization_display_name": ("display_name", md.OrganizationDisplayName), + "organization_url": ("url", md.OrganizationURL) +} + +def _localized_name(val, klass): + """If no language is defined 'en' is the default""" try: - for o, a in opts: - if o in ("-v", "--valid"): - valid_for = int(a) * 24 - elif o in ("-h", "--help"): - raise Usage(HELP_MESSAGE) - elif o in ("-n", "--name"): - name = a - elif o in ("-i", "--id"): - id = a - elif o in ("-s", "--sign"): - sign = True - elif o in ("-x", "--xmlsec"): - xmlsec = a - elif o in ("-k", "--keyfile"): - keyfile = a - elif o in ("-c", "--certfile"): - pubkeyfile = a - elif o in ("-e", "--entityid"): - entitiesid = False - elif o in ("-p", "--path"): - path = [x.strip() for x in a.split(":")] - elif o in ("-w",): - nspair = NSPAIR + (text, lang) = val + return klass(text=text, lang=lang) + except ValueError: + return klass(text=val, lang="en") + +def do_organization_info(ava): + """ decription of an organization in the configuration is + a dictionary of keys and values, where the values might be tuples: + + "organization": { + "name": ("AB Exempel", "se"), + "display_name": ("AB Exempel", "se"), + "url": "http://www.example.org" + } + + """ + + if ava is None: + return None + + org = md.Organization() + for dkey, (ckey, klass) in ORG_ATTR_TRANSL.items(): + if ckey not in ava: + continue + if isinstance(ava[ckey], basestring): + setattr(org, dkey, [_localized_name(ava[ckey], klass)]) + elif isinstance(ava[ckey], list): + setattr(org, dkey, + [_localized_name(n, klass) for n in ava[ckey]]) + else: + setattr(org, dkey, [_localized_name(ava[ckey], klass)]) + return org + +def do_contact_person_info(lava): + """ Creates a ContactPerson instance from configuration information""" + + cps = [] + if lava is None: + return cps + + contact_person = md.ContactPerson + for ava in lava: + cper = md.ContactPerson() + for (key, classpec) in contact_person.c_children.values(): + try: + value = ava[key] + data = [] + if isinstance(classpec, list): + # What if value is not a list ? + if isinstance(value, basestring): + data = [classpec[0](text=value)] + else: + for val in value: + data.append(classpec[0](text=val)) + else: + data = classpec(text=value) + setattr(cper, key, data) + except KeyError: + pass + for (prop, classpec, _) in contact_person.c_attributes.values(): + try: + # should do a check for valid value + setattr(cper, prop, ava[prop]) + except KeyError: + pass + + # ContactType must have a value + typ = getattr(cper, "contact_type") + if not typ: + setattr(cper, "contact_type", "technical") + + cps.append(cper) + + return cps + + +def do_key_descriptor(cert, use="signing"): + return md.KeyDescriptor( + key_info = ds.KeyInfo( + x509_data=ds.X509Data( + x509_certificate=ds.X509Certificate(text=cert) + ) + ), + use=use + ) + +def do_requested_attribute(attributes, acs, is_required="false"): + lista = [] + for attr in attributes: + attr = from_local_name(acs, attr, NAME_FORMAT_URI) + args = {} + for key in attr.keyswv(): + args[key] = getattr(attr, key) + args["is_required"] = is_required + args["name_format"] = NAME_FORMAT_URI + lista.append(md.RequestedAttribute(**args)) + return lista + +def do_uiinfo(_uiinfo): + uii = mdui.UIInfo() + for attr in ['display_name', 'description', "information_url", + 'privacy_statement_url']: + try: + val = _uiinfo[attr] + except KeyError: + continue + + aclass = uii.child_class(attr) + inst = getattr(uii, attr) + if isinstance(val, basestring): + ainst = aclass(text=val) + inst.append(ainst) + elif isinstance(val, dict): + ainst = aclass() + ainst.text = val["text"] + ainst.lang = val["lang"] + inst.append(ainst) + else : + for value in val: + if isinstance(value, basestring): + ainst = aclass(text=value) + inst.append(ainst) + elif isinstance(value, dict): + ainst = aclass() + ainst.text = value["text"] + ainst.lang = value["lang"] + inst.append(ainst) + + try: + _attr = "logo" + val = _uiinfo[_attr] + inst = getattr(uii, _attr) + # dictionary or list of dictionaries + if isinstance(val, dict): + logo = mdui.Logo() + for attr, value in val.items(): + if attr in logo.keys(): + setattr(logo, attr, value) + inst.append(logo) + elif isinstance(val, list): + for logga in val: + if not isinstance(logga, dict): + raise Exception("Configuration error !!") + logo = mdui.Logo() + for attr, value in logga.items(): + if attr in logo.keys(): + setattr(logo, attr, value) + inst.append(logo) + except KeyError: + pass + + try: + _attr = "keywords" + val = _uiinfo[_attr] + inst = getattr(uii, _attr) + # list of basestrings, dictionary or list of dictionaries + if isinstance(val, list): + for value in val: + keyw = mdui.Keywords() + if isinstance(value, basestring): + keyw.text = " ".join(value) + elif isinstance(value, dict): + keyw.text = " ".join(value["text"]) + try: + keyw.lang = value["lang"] + except KeyError: + pass + else: + raise Exception("Configuration error: ui_info logo") + inst.append(keyw) + elif isinstance(val, dict): + keyw = mdui.Keywords() + keyw.text = " ".join(val["text"]) + try: + keyw.lang = val["lang"] + except KeyError: + pass + inst.append(keyw) + else: + raise Exception("Configuration Error: ui_info logo") + except KeyError: + pass + + return uii + +def do_idpdisc(discovery_response): + return idpdisc.DiscoveryResponse(index="0", location=discovery_response, + binding=idpdisc.NAMESPACE) + +ENDPOINTS = { + "sp": { + "artifact_resolution_service": (md.ArtifactResolutionService, True), + "single_logout_service": (md.SingleLogoutService, False), + "manage_name_id_service": (md.ManageNameIDService, False), + "assertion_consumer_service": (md.AssertionConsumerService, True), + }, + "idp":{ + "artifact_resolution_service": (md.ArtifactResolutionService, True), + "single_logout_service": (md.SingleLogoutService, False), + "manage_name_id_service": (md.ManageNameIDService, False), + "single_sign_on_service": (md.SingleSignOnService, False), + "name_id_mapping_service": (md.NameIDMappingService, False), + "assertion_id_request_service": (md.AssertionIDRequestService, False), + }, + "aa":{ + "artifact_resolution_service": (md.ArtifactResolutionService, True), + "single_logout_service": (md.SingleLogoutService, False), + "manage_name_id_service": (md.ManageNameIDService, False), + + "assertion_id_request_service": (md.AssertionIDRequestService, False), + + "attribute_service": (md.AttributeService, False) + }, + "pdp": { + "authz_service": (md.AuthzService, True) + } +} + +DEFAULT_BINDING = { + "assertion_consumer_service": BINDING_HTTP_POST, + "single_sign_on_service": BINDING_HTTP_REDIRECT, + "single_logout_service": BINDING_HTTP_POST, + "attribute_service": BINDING_SOAP, + "artifact_resolution_service": BINDING_SOAP +} + +def do_endpoints(conf, endpoints): + service = {} + + for endpoint, (eclass, indexed) in endpoints.items(): + try: + servs = [] + i = 1 + for args in conf[endpoint]: + if isinstance(args, basestring): # Assume it's the location + args = {"location":args, + "binding": DEFAULT_BINDING[endpoint]} + elif isinstance(args, tuple): # (location, binding) + args = {"location":args[0], "binding": args[1]} + if indexed and "index" not in args: + args["index"] = "%d" % i + servs.append(factory(eclass, **args)) + i += 1 + service[endpoint] = servs + except KeyError: + pass + return service + +DEFAULT = { + "want_assertions_signed": "true", + "authn_requests_signed": "false", + "want_authn_requests_signed": "false", + } + +def do_spsso_descriptor(conf, cert=None): + spsso = md.SPSSODescriptor() + spsso.protocol_support_enumeration = samlp.NAMESPACE + + endps = conf.getattr("endpoints", "sp") + if endps: + for (endpoint, instlist) in do_endpoints(endps, + ENDPOINTS["sp"]).items(): + setattr(spsso, endpoint, instlist) + + if cert: + spsso.key_descriptor = do_key_descriptor(cert) + + for key in ["want_assertions_signed", "authn_requests_signed"]: + try: + val = conf.getattr(key, "sp") + if val is None: + setattr(spsso, key, DEFAULT[key]) #default ?! else: - assert False, "unhandled option %s" % o - except Usage, err: - print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg) - print >> sys.stderr, "\t for help use --help" - return 2 + strval = "{0:>s}".format(val) + setattr(spsso, key, strval.lower()) + except KeyError: + setattr(spsso, key, DEFAULTS[key]) - if not xmlsec: - xmlsec = get_xmlsec_binary(path) + requested_attributes = [] + acs = conf.attribute_converters + req = conf.getattr("required_attributes", "sp") + if req: + requested_attributes.extend(do_requested_attribute(req, acs, + is_required="true")) + + opt=conf.getattr("optional_attributes", "sp") + if opt: + requested_attributes.extend(do_requested_attribute(opt, acs)) + + if requested_attributes: + spsso.attribute_consuming_service = [md.AttributeConsumingService( + requested_attribute=requested_attributes, + service_name= [md.ServiceName(lang="en",text=conf.name)], + index="1", + )] + try: + if conf.description: + try: + (text, lang) = conf.description + except ValueError: + text = conf.description + lang = "en" + spsso.attribute_consuming_service[0].service_description = [ + md.ServiceDescription(text=text, + lang=lang)] + except KeyError: + pass + + dresp = conf.getattr("discovery_response", "sp") + if dresp: + if spsso.extensions is None: + spsso.extensions = md.Extensions() + spsso.extensions.add_extension_element(do_idpdisc(dresp)) + + return spsso + +def do_idpsso_descriptor(conf, cert=None): + idpsso = md.IDPSSODescriptor() + idpsso.protocol_support_enumeration = samlp.NAMESPACE + + endps = conf.getattr("endpoints", "idp") + if endps: + for (endpoint, instlist) in do_endpoints(endps, + ENDPOINTS["idp"]).items(): + setattr(idpsso, endpoint, instlist) + + scopes = conf.getattr("scope", "idp") + if scopes: + if idpsso.extensions is None: + idpsso.extensions = md.Extensions() + for scope in scopes: + mdscope = shibmd.Scope() + mdscope.text = scope + # unless scope contains '*'/'+'/'?' assume non regexp ? + mdscope.regexp = "false" + idpsso.extensions.add_extension_element(mdscope) + + ui_info = conf.getattr("ui_info", "idp") + if ui_info: + if idpsso.extensions is None: + idpsso.extensions = md.Extensions() + idpsso.extensions.add_extension_element(do_uiinfo(ui_info)) + + if cert: + idpsso.key_descriptor = do_key_descriptor(cert) + + for key in ["want_authn_requests_signed"]: + try: + val = conf.getattr(key, "idp") + if val is None: + setattr(idpsso, key, DEFAULT["want_authn_requests_signed"]) + else: + setattr(idpsso, key, "%s" % val) + except KeyError: + setattr(idpsso, key, DEFAULTS[key]) + + return idpsso + +def do_aa_descriptor(conf, cert): + aad = md.AttributeAuthorityDescriptor() + aad.protocol_support_enumeration = samlp.NAMESPACE + + endps = conf.getattr("endpoints", "aa") + + if endps: + for (endpoint, instlist) in do_endpoints(endps, + ENDPOINTS["aa"]).items(): + setattr(aad, endpoint, instlist) + + if cert: + aad.key_descriptor = do_key_descriptor(cert) + + return aad + +def do_pdp_descriptor(conf, cert): + """ Create a Policy Decision Point descriptor """ + pdp = md.PDPDescriptor() + + pdp.protocol_support_enumeration = samlp.NAMESPACE + + endps = conf.getattr("endpoints", "pdp") + + if endps: + for (endpoint, instlist) in do_endpoints(endps, + ENDPOINTS["pdp"]).items(): + setattr(pdp, endpoint, instlist) + + namef = conf.getattr("name_form", "pdp") + if namef: + if isinstance(namef, basestring): + ids = [md.NameIDFormat(namef)] + else: + ids = [md.NameIDFormat(text=form) for form in namef] + setattr(pdp, "name_id_format", ids) + + if cert: + pdp.key_descriptor = do_key_descriptor(cert) + + return pdp + +def entity_descriptor(confd): + mycert = "".join(open(confd.cert_file).readlines()[1:-1]) + + entd = md.EntityDescriptor() + entd.entity_id = confd.entityid + + if confd.valid_for: + entd.valid_until = in_a_while(hours=int(confd.valid_for)) + + if confd.organization is not None: + entd.organization = do_organization_info(confd.organization) + if confd.contact_person is not None: + entd.contact_person = do_contact_person_info(confd.contact_person) + + serves = confd.serves + if not serves: + raise Exception( + 'No service type ("sp","idp","aa") provided in the configuration') + + if "sp" in serves: + confd.context = "sp" + entd.spsso_descriptor = do_spsso_descriptor(confd, mycert) + if "idp" in serves: + confd.context = "idp" + entd.idpsso_descriptor = do_idpsso_descriptor(confd, mycert) + if "aa" in serves: + confd.context = "aa" + entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert) + if "pdp" in serves: + confd.context = "pdp" + entd.pdp_descriptor = do_pdp_descriptor(confd, mycert) + + return entd + +def entities_descriptor(eds, valid_for, name, ident, sign, secc): + entities = md.EntitiesDescriptor(entity_descriptor= eds) + if valid_for: + entities.valid_until = in_a_while(hours=valid_for) + if name: + entities.name = name + if ident: + entities.id = ident + + if sign: + if not ident: + ident = sid() + + if not secc.key_file: + raise Exception("If you want to do signing you should define " + + "a key to sign with") + + if not secc.my_cert: + raise Exception("If you want to do signing you should define " + + "where your public key are") + + entities.signature = pre_signature_part(ident, secc.my_cert, 1) + entities.id = ident + xmldoc = secc.sign_statement_using_xmlsec("%s" % entities, + class_name(entities)) + entities = md.entities_descriptor_from_string(xmldoc) + return entities + +def sign_entity_descriptor(edesc, ident, secc): + if not ident: + ident = sid() + + edesc.signature = pre_signature_part(ident, secc.my_cert, 1) + edesc.id = ident + xmldoc = secc.sign_statement_using_xmlsec("%s" % edesc, class_name(edesc)) + return md.entity_descriptor_from_string(xmldoc) + +if __name__ == "__main__": + import sys + + parser = argparse.ArgumentParser() + parser.add_argument('-v', dest='valid', action='store_true', + help="How long, in days, the metadata is valid from the time of creation") + parser.add_argument('-c', dest='cert', help='certificate') + parser.add_argument('-e', dest='ed', action='store_true', + help="Wrap the whole thing in an EntitiesDescriptor") + parser.add_argument('-i', dest='id', + help="The ID of the entities descriptor") + parser.add_argument('-k', dest='keyfile', + help="A file with a key to sign the metadata with") + parser.add_argument('-n', dest='name', default="") + parser.add_argument('-p', dest='path', + help="path to the configuration file") + parser.add_argument('-s', dest='sign', action='store_true', + help="sign the metadata") + parser.add_argument('-x', dest='xmlsec', + help="xmlsec binaries to be used for the signing") + parser.add_argument('-w', dest='wellknown', + help="Use wellknown namespace prefixes") + parser.add_argument(dest="config", nargs="+") + args = parser.parse_args() + + valid_for = 0 + nspair = None + paths = [".", "/opt/local/bin"] + + if args.valid: + # translate into hours + valid_for = int(args.valid) * 24 + if args.xmlsec: + xmlsec = args.xmlsec + else: + xmlsec = get_xmlsec_binary(paths) eds = [] - for filespec in args: + for filespec in args.config: bas, fil = os.path.split(filespec) if bas != "": sys.path.insert(0, bas) @@ -110,21 +562,18 @@ def main(args): cnf = Config().load_file(fil, metadata_construction=True) eds.append(entity_descriptor(cnf)) - secc = SecurityContext(xmlsec, keyfile, cert_file=pubkeyfile) - if entitiesid: - desc = entities_descriptor(eds, valid_for, name, id, sign, secc) + secc = SecurityContext(xmlsec, args.keyfile, cert_file=args.cert) + if args.id: + desc = entities_descriptor(eds, valid_for, args.name, args.id, + args.sign, secc) valid_instance(desc) print desc.to_string(nspair) else: for eid in eds: - if sign: + if args.sign: desc = sign_entity_descriptor(eid, id, secc) else: desc = eid valid_instance(desc) print desc.to_string(nspair) -if __name__ == "__main__": - import sys - - main(sys.argv[1:])