diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/env.py b/taskflow/persistence/backends/sqlalchemy/alembic/env.py index a864ac07..4e0a3ebf 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/env.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/env.py @@ -60,18 +60,16 @@ def run_migrations_online(): and associate a connection with the context. """ - engine = engine_from_config(config.get_section(config.config_ini_section), - prefix='sqlalchemy.', - poolclass=pool.NullPool) - - connection = engine.connect() - context.configure(connection=connection, target_metadata=target_metadata) - - try: + connectable = config.attributes.get('connection', None) + if connectable is None: + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', poolclass=pool.NullPool) + with connectable.connect() as connection: + context.configure(connection=connection, + target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() - finally: - connection.close() if context.is_offline_mode(): diff --git a/taskflow/persistence/backends/sqlalchemy/migration.py b/taskflow/persistence/backends/sqlalchemy/migration.py index 6d91e346..ca073246 100644 --- a/taskflow/persistence/backends/sqlalchemy/migration.py +++ b/taskflow/persistence/backends/sqlalchemy/migration.py @@ -18,25 +18,16 @@ import os -from alembic import config as a_config -from alembic import environment as a_env -from alembic import script as a_script +from alembic import command +from alembic import config -def _alembic_config(): +def _make_alembic_config(): path = os.path.join(os.path.dirname(__file__), 'alembic', 'alembic.ini') - return a_config.Config(path) + return config.Config(path) def db_sync(connection, revision='head'): - script = a_script.ScriptDirectory.from_config(_alembic_config()) - - def upgrade(rev, context): - return script._upgrade_revs(revision, rev) - - config = _alembic_config() - with a_env.EnvironmentContext(config, script, fn=upgrade, as_sql=False, - starting_rev=None, destination_rev=revision, - tag=None) as context: - context.configure(connection=connection) - context.run_migrations() + cfg = _make_alembic_config() + cfg.attributes['connection'] = connection + command.upgrade(cfg, revision)