diff --git a/swift/common/db.py b/swift/common/db.py index 6425e85034..bd88f881e1 100644 --- a/swift/common/db.py +++ b/swift/common/db.py @@ -427,6 +427,14 @@ class DatabaseBroker(object): if self.conn: self.conn.timeout = old_timeout + @contextmanager + def maybe_get(self, conn): + if conn: + yield conn + else: + with self.get() as conn: + yield conn + @contextmanager def get(self): """Use with the "with" statement; returns a database connection.""" diff --git a/swift/container/backend.py b/swift/container/backend.py index 69ea59ffbc..f61ee77706 100644 --- a/swift/container/backend.py +++ b/swift/container/backend.py @@ -1653,11 +1653,8 @@ class ContainerBroker(DatabaseBroker): return [row for row in data] try: - if connection: - return do_query(connection) - else: - with self.get() as conn: - return do_query(conn) + with self.maybe_get(connection) as conn: + return do_query(conn) except sqlite3.OperationalError as err: if ('no such table: %s' % SHARD_RANGE_TABLE) not in str(err): raise diff --git a/test/unit/common/test_db.py b/test/unit/common/test_db.py index acdf271364..0853f06878 100644 --- a/test/unit/common/test_db.py +++ b/test/unit/common/test_db.py @@ -604,6 +604,25 @@ class TestExampleBroker(unittest.TestCase): broker.get_info() self.assertEqual(1, broker.get_info()[count_key]) + @with_tempdir + def test_maybe_get(self, tempdir): + broker = self.broker_class(os.path.join(tempdir, 'test.db'), + account='a', container='c') + broker.initialize(next(self.ts), + storage_policy_index=int(self.policy)) + qry = 'select account from %s_stat' % broker.db_type + with broker.maybe_get(None) as conn: + rows = [dict(x) for x in conn.execute(qry)] + self.assertEqual([{'account': 'a'}], rows) + self.assertEqual(conn, broker.conn) + with broker.get() as other_conn: + self.assertEqual(broker.conn, None) + with broker.maybe_get(other_conn) as identity_conn: + self.assertEqual(other_conn, identity_conn) + self.assertEqual(broker.conn, None) + self.assertEqual(broker.conn, None) + self.assertEqual(broker.conn, conn) + class TestDatabaseBroker(unittest.TestCase):