diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index d379404..96c4c2a 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -108,6 +108,12 @@ def get_xmlsec_binary(paths=None): raise Exception("Can't find %s" % bin_name) +def get_xmlsec_cryptobackend(path=None, search_paths=None, debug=False): + if path is None: + path=get_xmlsec_binary(paths=search_paths) + return CryptoBackendXmlSec1(path, debug=debug) + + try: XMLSEC_BINARY = get_xmlsec_binary() except Exception: @@ -246,7 +252,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): signed_xml = "%s" % instance for (node_name, nodeid) in elements_to_sign: signed_xml = seccont.sign_statement_using_xmlsec( - signed_xml, klass_namn=node_name, nodeid=nodeid) + signed_xml, class_name=node_name, node_id=nodeid) #print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" #print "%s" % signed_xml @@ -476,8 +482,6 @@ def parse_xmlsec_output(output): raise XmlsecError(output) raise XmlsecError(output) -__DEBUG = 0 - class BadSignature(Exception): """The signature is invalid.""" @@ -559,7 +563,7 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem", """ Verifies the signature of a XML document. :param enctext: The signed XML document - :param xmlsec_binary: The xmlsec1 binaries to be used + :param xmlsec_binary: The xmlsec1 binaries to be used (or CryptoBackend()) :param cert_file: The public key used to decrypt the signature :param cert_type: The cert format :param node_name: The SAML class of the root node in the signed document @@ -572,32 +576,15 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem", if not id_attr: id_attr = ID_ATTR - _, fil = make_temp(enctext, decode=False) + crypto = xmlsec_binary + if not isinstance(crypto, CryptoBackend): + # backwards compatibility + crypto = CryptoBackendXmlSec1(xmlsec_binary, debug=debug) - com_list = [xmlsec_binary, "--verify", - "--pubkey-cert-%s" % cert_type, cert_file, - "--id-attr:%s" % id_attr, node_name] - - if debug: - com_list.append("--store-signatures") - - if node_id: - com_list.extend(["--node-id", node_id]) - - if __DEBUG: - try: - print " ".join(com_list) - except TypeError: - print "cert_type", cert_type - print "cert_file", cert_file - print "node_name", node_name - print "fil", fil - raise - print "%s: %s" % (cert_file, os.access(cert_file, os.F_OK)) - print "%s: %s" % (fil, os.access(fil, os.F_OK)) - - (_stdout, stderr, _output) = _run_xmlsec(com_list, [fil], exception=SignatureError) - return parse_xmlsec_output(stderr) + return crypto.validate_signature(enctext, cert_file=cert_file, + cert_type=cert_type, node_name=node_name, + node_id=node_id, id_attr=id_attr, + ) # --------------------------------------------------------------------------- @@ -638,34 +625,146 @@ def read_cert_from_file(cert_file, cert_type): return base64.b64encode(str(data)) -def _run_xmlsec(com_list, extra_args, validate_output=True, exception=XmlsecError): - """ - Common code to invoke xmlsec and parse the output. - :param com_list: Key-value parameter list for xmlsec - :param extra_args: Positional parameters to be appended after all key-value parameters - :param validate_output: Parse and validate the output - :param exception: The exception class to raise on errors - :result: Whatever xmlsec wrote to an --output temporary file - """ - ntf = NamedTemporaryFile() - com_list.append(["--output", ntf.name]) - com_list += extra_args +class CryptoBackend(): - logger.debug("xmlsec command: %s" % " ".join(com_list)) + def __init__(self, debug=False): + self.debug = debug - pof = Popen(com_list, stderr=PIPE, stdout=PIPE) + def encrypt(self, text, recv_key, template, key_type): + raise NotImplementedError() - p_out = pof.stdout.read() - p_err = pof.stderr.read() - try: - if validate_output: - parse_xmlsec_output(p_err) - except XmlsecError, exc: - logger.error(LOG_LINE_2 % (p_out, p_err, exc)) - raise exception("%s" % (exc,)) + def decrypt(self, enctext, key_file): + raise NotImplementedError() - ntf.seek(0) - return (p_out, p_err, ntf.read()) + def sign_statement(self, statement, class_name, key, key_file, nodeid, + id_attr): + raise NotImplementedError() + + def validate_signature(self, enctext, cert_file, cert_type, node_name, + node_id, id_attr): + raise NotImplementedError() + + +class CryptoBackendXmlSec1(CryptoBackend): + + __DEBUG = 0 + + def __init__(self, xmlsec_binary, **kwargs): + CryptoBackend.__init__(self, **kwargs) + assert(isinstance(xmlsec_binary, basestring)) + self.xmlsec = xmlsec_binary + + def encrypt(self, text, recv_key, template, key_type): + logger.info("Encryption input len: %d" % len(text)) + _, fil = make_temp("%s" % text, decode=False) + + com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key, + "--session-key", key_type, "--xml-data", fil, + ] + + (_stdout, _stderr, output) = self._run_xmlsec(com_list, [template], + exception=DecryptError) + return output + + def decrypt(self, enctext, key_file): + logger.info("Decrypt input len: %d" % len(enctext)) + _, fil = make_temp("%s" % enctext, decode=False) + + com_list = [self.xmlsec, "--decrypt", "--privkey-pem", + key_file, "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, + ] + + (_stdout, _stderr, output) = self._run_xmlsec(com_list, [fil], + exception=DecryptError) + return output + + def sign_statement(self, statement, class_name, key, key_file, node_id, + id_attr): + + _, fil = make_temp("%s" % statement, decode=False) + + com_list = [self.xmlsec, "--sign", + "--privkey-pem", key_file, + "--id-attr:%s" % id_attr, class_name, + #"--store-signatures" + ] + if node_id: + com_list.extend(["--node-id", node_id]) + + try: + (stdout, stderr, signed_statement) = \ + self._run_xmlsec(com_list, [fil], validate_output=False) + # this doesn't work if --store-signatures are used + if stdout == "": + if signed_statement: + return signed_statement + logger.error("Signing operation failed :\nstdout : %s\nstderr : %s" \ + % (stdout, stderr)) + raise Exception("Signing failed") + except DecryptError, exc: + raise Exception("Signing failed") + + + def validate_signature(self, enctext, cert_file, cert_type, node_name, + node_id, id_attr): + _, fil = make_temp(enctext, decode=False) + + com_list = [self.xmlsec, "--verify", + "--pubkey-cert-%s" % cert_type, cert_file, + "--id-attr:%s" % id_attr, node_name] + + if self.debug: + com_list.append("--store-signatures") + + if node_id: + com_list.extend(["--node-id", node_id]) + + if self.__DEBUG: + try: + print " ".join(com_list) + except TypeError: + print "cert_type", cert_type + print "cert_file", cert_file + print "node_name", node_name + print "fil", fil + raise + print "%s: %s" % (cert_file, os.access(cert_file, os.F_OK)) + print "%s: %s" % (fil, os.access(fil, os.F_OK)) + + (_stdout, stderr, _output) = self._run_xmlsec(com_list, [fil], + exception=SignatureError) + return parse_xmlsec_output(stderr) + + def _run_xmlsec(self, com_list, extra_args, validate_output=True, + exception=XmlsecError): + """ + Common code to invoke xmlsec and parse the output. + :param com_list: Key-value parameter list for xmlsec + :param extra_args: Positional parameters to be appended after all + key-value parameters + :param validate_output: Parse and validate the output + :param exception: The exception class to raise on errors + :result: Whatever xmlsec wrote to an --output temporary file + """ + ntf = NamedTemporaryFile() + com_list.extend(["--output", ntf.name]) + com_list += extra_args + + logger.debug("xmlsec command: %s" % " ".join(com_list)) + + pof = Popen(com_list, stderr=PIPE, stdout=PIPE) + + p_out = pof.stdout.read() + p_err = pof.stderr.read() + try: + if validate_output: + parse_xmlsec_output(p_err) + except XmlsecError, exc: + logger.error(LOG_LINE_2 % (p_out, p_err, exc)) + raise exception("%s" % (exc,)) + + ntf.seek(0) + return (p_out, p_err, ntf.read()) def security_context(conf, debug=None): @@ -686,18 +785,21 @@ def security_context(conf, debug=None): if _only_md is None: _only_md = False - return SecurityContext(conf.xmlsec_binary, conf.key_file, + crypto = get_xmlsec_cryptobackend(conf.xmlsec_binary, debug=debug) + + return SecurityContext(crypto, conf.key_file, cert_file=conf.cert_file, metadata=metadata, debug=debug, only_use_keys_in_metadata=_only_md) class SecurityContext(object): - def __init__(self, xmlsec_binary, key_file="", key_type="pem", + def __init__(self, crypto, key_file="", key_type="pem", cert_file="", cert_type="pem", metadata=None, debug=False, template="", encrypt_key_type="des-192", only_use_keys_in_metadata=False): - self.xmlsec = xmlsec_binary + self.crypto = crypto + assert (isinstance(self.crypto, CryptoBackend)) # Your private key self.key_file = key_file @@ -741,15 +843,7 @@ class SecurityContext(object): if not template: template = self.template - logger.info("Encryption input len: %d" % len(text)) - _, fil = make_temp("%s" % text, decode=False) - - com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key, - "--session-key", key_type, "--xml-data", fil, - ] - - (_stdout, _stderr, output) = _run_xmlsec(com_list, [template], exception=DecryptError) - return output + return self.crypto.encrypt(text, recv_key, template, key_type) def decrypt(self, enctext): """ Decrypting an encrypted text by the use of a private key. @@ -757,16 +851,7 @@ class SecurityContext(object): :param enctext: The encrypted text as a string :return: The decrypted text """ - - logger.info("Decrypt input len: %d" % len(enctext)) - _, fil = make_temp("%s" % enctext, decode=False) - - com_list = [self.xmlsec, "--decrypt", "--privkey-pem", - self.key_file, "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, - ] - - (_stdout, _stderr, output) = _run_xmlsec(com_list, [fil], exception=DecryptError) - return output + return self.crypto.decrypt(enctext, self.key_file) def verify_signature(self, enctext, cert_file=None, cert_type="pem", node_name=NODE_NAME, node_id=None, id_attr=""): @@ -786,7 +871,7 @@ class SecurityContext(object): cert_file = self.cert_file cert_type = self.cert_type - return verify_signature(enctext, self.xmlsec, cert_file, cert_type, + return verify_signature(enctext, self.crypto, cert_file, cert_type, node_name, self.debug, node_id, id_attr) def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, @@ -1013,9 +1098,13 @@ class SecurityContext(object): #-------------------------------------------------------------------------- # SIGNATURE PART #-------------------------------------------------------------------------- - def sign_statement_using_xmlsec(self, statement, klass_namn, key=None, - key_file=None, nodeid=None, id_attr=""): - """Sign a SAML statement using xmlsec. + def sign_statement_using_xmlsec(self, statement, **kwargs): + """ Deprecated function. See sign_statement(). """ + return self.sign_statement(statement, **kwargs) + + def sign_statement(self, statement, class_name, key=None, + key_file=None, node_id=None, id_attr=""): + """Sign a SAML statement. :param statement: The statement to be signed :param key: The key to be used for the signing, either this or @@ -1024,7 +1113,6 @@ class SecurityContext(object): 'id','Id' or 'ID' :return: The signed statement """ - if not id_attr: id_attr = ID_ATTR @@ -1034,55 +1122,39 @@ class SecurityContext(object): if not key and not key_file: key_file = self.key_file - _, fil = make_temp("%s" % statement, decode=False) + return self.crypto.sign_statement(statement, class_name, key, key_file, + node_id, id_attr) - com_list = [self.xmlsec, "--sign", - "--privkey-pem", key_file, - "--id-attr:%s" % id_attr, klass_namn - #"--store-signatures" - ] - if nodeid: - com_list.extend(["--node-id", nodeid]) + def sign_assertion_using_xmlsec(self, statement, **kwargs): + """ Deprecated function. See sign_assertion(). """ + return self.sign_statement(statement, class_name(saml.Assertion()), + **kwargs) - try: - (stdout, stderr, signed_statement) = _run_xmlsec(com_list, [fil]) - # this doesn't work if --store-signatures are used - if stdout == "": - if signed_statement: - return signed_statement - logger.error("Signing operation failed :\nstdout : %s\nstderr : %s" % (stdout, stderr)) - raise Exception("Signing failed") - except DecryptError, exc: - raise Exception("Signing failed") + def sign_assertion(self, statement, **kwargs): + """Sign a SAML assertion. - def sign_assertion_using_xmlsec(self, statement, key=None, key_file=None, - nodeid=None, id_attr=""): - """Sign a SAML assertion using xmlsec. + See sign_statement() for the kwargs. :param statement: The statement to be signed - :param key: The key to be used for the signing, either this or - :param key_file: The file where the key can be found :return: The signed statement """ + return self.sign_statement(statement, class_name(saml.Assertion()), + **kwargs) - return self.sign_statement_using_xmlsec(statement, - class_name(saml.Assertion()), - key, key_file, nodeid, - id_attr=id_attr) + def sign_attribute_query_using_xmlsec(self, statement, **kwargs): + """ Deprecated function. See sign_attribute_query(). """ + return self.sign_attribute_query(statement, **kwargs); - def sign_attribute_query_using_xmlsec(self, statement, key=None, - key_file=None, nodeid=None, - id_attr=""): - """Sign a SAML assertion using xmlsec. + def sign_attribute_query(self, statement, **kwargs): + """Sign a SAML attribute query. + + See sign_statement() for the kwargs. :param statement: The statement to be signed - :param key: The key to be used for the signing, either this or - :param key_file: The file where the key can be found :return: The signed statement """ - - return self.sign_statement_using_xmlsec(statement, class_name( - samlp.AttributeQuery()), key, key_file, nodeid, id_attr=id_attr) + return self.sign_statement(statement, class_name( + samlp.AttributeQuery()), **kwargs) def multiple_signatures(self, statement, to_sign, key=None, key_file=None): """ @@ -1105,12 +1177,9 @@ class SecurityContext(object): if not item.signature: item.signature = pre_signature_part(sid, self.cert_file) - statement = self.sign_statement_using_xmlsec(statement, - class_name(item), - key=key, - key_file=key_file, - nodeid=id, - id_attr=id_attr) + statement = self.sign_statement(statement, class_name(item), + key=key, key_file=key_file, + node_id=id, id_attr=id_attr) return statement diff --git a/tests/test_40_sigver.py b/tests/test_40_sigver.py index 33ac786..a721a2b 100644 --- a/tests/test_40_sigver.py +++ b/tests/test_40_sigver.py @@ -10,7 +10,7 @@ from saml2 import class_name from saml2 import time_util from saml2 import saml, samlp from saml2.s_utils import factory, do_attribute_statement -from saml2.sigver import xmlsec_version, get_xmlsec_binary +from saml2.sigver import xmlsec_version, get_xmlsec_cryptobackend, get_xmlsec_binary from py.test import raises @@ -78,11 +78,10 @@ def test_cert_from_instance_ssp(): print str(decoder.decode(der)).replace('.',"\n.") assert decoder.decode(der) - class TestSecurity(): def setup_class(self): - xmlexec = get_xmlsec_binary() - self.sec = sigver.SecurityContext(xmlexec, key_file=PRIV_KEY, + crypto = get_xmlsec_cryptobackend() + self.sec = sigver.SecurityContext(crypto, key_file=PRIV_KEY, cert_file=PUB_KEY, debug=1) self._assertion = factory( saml.Assertion, @@ -116,7 +115,7 @@ class TestSecurity(): ass = self._assertion print ass sign_ass = self.sec.sign_assertion_using_xmlsec("%s" % ass, - nodeid=ass.id) + node_id=ass.id) #print sign_ass sass = saml.assertion_from_string(sign_ass) #print sass @@ -137,7 +136,7 @@ class TestSecurity(): assertion=self._assertion, id="22222", signature=sigver.pre_signature_part("22222", self.sec.my_cert)) - + to_sign = [(class_name(self._assertion), self._assertion.id), (class_name(response), response.id)] s_response = sigver.signed_instance_factory(response, self.sec, to_sign) @@ -291,7 +290,7 @@ class TestSecurity(): response2.id = "23456" raises(sigver.SignatureError, self.sec._check_signature, s_response, response2, class_name(response2)) - + class TestSecurityMetadata(): def setup_class(self): @@ -299,7 +298,8 @@ class TestSecurityMetadata(): md = MetadataStore([saml, samlp], None, xmlexec) md.load("local", "metadata_cert.xml") - self.sec = sigver.SecurityContext(xmlexec, key_file=PRIV_KEY, + crypto = get_xmlsec_cryptobackend() + self.sec = sigver.SecurityContext(crypto, key_file=PRIV_KEY, cert_file=PUB_KEY, debug=1, metadata=md) self._assertion = factory( saml.Assertion, @@ -317,7 +317,7 @@ class TestSecurityMetadata(): ass = self._assertion print ass sign_ass = self.sec.sign_assertion_using_xmlsec("%s" % ass, - nodeid=ass.id) + node_id=ass.id) #print sign_ass sass = saml.assertion_from_string(sign_ass) #print sass