diff --git a/src/saml2/eptid.py b/src/saml2/eptid.py index dd66d8f..7e0905e 100644 --- a/src/saml2/eptid.py +++ b/src/saml2/eptid.py @@ -1,6 +1,6 @@ # An eduPersonTargetedID comprises # the entity name of the identity provider, the entity name of the service -# provider, and the opaque string value. +# provider, and a opaque string value. # These strings are separated by "!" symbols. This form is advocated by # Internet2 and may overtake the other form in due course. @@ -33,13 +33,13 @@ class Eptid(object): def __setitem__(self, key, value): self._db[key] = value - def get(self, idp, sp, args): + def get(self, idp, sp, *args): # key is a combination of sp_entity_id and object id - key = (".".join([sp, args[0]])).encode("utf-8") + key = ("__".join([sp, args[0]])).encode("utf-8") try: return self[key] except KeyError: - val = self.make(idp, sp, args[1]) + val = self.make(idp, sp, args) self[key] = val return val diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index 68650d5..30b69a2 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -97,6 +97,18 @@ class MetaData(object): self.entity = {} self.metadata = metadata + def items(self): + return self.entity + + def keys(self): + return self.entity.keys() + + def __contains__(self, item): + return item in self.entity + + def __getitem__(self, item): + return self.entity[item] + def do_entity_descriptor(self, entity_descr): try: if not valid(entity_descr.valid_until): @@ -530,19 +542,19 @@ class MetadataStore(object): def attribute_requirement(self, entity_id, index=0): for md in self.metadata.values(): - if entity_id in md.entity: + if entity_id in md.items(): return md.attribute_requirement(entity_id, index) def keys(self): res = [] for md in self.metadata.values(): - res.extend(md.entity.keys()) + res.extend(md.keys()) return res def __getitem__(self, item): for md in self.metadata.values(): try: - return md.entity[item] + return md[item] except KeyError: pass @@ -554,7 +566,7 @@ class MetadataStore(object): def entities(self): num = 0 for md in self.metadata.values(): - num += len(md.entity) + num += len(md.items()) return num @@ -569,8 +581,8 @@ class MetadataStore(object): def name(self, entity_id, langpref="en"): for md in self.metadata.values(): - if entity_id in md.entity: - return name(md.entity[entity_id], langpref) + if entity_id in md.items(): + return name(md[entity_id], langpref) return None def certs(self, entity_id, descriptor, use="signing"): @@ -618,7 +630,7 @@ class MetadataStore(object): def bindings(self, entity_id, typ, service): for md in self.metadata.values(): - if entity_id in md.entity: + if entity_id in md.items(): return md.bindings(entity_id, typ, service) return None @@ -639,5 +651,5 @@ class MetadataStore(object): def items(self): res = {} for md in self.metadata.values(): - res.update(md.entity) + res.update(md.items()) return res.items() diff --git a/src/saml2/mongo_store.py b/src/saml2/mongo_store.py index 25a36b0..c60d888 100644 --- a/src/saml2/mongo_store.py +++ b/src/saml2/mongo_store.py @@ -208,9 +208,9 @@ class MDB(object): doc.update(kwargs) _ = self.db.insert(doc) - def get(self, key=None, **kwargs): - if key: - doc = {self.primary_key: key} + def get(self, value=None, **kwargs): + if value: + doc = {self.primary_key: value} doc.update(kwargs) return [item for item in self.db.find(doc)] elif kwargs: @@ -249,27 +249,23 @@ class MDB(object): #------------------------------------------------------------------------------ class EptidMDB(Eptid): - primary_key = "eptid" - def __init__(self, secret, collection="", sub_collection=""): + def __init__(self, secret, collection="", sub_collection="eptid"): Eptid.__init__(self, secret) self.mdb = MDB(collection, sub_collection) - self.mdb.primary_key = "entity_id" + self.mdb.primary_key = "eptid_key" def __getitem__(self, key): res = self.mdb.get(key) if not res: raise KeyError(key) elif len(res) == 1: - return res[0] + return res[0]["eptid"] else: raise CorruptDatabase("Found more than one EPTID document") def __setitem__(self, key, value): - if key == self.mdb.primary_key: - _ = self.mdb.store(value) - else: - _ = self.mdb.store(**{key: value}) + _ = self.mdb.store(key, **{"eptid": value}) #------------------------------------------------------------------------------ diff --git a/tests/test_72_eptid.py b/tests/test_72_eptid.py new file mode 100644 index 0000000..ea95d89 --- /dev/null +++ b/tests/test_72_eptid.py @@ -0,0 +1,41 @@ +from saml2.eptid import Eptid, EptidShelve + +__author__ = 'rolandh' + + +def test_eptid(): + edb = Eptid("secret") + e1 = edb.get("idp_entity_id", "sp_entity_id", "user_id", "some other data") + print e1 + assert e1.startswith("idp_entity_id!sp_entity_id!") + e2 = edb.get("idp_entity_id", "sp_entity_id", "user_id", "some other data") + assert e1 == e2 + + e3 = edb.get("idp_entity_id", "sp_entity_id", "user_2", "some other data") + print e3 + assert e1 != e3 + + e4 = edb.get("idp_entity_id", "sp_entity_id2", "user_id", "some other data") + assert e4 != e1 + assert e4 != e3 + + +def test_eptid_shelve(): + edb = EptidShelve("secret", "eptid.db") + e1 = edb.get("idp_entity_id", "sp_entity_id", "user_id", "some other data") + print e1 + assert e1.startswith("idp_entity_id!sp_entity_id!") + e2 = edb.get("idp_entity_id", "sp_entity_id", "user_id", "some other data") + assert e1 == e2 + + e3 = edb.get("idp_entity_id", "sp_entity_id", "user_2", "some other data") + print e3 + assert e1 != e3 + + e4 = edb.get("idp_entity_id", "sp_entity_id2", "user_id", "some other data") + assert e4 != e1 + assert e4 != e3 + + +if __name__ == "__main__": + test_eptid_shelve() \ No newline at end of file diff --git a/tests/test_71_mongodb.py b/tests/test_75_mongodb.py similarity index 73% rename from tests/test_71_mongodb.py rename to tests/test_75_mongodb.py index d8b7f05..a3e2a10 100644 --- a/tests/test_71_mongodb.py +++ b/tests/test_75_mongodb.py @@ -2,6 +2,7 @@ from saml2 import BINDING_HTTP_POST from saml2.saml import AUTHN_PASSWORD from saml2.client import Saml2Client from saml2.server import Server +from saml2.mongo_store import EptidMDB __author__ = 'rolandh' @@ -50,5 +51,26 @@ def test_flow(): nids = idp2.ident.find_nameid("jeter") assert len(nids) == 1 + +def test_eptid_mongo_db(): + edb = EptidMDB("secret", "idp") + e1 = edb.get("idp_entity_id", "sp_entity_id", "user_id", + "some other data") + print e1 + assert e1.startswith("idp_entity_id!sp_entity_id!") + e2 = edb.get("idp_entity_id", "sp_entity_id", "user_id", + "some other data") + assert e1 == e2 + + e3 = edb.get("idp_entity_id", "sp_entity_id", "user_2", + "some other data") + print e3 + assert e1 != e3 + + e4 = edb.get("idp_entity_id", "sp_entity_id2", "user_id", + "some other data") + assert e4 != e1 + assert e4 != e3 + if __name__ == "__main__": test_flow()