diff --git a/oslo_db/sqlalchemy/enginefacade.py b/oslo_db/sqlalchemy/enginefacade.py index 8150875..cd5d74b 100644 --- a/oslo_db/sqlalchemy/enginefacade.py +++ b/oslo_db/sqlalchemy/enginefacade.py @@ -287,6 +287,15 @@ class _TransactionFactory(object): else: return self._writer_maker(**kw) + def _create_factory_copy(self): + factory = _TransactionFactory() + factory._url_cfg.update(self._url_cfg) + factory._engine_cfg.update(self._engine_cfg) + factory._maker_cfg.update(self._maker_cfg) + factory._transaction_ctx_cfg.update(self._transaction_ctx_cfg) + factory._facade_cfg.update(self._facade_cfg) + return factory + def _args_for_conf(self, default_cfg, conf): if conf is None: return dict( @@ -308,7 +317,9 @@ class _TransactionFactory(object): return self._args_for_conf(self._engine_cfg, conf) def _maker_args_for_conf(self, conf): - return self._args_for_conf(self._maker_cfg, conf) + maker_args = self._args_for_conf(self._maker_cfg, conf) + maker_args['autocommit'] = maker_args.pop('__autocommit') + return maker_args def dispose_pool(self): """Call engine.pool.dispose() on underlying Engine objects.""" @@ -345,7 +356,6 @@ class _TransactionFactory(object): url_args['slave_connection'] = slave_connection engine_args = self._engine_args_for_conf(conf) maker_args = self._maker_args_for_conf(conf) - maker_args['autocommit'] = maker_args.pop('__autocommit') self._writer_engine, self._writer_maker = \ self._setup_for_connection( @@ -660,6 +670,81 @@ class _TransactionContextManager(object): """Call engine.pool.dispose() on underlying Engine objects.""" self._factory.dispose_pool() + def make_new_manager(self): + """Create a new, independent _TransactionContextManager from this one. + + Copies the underlying _TransactionFactory to a new one, so that + it can be further configured with new options. + + Used for test environments where the application-wide + _TransactionContextManager may be used as a factory for test-local + managers. + + """ + new = self._clone() + new._root = new + new._root_factory = self._root_factory._create_factory_copy() + assert not new._factory._started + return new + + def patch_factory(self, factory_or_manager): + """Patch a _TransactionFactory into this manager. + + Replaces this manager's factory with the given one, and returns + a callable that will reset the factory back to what we + started with. + + Only works for root factories. Is intended for test suites + that need to patch in alternate database configurations. + + The given argument may be a _TransactionContextManager or a + _TransactionFactory. + + """ + + if isinstance(factory_or_manager, _TransactionContextManager): + factory = factory_or_manager._factory + elif isinstance(factory_or_manager, _TransactionFactory): + factory = factory_or_manager + else: + raise ValueError( + "_TransactionContextManager or " + "_TransactionFactory expected.") + assert self._root is self + existing_factory = self._root_factory + self._root_factory = factory + + def reset(): + self._root_factory = existing_factory + + return reset + + def patch_engine(self, engine): + """Patch an Engine into this manager. + + Replaces this manager's factory with a _TestTransactionFactory + that will use the given Engine, and returns + a callable that will reset the factory back to what we + started with. + + Only works for root factories. Is intended for test suites + that need to patch in alternate database configurations. + + """ + + existing_factory = self._factory + maker = existing_factory._writer_maker + maker_kwargs = existing_factory._maker_args_for_conf(cfg.CONF) + maker = orm.get_maker(engine=engine, **maker_kwargs) + + factory = _TestTransactionFactory( + engine, maker, + apply_global=False, + synchronous_reader=existing_factory. + _facade_cfg['synchronous_reader'] + ) + return self.patch_factory(factory) + @property def replace(self): """Modifier to replace the global transaction factory with this one.""" diff --git a/oslo_db/tests/sqlalchemy/test_enginefacade.py b/oslo_db/tests/sqlalchemy/test_enginefacade.py index 72663d2..fb63856 100644 --- a/oslo_db/tests/sqlalchemy/test_enginefacade.py +++ b/oslo_db/tests/sqlalchemy/test_enginefacade.py @@ -1076,6 +1076,154 @@ class MockFacadeTest(oslo_test_base.BaseTestCase): session.execute("test") +class PatchFactoryTest(oslo_test_base.BaseTestCase): + + def test_patch_manager(self): + normal_mgr = enginefacade.transaction_context() + normal_mgr.configure(connection="sqlite:///foo.db") + alt_mgr = enginefacade.transaction_context() + alt_mgr.configure(connection="sqlite:///bar.db") + + @normal_mgr.writer + def go1(context): + s1 = context.session + self.assertEqual( + s1.bind.url, "sqlite:///foo.db") + self.assertIs( + s1.bind, + normal_mgr._factory._writer_engine) + + @normal_mgr.writer + def go2(context): + s1 = context.session + + self.assertEqual( + s1.bind.url, + "sqlite:///bar.db") + + self.assertIs( + normal_mgr._factory._writer_engine, + alt_mgr._factory._writer_engine + ) + + def create_engine(sql_connection, **kw): + return mock.Mock(url=sql_connection) + + with mock.patch( + "oslo_db.sqlalchemy.engines.create_engine", create_engine): + context = oslo_context.RequestContext() + go1(context) + reset = normal_mgr.patch_factory(alt_mgr) + go2(context) + reset() + go1(context) + + def test_patch_factory(self): + normal_mgr = enginefacade.transaction_context() + normal_mgr.configure(connection="sqlite:///foo.db") + alt_mgr = enginefacade.transaction_context() + alt_mgr.configure(connection="sqlite:///bar.db") + + @normal_mgr.writer + def go1(context): + s1 = context.session + self.assertEqual( + s1.bind.url, "sqlite:///foo.db") + self.assertIs( + s1.bind, + normal_mgr._factory._writer_engine) + + @normal_mgr.writer + def go2(context): + s1 = context.session + + self.assertEqual( + s1.bind.url, + "sqlite:///bar.db") + + self.assertIs( + normal_mgr._factory._writer_engine, + alt_mgr._factory._writer_engine + ) + + def create_engine(sql_connection, **kw): + return mock.Mock(url=sql_connection) + + with mock.patch( + "oslo_db.sqlalchemy.engines.create_engine", create_engine): + context = oslo_context.RequestContext() + go1(context) + reset = normal_mgr.patch_factory(alt_mgr._factory) + go2(context) + reset() + go1(context) + + def test_patch_engine(self): + normal_mgr = enginefacade.transaction_context() + normal_mgr.configure(connection="sqlite:///foo.db") + + @normal_mgr.writer + def go1(context): + s1 = context.session + self.assertEqual( + s1.bind.url, "sqlite:///foo.db") + self.assertIs( + s1.bind, + normal_mgr._factory._writer_engine) + + @normal_mgr.writer + def go2(context): + s1 = context.session + + self.assertEqual( + s1.bind.url, + "sqlite:///bar.db") + + def create_engine(sql_connection, **kw): + return mock.Mock(url=sql_connection) + + with mock.patch( + "oslo_db.sqlalchemy.engines.create_engine", create_engine): + mock_engine = create_engine("sqlite:///bar.db") + + context = oslo_context.RequestContext() + go1(context) + reset = normal_mgr.patch_engine(mock_engine) + go2(context) + self.assertIs( + normal_mgr._factory._writer_engine, mock_engine) + reset() + go1(context) + + def test_new_manager_from_config(self): + normal_mgr = enginefacade.transaction_context() + normal_mgr.configure( + connection="sqlite://", + sqlite_fk=True, + mysql_sql_mode="FOOBAR", + max_overflow=38 + ) + + normal_mgr._factory._start() + + copied_mgr = normal_mgr.make_new_manager() + + self.assertTrue(normal_mgr._factory._started) + self.assertIsNotNone(normal_mgr._factory._writer_engine) + + self.assertIsNot(copied_mgr._factory, normal_mgr._factory) + self.assertFalse(copied_mgr._factory._started) + copied_mgr._factory._start() + self.assertIsNot( + normal_mgr._factory._writer_engine, + copied_mgr._factory._writer_engine) + + engine_args = copied_mgr._factory._engine_args_for_conf(None) + self.assertTrue(engine_args['sqlite_fk']) + self.assertEqual("FOOBAR", engine_args["mysql_sql_mode"]) + self.assertEqual(38, engine_args["max_overflow"]) + + class SynchronousReaderWSlaveMockFacadeTest(MockFacadeTest): synchronous_reader = True