diff --git a/migrate/tests/versioning/test_script.py b/migrate/tests/versioning/test_script.py index c26b03b..954bc0d 100644 --- a/migrate/tests/versioning/test_script.py +++ b/migrate/tests/versioning/test_script.py @@ -271,3 +271,25 @@ class TestSqlScript(fixture.Pathed, fixture.DB): sqls = SqlScript(src) sqls.run(self.engine) tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True) + + @fixture.usedb() + def test_transaction_management_statements(self): + """ + Test that we can successfully execute SQL scripts with transaction + management statements. + """ + for script_pattern in ( + "BEGIN TRANSACTION; %s; COMMIT;", + "BEGIN; %s; END TRANSACTION;", + ): + + test_statement = ("CREATE TABLE TEST1 (field1 int); " + "DROP TABLE TEST1") + script = script_pattern % test_statement + src = self.tmp() + + with open(src, 'wt') as f: + f.write(script) + + sqls = SqlScript(src) + sqls.run(self.engine) diff --git a/migrate/versioning/script/sql.py b/migrate/versioning/script/sql.py index 70b49ec..4b0536d 100644 --- a/migrate/versioning/script/sql.py +++ b/migrate/versioning/script/sql.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- import logging +import re import shutil import sqlparse @@ -36,13 +37,29 @@ class SqlScript(base.BaseScript): try: trans = conn.begin() try: + # ignore transaction management statements that are + # redundant in SQL script context and result in + # operational error being returned. + # + # Note: we don't ignore ROLLBACK in migration scripts + # since its usage would be insane anyway, and we're + # better to fail on its occurance instead of ignoring it + # (and committing transaction, which is contradictory to + # the whole idea of ROLLBACK) + ignored_statements = ('BEGIN', 'END', 'COMMIT') + ignored_regex = re.compile('^\s*(%s).*;?$' % '|'.join(ignored_statements), + re.IGNORECASE) + # NOTE(ihrachys): script may contain multiple statements, and # not all drivers reliably handle multistatement queries or # commands passed to .execute(), so split them and execute one # by one for statement in sqlparse.split(text): if statement: - conn.execute(statement) + if re.match(ignored_regex, statement): + log.warning('"%s" found in SQL script; ignoring' % statement) + else: + conn.execute(statement) trans.commit() except Exception as e: log.error("SQL script %s failed: %s", self.path, e)