diff --git a/tools/schema_generator.py b/tools/schema_generator.py index 12e03bf86..8202ab241 100755 --- a/tools/schema_generator.py +++ b/tools/schema_generator.py @@ -17,6 +17,7 @@ import contextlib import re +import sqlalchemy as sa import tabulate from taskflow.persistence.backends import impl_sqlalchemy @@ -53,29 +54,30 @@ def main(): with contextlib.closing(backend.get_connection()) as conn: conn.upgrade() # Now make a prettier version of that schema... - tables = backend.engine.execute(TABLE_QUERY) - table_names = [r[0] for r in tables] - for i, table_name in enumerate(table_names): - pretty_name = NAME_MAPPING.get(table_name, table_name) - print("*" + pretty_name + "*") - # http://www.sqlite.org/faq.html#q24 - table_name = table_name.replace("\"", "\"\"") - rows = [] - for r in backend.engine.execute(SCHEMA_QUERY % table_name): - # Cut out the numbers from things like VARCHAR(12) since - # this is not very useful to show users who just want to - # see the basic schema... - row_type = re.sub(r"\(.*?\)", "", r['type']).strip() - if not row_type: - raise ValueError("Row %s of table '%s' was empty after" - " cleaning" % (r['cid'], table_name)) - rows.append([r['name'], row_type, to_bool_string(r['pk'])]) - contents = tabulate.tabulate( - rows, headers=['Name', 'Type', 'Primary Key'], - tablefmt="rst") - print("\n%s" % contents.strip()) - if i + 1 != len(table_names): - print("") + with backend.engine.connect() as conn, conn.begin(): + tables = conn.execute(sa.text(TABLE_QUERY)) + table_names = [r[0] for r in tables] + for i, table_name in enumerate(table_names): + pretty_name = NAME_MAPPING.get(table_name, table_name) + print("*" + pretty_name + "*") + # http://www.sqlite.org/faq.html#q24 + table_name = table_name.replace("\"", "\"\"") + rows = [] + for r in conn.execute(sa.text(SCHEMA_QUERY % table_name)): + # Cut out the numbers from things like VARCHAR(12) since + # this is not very useful to show users who just want to + # see the basic schema... + row_type = re.sub(r"\(.*?\)", "", r['type']).strip() + if not row_type: + raise ValueError("Row %s of table '%s' was empty after" + " cleaning" % (r['cid'], table_name)) + rows.append([r['name'], row_type, to_bool_string(r['pk'])]) + contents = tabulate.tabulate( + rows, headers=['Name', 'Type', 'Primary Key'], + tablefmt="rst") + print("\n%s" % contents.strip()) + if i + 1 != len(table_names): + print("") if __name__ == '__main__':