Use create_all for empty databases

Rather than running through all of the migrations when starting Zuul
with an empty database, this uses sqlalchemy's create_all method to
create it from the declarative schema.

To make sure that stays in sync with alembic, a test is added to run
DB creation both ways and compare.

The declaritive schema had one column with an incorrect type, and
several columns out of order; this adjusts the schema to match the
migrations.

Contrary to expectations, using sqlalchemy to create the schema actually
adds about 0.05 seconds on averate to test runtime.

Change-Id: I594b6980f5efa5fa4b8ca387c5d0ab4373b86394
This commit is contained in:
James E. Blair
2021-06-19 10:07:35 -07:00
parent 0a5e330891
commit abe2a482be
2 changed files with 103 additions and 10 deletions

View File

@@ -0,0 +1,73 @@
# Copyright 2021 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import re
from zuul.driver.sql import SQLDriver
from tests.base import BaseTestCase, MySQLSchemaFixture
class TestDatabase(BaseTestCase):
def setUp(self):
super().setUp()
f = MySQLSchemaFixture()
self.useFixture(f)
config = dict(dburi=f.dburi)
driver = SQLDriver()
self.connection = driver.getConnection('database', config)
self.connection.onLoad()
def compareMysql(self, alembic_text, sqlalchemy_text):
alembic_lines = alembic_text.split('\n')
sqlalchemy_lines = sqlalchemy_text.split('\n')
self.assertEqual(len(alembic_lines), len(sqlalchemy_lines))
alembic_constraints = []
sqlalchemy_constraints = []
for i in range(len(alembic_lines)):
if alembic_lines[i].startswith(" `"):
# Column
self.assertEqual(alembic_lines[i], sqlalchemy_lines[i])
elif alembic_lines[i].startswith(" "):
# Constraints can be unordered
# strip trailing commas since the last line omits it
alembic_constraints.append(
re.sub(',$', '', alembic_lines[i]))
sqlalchemy_constraints.append(
re.sub(',$', '', sqlalchemy_lines[i]))
else:
self.assertEqual(alembic_lines[i], sqlalchemy_lines[i])
alembic_constraints.sort()
sqlalchemy_constraints.sort()
self.assertEqual(alembic_constraints, sqlalchemy_constraints)
def test_migration(self):
# Test that SQLAlchemy create_all produces the same output as
# a full migration run.
sqlalchemy_tables = {}
self.connection.engine.execute("set foreign_key_checks=0")
for table in self.connection.engine.execute("show tables"):
table = table[0]
sqlalchemy_tables[table] = self.connection.engine.execute(
f"show create table {table}").one()[1]
self.connection.engine.execute(f"drop table {table}")
self.connection.engine.execute("set foreign_key_checks=1")
self.connection.force_migrations = True
self.connection.onLoad()
for table in self.connection.engine.execute("show tables"):
table = table[0]
create = self.connection.engine.execute(
f"show create table {table}").one()[1]
self.compareMysql(create, sqlalchemy_tables[table])

View File

@@ -180,6 +180,8 @@ class DatabaseSession(object):
class SQLConnection(BaseConnection):
driver_name = 'sql'
log = logging.getLogger("zuul.SQLConnection")
# This is used by tests only
force_migrations = False
def __init__(self, driver, connection_name, connection_config):
@@ -195,6 +197,7 @@ class SQLConnection(BaseConnection):
try:
self.dburi = self.connection_config.get('dburi')
self.metadata = sa.MetaData()
self._setup_models()
# Recycle connections if they've been idle for more than 1 second.
@@ -239,7 +242,12 @@ class SQLConnection(BaseConnection):
# Alembic lets us add arbitrary data in the tag argument. We can
# leverage that to tell the upgrade scripts about the table prefix.
tag = {'table_prefix': self.table_prefix}
alembic.command.upgrade(config, 'head', tag=tag)
if current_rev is None and not self.force_migrations:
self.metadata.create_all(self.engine)
alembic.command.stamp(config, "head", tag=tag)
else:
alembic.command.upgrade(config, 'head', tag=tag)
def onLoad(self):
while True:
@@ -253,27 +261,34 @@ class SQLConnection(BaseConnection):
time.sleep(10)
def _setup_models(self):
Base = declarative_base(metadata=sa.MetaData())
Base = declarative_base(metadata=self.metadata)
class BuildSetModel(Base):
__tablename__ = self.table_prefix + BUILDSET_TABLE
id = sa.Column(sa.Integer, primary_key=True)
uuid = sa.Column(sa.String(36))
zuul_ref = sa.Column(sa.String(255))
pipeline = sa.Column(sa.String(255))
project = sa.Column(sa.String(255))
branch = sa.Column(sa.String(255))
change = sa.Column(sa.Integer, nullable=True)
patchset = sa.Column(sa.String(255), nullable=True)
ref = sa.Column(sa.String(255))
oldrev = sa.Column(sa.String(255))
newrev = sa.Column(sa.String(255))
ref_url = sa.Column(sa.String(255))
result = sa.Column(sa.String(255))
message = sa.Column(sa.TEXT())
tenant = sa.Column(sa.String(255))
result = sa.Column(sa.String(255))
ref_url = sa.Column(sa.String(255))
oldrev = sa.Column(sa.String(255))
newrev = sa.Column(sa.String(255))
branch = sa.Column(sa.String(255))
uuid = sa.Column(sa.String(36))
event_id = sa.Column(sa.String(255), nullable=True)
sa.Index(self.table_prefix + 'project_pipeline_idx',
project, pipeline)
sa.Index(self.table_prefix + 'project_change_idx',
project, change)
sa.Index(self.table_prefix + 'change_idx', change)
sa.Index(self.table_prefix + 'uuid_idx', uuid)
def createBuild(self, *args, **kw):
session = orm.session.Session.object_session(self)
b = BuildModel(*args, **kw)
@@ -286,12 +301,11 @@ class SQLConnection(BaseConnection):
class BuildModel(Base):
__tablename__ = self.table_prefix + BUILD_TABLE
id = sa.Column(sa.Integer, primary_key=True)
buildset_id = sa.Column(sa.String, sa.ForeignKey(
buildset_id = sa.Column(sa.Integer, sa.ForeignKey(
self.table_prefix + BUILDSET_TABLE + ".id"))
uuid = sa.Column(sa.String(36))
job_name = sa.Column(sa.String(255))
result = sa.Column(sa.String(255))
held = sa.Column(sa.Boolean)
start_time = sa.Column(sa.DateTime)
end_time = sa.Column(sa.DateTime)
voting = sa.Column(sa.Boolean)
@@ -299,8 +313,14 @@ class SQLConnection(BaseConnection):
node_name = sa.Column(sa.String(255))
error_detail = sa.Column(sa.TEXT())
final = sa.Column(sa.Boolean)
held = sa.Column(sa.Boolean)
buildset = orm.relationship(BuildSetModel, backref="builds")
sa.Index(self.table_prefix + 'job_name_buildset_id_idx',
job_name, buildset_id)
sa.Index(self.table_prefix + 'uuid_buildset_id_idx',
uuid, buildset_id)
def createArtifact(self, *args, **kw):
session = orm.session.Session.object_session(self)
# SQLAlchemy reserves the 'metadata' attribute on