_mdb_get_database: don't require tuple for URI

This commit is contained in:
Fredrik Thulin
2013-06-11 15:04:26 +02:00
parent 93a9bf3a6a
commit c16d9792c6

View File

@@ -4,6 +4,7 @@ import logging
from pymongo import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
import pymongo.uri_parser
import pymongo.errors
from saml2.eptid import Eptid
from saml2.mdstore import MetaData
from saml2.s_utils import PolicyError
@@ -251,39 +252,42 @@ class MDB(object):
def _mdb_get_database(database, **kwargs):
def _mdb_get_database(uri, **kwargs):
"""
Helper-function to connect to MongoDB and return a database object.
The `uri' argument should be either a full MongoDB connection URI string,
or just a database name in which case a connection to the default mongo
instance at mongodb://localhost:27017 will be made.
Performs explicit authentication if a username is provided in a connection
string URI, since PyMongo does not always seem to do that as promised.
If database is *not* a connection to mongodb://localhost:27017 will be
made.
:params database: name as string or (uri, name)
:returns: pymongo database object
"""
_conn = None
if isinstance(database, tuple):
uri, collection, = database
if uri.startswith("mongodb://"):
if "replicaSet=" in uri:
connection_factory = MongoReplicaSetClient
else:
connection_factory = MongoClient
connection_factory = MongoClient
_parsed_uri = {}
db_name = None
try:
_parsed_uri = pymongo.uri_parser.parse_uri(uri)
except pymongo.errors.InvalidURI:
# assume URI to be just the database name
db_name = uri
pass
else:
if "replicaset" in _parsed_uri["options"]:
connection_factory = MongoReplicaSetClient
db_name = _parsed_uri.get("database", "pysaml2")
if not 'tz_aware' in kwargs:
# default, but not forced
kwargs['tz_aware'] = True
if not "tz_aware" in kwargs:
# default, but not forced
kwargs["tz_aware"] = True
_conn = connection_factory(uri, **kwargs)
if not _conn:
_conn = MongoClient()
_conn = connection_factory(uri, **kwargs)
_db = _conn[db_name]
_parsed_uri = pymongo.uri_parser.parse_uri(uri)
_db = _conn[_parsed_uri.get("database")]
if _parsed_uri.get("username", None):
if "username" in _parsed_uri:
_db.authenticate(
_parsed_uri.get("username", None),
_parsed_uri.get("password", None)