deb-python-pysaml2/src/saml2/mongo_store.py

420 lines
13 KiB
Python

from hashlib import sha1
import logging
from pymongo import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
import pymongo.uri_parser
import pymongo.errors
from saml2.eptid import Eptid
from saml2.mdstore import InMemoryMetaData
from saml2.mdstore import metadata_modules
from saml2.mdstore import load_metadata_modules
from saml2.s_utils import PolicyError
from saml2.ident import code_binary
from saml2.ident import IdentDB
from saml2.ident import Unknown
from saml2.mdie import to_dict
from saml2.mdie import from_dict
import six
__author__ = 'rolandh'
logger = logging.getLogger(__name__)
ONTS = load_metadata_modules()
MMODS = metadata_modules()
class CorruptDatabase(Exception):
pass
def context_match(cfilter, cntx):
# TODO
return True
class SessionStorageMDB(object):
""" Session information is stored in a MongoDB database"""
def __init__(self, database="", collection="assertion", **kwargs):
db = _mdb_get_database(database, **kwargs)
self.assertion = db[collection]
def store_assertion(self, assertion, to_sign):
name_id = assertion.subject.name_id
nkey = sha1(code_binary(name_id)).hexdigest()
doc = {
"name_id_key": nkey,
"assertion_id": assertion.id,
"assertion": to_dict(assertion, MMODS, True),
"to_sign": to_sign
}
_ = self.assertion.insert(doc)
def get_assertion(self, cid):
res = []
for item in self.assertion.find({"assertion_id": cid}):
res.append({"assertion": from_dict(item["assertion"], ONTS, True),
"to_sign": item["to_sign"]})
if len(res) == 1:
return res[0]
elif res is []:
return None
else:
raise SystemError("More then one assertion with the same ID")
def get_assertions_by_subject(self, name_id=None, session_index=None,
requested_context=None):
"""
:param name_id: One of name_id or key can be used to get the authn
statement
:param session_index: If match against a session index should be done
:param requested_context: Authn statements should match a specific
authn context
:return:
"""
result = []
key = sha1(code_binary(name_id)).hexdigest()
for item in self.assertion.find({"name_id_key": key}):
assertion = from_dict(item["assertion"], ONTS, True)
if session_index or requested_context:
for statement in assertion.authn_statement:
if session_index:
if statement.session_index == session_index:
result.append(assertion)
break
if requested_context:
if context_match(requested_context,
statement.authn_context):
result.append(assertion)
break
else:
result.append(assertion)
return result
def remove_authn_statements(self, name_id):
logger.debug("remove authn about: %s", name_id)
key = sha1(code_binary(name_id)).hexdigest()
for item in self.assertion.find({"name_id_key": key}):
self.assertion.remove(item["_id"])
def get_authn_statements(self, name_id, session_index=None,
requested_context=None):
"""
:param name_id:
:param session_index:
:param requested_context:
:return:
"""
return [k.authn_statement for k in self.get_assertions_by_subject(
name_id, session_index, requested_context)]
class IdentMDB(IdentDB):
def __init__(self, database="", collection="ident", domain="",
name_qualifier=""):
IdentDB.__init__(self, None, domain, name_qualifier)
self.mdb = MDB(database=database, collection=collection)
self.mdb.primary_key = "user_id"
def in_store(self, _id):
if [x for x in self.mdb.get(ident_id=_id)]:
return True
else:
return False
def create_id(self, nformat, name_qualifier="", sp_name_qualifier=""):
_id = self._create_id(nformat, name_qualifier, sp_name_qualifier)
while self.in_store(_id):
_id = self._create_id(nformat, name_qualifier, sp_name_qualifier)
return _id
def store(self, ident, name_id):
self.mdb.store(ident, name_id=to_dict(name_id, MMODS, True))
def find_nameid(self, userid, nformat=None, sp_name_qualifier=None,
name_qualifier=None, sp_provided_id=None, **kwargs):
# reset passed for compatibility kwargs for next usage
kwargs = {}
if nformat:
kwargs["name_format"] = nformat
if sp_name_qualifier:
kwargs["sp_name_qualifier"] = sp_name_qualifier
if name_qualifier:
kwargs["name_qualifier"] = name_qualifier
if sp_provided_id:
kwargs["sp_provided_id"] = sp_provided_id
res = []
for item in self.mdb.get(userid, **kwargs):
res.append(from_dict(item["name_id"], ONTS, True))
return res
def find_local_id(self, name_id):
cnid = to_dict(name_id, MMODS, True)
for item in self.mdb.get(name_id=cnid):
return item[self.mdb.primary_key]
return None
def remove_remote(self, name_id):
cnid = to_dict(name_id, MMODS, True)
self.mdb.remove(name_id=cnid)
def handle_name_id_mapping_request(self, name_id, name_id_policy):
_id = self.find_local_id(name_id)
if not _id:
raise Unknown("Unknown entity")
if name_id_policy.allow_create == "false":
raise PolicyError("Not allowed to create new identifier")
# else create and return a new one
return self.construct_nameid(_id, name_id_policy=name_id_policy)
def close(self):
pass
#------------------------------------------------------------------------------
class MDB(object):
primary_key = "mdb"
def __init__(self, database, collection, **kwargs):
_db = _mdb_get_database(database, **kwargs)
self.db = _db[collection]
def store(self, value, **kwargs):
if value:
doc = {self.primary_key: value}
else:
doc = {}
doc.update(kwargs)
_ = self.db.insert(doc)
def get(self, value=None, **kwargs):
if value is not None:
doc = {self.primary_key: value}
doc.update(kwargs)
return [item for item in self.db.find(doc)]
elif kwargs:
return [item for item in self.db.find(kwargs)]
def remove(self, key=None, **kwargs):
if key is None:
if kwargs:
for item in self.db.find(kwargs):
self.db.remove(item["_id"])
else:
doc = {self.primary_key: key}
doc.update(kwargs)
for item in self.db.find(doc):
self.db.remove(item["_id"])
def keys(self):
for item in self.db.find():
yield item[self.primary_key]
def items(self):
for item in self.db.find():
_key = item[self.primary_key]
del item[self.primary_key]
del item["_id"]
yield _key, item
def __contains__(self, key):
doc = {self.primary_key: key}
res = [item for item in self.db.find(doc)]
if not res:
return False
else:
return True
def reset(self):
self.db.drop()
def _mdb_get_database(uri, **kwargs):
"""
Helper-function to connect to MongoDB and return a database object.
The `uri' argument should be either a full MongoDB connection URI string,
or just a database name in which case a connection to the default mongo
instance at mongodb://localhost:27017 will be made.
Performs explicit authentication if a username is provided in a connection
string URI, since PyMongo does not always seem to do that as promised.
:params database: name as string or (uri, name)
:returns: pymongo database object
"""
if not "tz_aware" in kwargs:
# default, but not forced
kwargs["tz_aware"] = True
connection_factory = MongoClient
_parsed_uri = {}
try:
_parsed_uri = pymongo.uri_parser.parse_uri(uri)
except pymongo.errors.InvalidURI:
# assume URI to be just the database name
db_name = uri
_conn = MongoClient()
pass
else:
if "replicaset" in _parsed_uri["options"]:
connection_factory = MongoReplicaSetClient
db_name = _parsed_uri.get("database", "pysaml2")
_conn = connection_factory(uri, **kwargs)
_db = _conn[db_name]
if "username" in _parsed_uri:
_db.authenticate(
_parsed_uri.get("username", None),
_parsed_uri.get("password", None)
)
return _db
#------------------------------------------------------------------------------
class EptidMDB(Eptid):
def __init__(self, secret, database="", collection="eptid"):
Eptid.__init__(self, secret)
self.mdb = MDB(database, collection)
self.mdb.primary_key = "eptid_key"
def __getitem__(self, key):
res = self.mdb.get(key)
if not res:
raise KeyError(key)
elif len(res) == 1:
return res[0]["eptid"]
else:
raise CorruptDatabase("Found more than one EPTID document")
def __setitem__(self, key, value):
_ = self.mdb.store(key, **{"eptid": value})
#------------------------------------------------------------------------------
def protect(dic):
res = {}
for key, val in dic.items():
key = key.replace(".", "__")
if isinstance(val, six.string_types):
pass
elif isinstance(val, dict):
val = protect(val)
elif isinstance(val, list):
li = []
for va in val:
if isinstance(va, six.string_types):
pass
elif isinstance(va, dict):
va = protect(va)
# I don't think lists of lists will appear am I wrong ?
li.append(va)
val = li
res[key] = val
return res
def unprotect(dic):
res = {}
for key, val in dic.items():
if key == "__class__":
pass
else:
key = key.replace("__", ".")
if isinstance(val, six.string_types):
pass
elif isinstance(val, dict):
val = unprotect(val)
elif isinstance(val, list):
li = []
for va in val:
if isinstance(va, six.string_types):
pass
elif isinstance(val, dict):
va = unprotect(va)
li.append(va)
val = li
res[key] = val
return res
def export_mdstore_to_mongo_db(mds, database, collection, sub_collection=""):
mdb = MDB(database, collection, sub_collection=sub_collection)
mdb.reset()
mdb.primary_key = "entity_id"
for key, desc in mds.items():
kwargs = {
"entity_description": protect(desc),
}
mdb.store(key, **kwargs)
class MetadataMDB(InMemoryMetaData):
def __init__(self, attrc, database="", collection=""):
super(MetadataMDB, self).__init__(attrc)
self.mdb = MDB(database, collection)
self.mdb.primary_key = "entity_id"
def _ext_service(self, entity_id, typ, service, binding):
try:
srvs = self[entity_id][typ]
except KeyError:
return None
if not srvs:
return srvs
res = []
for srv in srvs:
if "extensions" in srv:
for elem in srv["extensions"]["extension_elements"]:
if elem["__class__"] == service:
if elem["binding"] == binding:
res.append(elem)
return res
def load(self):
pass
def items(self):
for key, item in self.mdb.items():
yield key, unprotect(item["entity_description"])
def keys(self):
return self.mdb.keys()
def values(self):
for key, item in self.mdb.items():
yield unprotect(item["entity_description"])
def __contains__(self, item):
return item in self.mdb
def __getitem__(self, item):
res = self.mdb.get(item)
if not res:
raise KeyError(item)
elif len(res) == 1:
return unprotect(res[0]["entity_description"])
else:
raise CorruptDatabase("More then one document with key %s" % item)
def bindings(self, entity_id, typ, service):
pass