diff --git a/src/saml2/server.py b/src/saml2/server.py index b53d6c6..45db33d 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -86,18 +86,24 @@ class Server(Entity): # default database is a shelve database which is OK in some setups dbspec = self.config.getattr("subject_data", "idp") idb = None - if isinstance(dbspec, basestring): + if not dbspec: + pass + elif isinstance(dbspec, basestring): idb = shelve.open(dbspec, writeback=True) - else: # database spec is a a 2-tuple (type, address) - print >> sys.stderr, "DBSPEC: %s" % dbspec + else: # database spec is a a 2-tuple (type, address) + print >> sys.stderr, "DBSPEC: %s" % (dbspec,) (typ, addr) = dbspec if typ == "shelve": idb = shelve.open(addr, writeback=True) elif typ == "memcached": idb = memcache.Client(addr) - elif typ == "dict": # in-memory dictionary + elif typ == "dict": # in-memory dictionary idb = addr - + elif typ == "mongodb": + from mongodict import MongoDict + idb = MongoDict(host='localhost', port=27017, + database=addr, collection='store') + if idb is not None: self.ident = IdentDB(idb) else: @@ -150,7 +156,6 @@ class Server(Entity): return self._parse_request(xml_string, AttributeQuery, "attribute_service", binding) - def parse_authz_decision_query(self, xml_string, binding): """ Parse an attribute query @@ -236,7 +241,8 @@ class Server(Entity): if statement.session_index != session_index: continue if requested_context: - if not context_match(requested_context, statement.authn_context): + if not context_match(requested_context, + statement.authn_context): continue result.append(statement) @@ -286,7 +292,7 @@ class Server(Entity): return self.create_error_response(in_response_to, consumer_url, exc, sign_response) - if authn: # expected to be a 2-tuple class+authority + if authn: # expected to be a 2-tuple class+authority (authn_class, authn_authn) = authn assertion = ast.construct(sp_entity_id, in_response_to, consumer_url, name_id, @@ -387,7 +393,6 @@ class Server(Entity): # Just the assertion or the response and the assertion ? to_sign = [(class_name(assertion), assertion.id)] - args["assertion"] = assertion return self._response(in_response_to, destination, status, issuer, @@ -424,7 +429,8 @@ class Server(Entity): nid_formats = [] for _sp in self.metadata[sp_entity_id]["spsso_descriptor"]: if "name_id_format" in _sp: - nid_formats.extend([n["text"] for n in _sp["name_id_format"]]) + nid_formats.extend([n["text"] for n in + _sp["name_id_format"]]) name_id = self.ident.construct_nameid(userid, policy, sp_entity_id, @@ -438,13 +444,13 @@ class Server(Entity): return ("%s" % response).split("\n") try: - return self._authn_response(in_response_to, # in_response_to - destination, # consumer_url - sp_entity_id, # sp_entity_id - identity, # identity as dictionary + return self._authn_response(in_response_to, # in_response_to + destination, # consumer_url + sp_entity_id, # sp_entity_id + identity, # identity as dictionary name_id, - authn=authn, # Information about the - # authentication + authn=authn, # Information about the + # authentication authn_decl=authn_decl, issuer=issuer, policy=policy, @@ -453,7 +459,7 @@ class Server(Entity): except MissingValue, exc: return self.create_error_response(in_response_to, destination, - sp_entity_id, exc, name_id) + sp_entity_id, exc, name_id) def create_assertion_id_request_response(self, assertion_id, sign=False): """