Sync db.sqlalchemy code from Oslo

Patch contains common code from Oslo, we need to sync this modules to fix
a problem with a newer version of sqlite and the db.sqlachemy module that
is causing some of tests to fail if the version of the sqlite is >=3.7.16.

The patch also includes the code needed with the transition from migrate
tool to alembic, these base migration classes depends on the py3kcompat
and test modules which were also added to the project.

Closes-Bug: #1272500
Change-Id: Ifcd1254f16506a4247523056d6ad4ad8d7aefdac
This commit is contained in:
Lucas Alvares Gomes 2014-01-24 19:43:29 +00:00 committed by Devananda van der Veen
parent 2823630cd5
commit 4f8b5fb961
16 changed files with 1355 additions and 87 deletions

View File

@ -98,10 +98,10 @@
# Options defined in ironic.openstack.common.db.sqlalchemy.session
#
# the filename to use with sqlite (string value)
# The file name to use with SQLite (string value)
#sqlite_db=ironic.sqlite
# If true, use synchronous mode for sqlite (boolean value)
# If True, SQLite uses synchronous mode (boolean value)
#sqlite_synchronous=true
@ -485,10 +485,11 @@
# slave database (string value)
#slave_connection=
# timeout before idle sql connections are reaped (integer
# Timeout before idle sql connections are reaped (integer
# value)
# Deprecated group/name - [DEFAULT]/sql_idle_timeout
# Deprecated group/name - [DATABASE]/sql_idle_timeout
# Deprecated group/name - [sql]/idle_timeout
#idle_timeout=3600
# Minimum number of SQL connections to keep open in a pool
@ -503,13 +504,13 @@
# Deprecated group/name - [DATABASE]/sql_max_pool_size
#max_pool_size=<None>
# maximum db connection retries during startup. (setting -1
# Maximum db connection retries during startup. (setting -1
# implies an infinite retry count) (integer value)
# Deprecated group/name - [DEFAULT]/sql_max_retries
# Deprecated group/name - [DATABASE]/sql_max_retries
#max_retries=10
# interval between retries of opening a sql connection
# Interval between retries of opening a sql connection
# (integer value)
# Deprecated group/name - [DEFAULT]/sql_retry_interval
# Deprecated group/name - [DATABASE]/reconnect_interval

View File

@ -0,0 +1,265 @@
# coding: utf-8
#
# Copyright (c) 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# Base on code in migrate/changeset/databases/sqlite.py which is under
# the following license:
#
# The MIT License
#
# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import os
import re
from migrate.changeset import ansisql
from migrate.changeset.databases import sqlite
from migrate import exceptions as versioning_exceptions
from migrate.versioning import api as versioning_api
from migrate.versioning.repository import Repository
import sqlalchemy
from sqlalchemy.schema import UniqueConstraint
from ironic.openstack.common.db import exception
from ironic.openstack.common.db.sqlalchemy import session as db_session
from ironic.openstack.common.gettextutils import _
get_engine = db_session.get_engine
def _get_unique_constraints(self, table):
"""Retrieve information about existing unique constraints of the table
This feature is needed for _recreate_table() to work properly.
Unfortunately, it's not available in sqlalchemy 0.7.x/0.8.x.
"""
data = table.metadata.bind.execute(
"""SELECT sql
FROM sqlite_master
WHERE
type='table' AND
name=:table_name""",
table_name=table.name
).fetchone()[0]
UNIQUE_PATTERN = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)"
return [
UniqueConstraint(
*[getattr(table.columns, c.strip(' "')) for c in cols.split(",")],
name=name
)
for name, cols in re.findall(UNIQUE_PATTERN, data)
]
def _recreate_table(self, table, column=None, delta=None, omit_uniques=None):
"""Recreate the table properly
Unlike the corresponding original method of sqlalchemy-migrate this one
doesn't drop existing unique constraints when creating a new one.
"""
table_name = self.preparer.format_table(table)
# we remove all indexes so as not to have
# problems during copy and re-create
for index in table.indexes:
index.drop()
# reflect existing unique constraints
for uc in self._get_unique_constraints(table):
table.append_constraint(uc)
# omit given unique constraints when creating a new table if required
table.constraints = set([
cons for cons in table.constraints
if omit_uniques is None or cons.name not in omit_uniques
])
self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
self.execute()
insertion_string = self._modify_table(table, column, delta)
table.create(bind=self.connection)
self.append(insertion_string % {'table_name': table_name})
self.execute()
self.append('DROP TABLE migration_tmp')
self.execute()
def _visit_migrate_unique_constraint(self, *p, **k):
"""Drop the given unique constraint
The corresponding original method of sqlalchemy-migrate just
raises NotImplemented error
"""
self.recreate_table(p[0].table, omit_uniques=[p[0].name])
def patch_migrate():
"""A workaround for SQLite's inability to alter things
SQLite abilities to alter tables are very limited (please read
http://www.sqlite.org/lang_altertable.html for more details).
E. g. one can't drop a column or a constraint in SQLite. The
workaround for this is to recreate the original table omitting
the corresponding constraint (or column).
sqlalchemy-migrate library has recreate_table() method that
implements this workaround, but it does it wrong:
- information about unique constraints of a table
is not retrieved. So if you have a table with one
unique constraint and a migration adding another one
you will end up with a table that has only the
latter unique constraint, and the former will be lost
- dropping of unique constraints is not supported at all
The proper way to fix this is to provide a pull-request to
sqlalchemy-migrate, but the project seems to be dead. So we
can go on with monkey-patching of the lib at least for now.
"""
# this patch is needed to ensure that recreate_table() doesn't drop
# existing unique constraints of the table when creating a new one
helper_cls = sqlite.SQLiteHelper
helper_cls.recreate_table = _recreate_table
helper_cls._get_unique_constraints = _get_unique_constraints
# this patch is needed to be able to drop existing unique constraints
constraint_cls = sqlite.SQLiteConstraintDropper
constraint_cls.visit_migrate_unique_constraint = \
_visit_migrate_unique_constraint
constraint_cls.__bases__ = (ansisql.ANSIColumnDropper,
sqlite.SQLiteConstraintGenerator)
def db_sync(abs_path, version=None, init_version=0):
"""Upgrade or downgrade a database.
Function runs the upgrade() or downgrade() functions in change scripts.
:param abs_path: Absolute path to migrate repository.
:param version: Database will upgrade/downgrade until this version.
If None - database will update to the latest
available version.
:param init_version: Initial database version
"""
if version is not None:
try:
version = int(version)
except ValueError:
raise exception.DbMigrationError(
message=_("version should be an integer"))
current_version = db_version(abs_path, init_version)
repository = _find_migrate_repo(abs_path)
_db_schema_sanity_check()
if version is None or version > current_version:
return versioning_api.upgrade(get_engine(), repository, version)
else:
return versioning_api.downgrade(get_engine(), repository,
version)
def _db_schema_sanity_check():
engine = get_engine()
if engine.name == 'mysql':
onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION '
'from information_schema.TABLES '
'where TABLE_SCHEMA=%s and '
'TABLE_COLLATION NOT LIKE "%%utf8%%"')
table_names = [res[0] for res in engine.execute(onlyutf8_sql,
engine.url.database)]
if len(table_names) > 0:
raise ValueError(_('Tables "%s" have non utf8 collation, '
'please make sure all tables are CHARSET=utf8'
) % ','.join(table_names))
def db_version(abs_path, init_version):
"""Show the current version of the repository.
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
try:
return versioning_api.db_version(get_engine(), repository)
except versioning_exceptions.DatabaseNotControlledError:
meta = sqlalchemy.MetaData()
engine = get_engine()
meta.reflect(bind=engine)
tables = meta.tables
if len(tables) == 0 or 'alembic_version' in tables:
db_version_control(abs_path, init_version)
return versioning_api.db_version(get_engine(), repository)
else:
raise exception.DbMigrationError(
message=_(
"The database is not under version control, but has "
"tables. Please stamp the current version of the schema "
"manually."))
def db_version_control(abs_path, version=None):
"""Mark a database as under this repository's version control.
Once a database is under version control, schema changes should
only be done via change scripts in this repository.
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
versioning_api.version_control(get_engine(), repository, version)
return version
def _find_migrate_repo(abs_path):
"""Get the project's change script repository
:param abs_path: Absolute path to migrate repository
"""
if not os.path.exists(abs_path):
raise exception.DbMigrationError("Path %s not found" % abs_path)
return Repository(abs_path)

View File

@ -0,0 +1,77 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import alembic
from alembic import config as alembic_config
import alembic.migration as alembic_migration
from ironic.openstack.common.db.sqlalchemy.migration_cli import ext_base
from ironic.openstack.common.db.sqlalchemy import session as db_session
class AlembicExtension(ext_base.MigrationExtensionBase):
order = 2
@property
def enabled(self):
return os.path.exists(self.alembic_ini_path)
def __init__(self, migration_config):
"""Extension to provide alembic features.
:param migration_config: Stores specific configuration for migrations
:type migration_config: dict
"""
self.alembic_ini_path = migration_config.get('alembic_ini_path', '')
self.config = alembic_config.Config(self.alembic_ini_path)
# option should be used if script is not in default directory
repo_path = migration_config.get('alembic_repo_path')
if repo_path:
self.config.set_main_option('script_location', repo_path)
def upgrade(self, version):
return alembic.command.upgrade(self.config, version or 'head')
def downgrade(self, version):
if isinstance(version, int) or version is None or version.isdigit():
version = 'base'
return alembic.command.downgrade(self.config, version)
def version(self):
engine = db_session.get_engine()
with engine.connect() as conn:
context = alembic_migration.MigrationContext.configure(conn)
return context.get_current_revision()
def revision(self, message='', autogenerate=False):
"""Creates template for migration.
:param message: Text that will be used for migration title
:type message: string
:param autogenerate: If True - generates diff based on current database
state
:type autogenerate: bool
"""
return alembic.command.revision(self.config, message=message,
autogenerate=autogenerate)
def stamp(self, revision):
"""Stamps database with provided revision.
:param revision: Should match one from repository or head - to stamp
database with most recent revision
:type revision: string
"""
return alembic.command.stamp(self.config, revision=revision)

View File

@ -0,0 +1,79 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import abc
import six
@six.add_metaclass(abc.ABCMeta)
class MigrationExtensionBase(object):
#used to sort migration in logical order
order = 0
@property
def enabled(self):
"""Used for availability verification of a plugin.
:rtype: bool
"""
return False
@abc.abstractmethod
def upgrade(self, version):
"""Used for upgrading database.
:param version: Desired database version
:type version: string
"""
@abc.abstractmethod
def downgrade(self, version):
"""Used for downgrading database.
:param version: Desired database version
:type version: string
"""
@abc.abstractmethod
def version(self):
"""Current database version.
:returns: Databse version
:rtype: string
"""
def revision(self, *args, **kwargs):
"""Used to generate migration script.
In migration engines that support this feature, it should generate
new migration script.
Accept arbitrary set of arguments.
"""
raise NotImplementedError()
def stamp(self, *args, **kwargs):
"""Stamps database based on plugin features.
Accept arbitrary set of arguments.
"""
raise NotImplementedError()
def __cmp__(self, other):
"""Used for definition of plugin order.
:param other: MigrationExtensionBase instance
:rtype: bool
"""
return self.order > other.order

View File

@ -0,0 +1,66 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
from ironic.openstack.common.db.sqlalchemy import migration
from ironic.openstack.common.db.sqlalchemy.migration_cli import ext_base
from ironic.openstack.common.gettextutils import _ # noqa
from ironic.openstack.common import log as logging
LOG = logging.getLogger(__name__)
class MigrateExtension(ext_base.MigrationExtensionBase):
"""Extension to provide sqlalchemy-migrate features.
:param migration_config: Stores specific configuration for migrations
:type migration_config: dict
"""
order = 1
def __init__(self, migration_config):
self.repository = migration_config.get('migration_repo_path', '')
self.init_version = migration_config.get('init_version', 0)
@property
def enabled(self):
return os.path.exists(self.repository)
def upgrade(self, version):
version = None if version == 'head' else version
return migration.db_sync(
self.repository, version,
init_version=self.init_version)
def downgrade(self, version):
try:
#version for migrate should be valid int - else skip
if version in ('base', None):
version = self.init_version
version = int(version)
return migration.db_sync(
self.repository, version,
init_version=self.init_version)
except ValueError:
LOG.error(
_('Migration number for migrate plugin must be valid '
'integer or empty, if you want to downgrade '
'to initial state')
)
raise
def version(self):
return migration.db_version(
self.repository, init_version=self.init_version)

View File

@ -0,0 +1,71 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from stevedore import enabled
MIGRATION_NAMESPACE = 'ironic.openstack.common.migration'
def check_plugin_enabled(ext):
"""Used for EnabledExtensionManager"""
return ext.obj.enabled
class MigrationManager(object):
def __init__(self, migration_config):
self._manager = enabled.EnabledExtensionManager(
MIGRATION_NAMESPACE,
check_plugin_enabled,
invoke_kwds={'migration_config': migration_config},
invoke_on_load=True
)
if not self._plugins:
raise ValueError('There must be at least one plugin active.')
@property
def _plugins(self):
return sorted(ext.obj for ext in self._manager.extensions)
def upgrade(self, revision):
"""Upgrade database with all available backends."""
results = []
for plugin in self._plugins:
results.append(plugin.upgrade(revision))
return results
def downgrade(self, revision):
"""Downgrade database with available backends."""
#downgrading should be performed in reversed order
results = []
for plugin in reversed(self._plugins):
results.append(plugin.downgrade(revision))
return results
def version(self):
"""Return last version of db."""
last = None
for plugin in self._plugins:
version = plugin.version()
if version:
last = version
return last
def revision(self, message, autogenerate):
"""Generate template or autogenerated revision."""
#revision should be done only by last plugin
return self._plugins[-1].revision(message, autogenerate)
def stamp(self, revision):
"""Create stamp for a given revision."""
return self._plugins[-1].stamp(revision)

View File

@ -1,5 +1,3 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2011 X.commerce, a business unit of eBay Inc.
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
@ -41,13 +39,13 @@ class ModelBase(object):
if not session:
session = sa.get_session()
# NOTE(boris-42): This part of code should be look like:
# sesssion.add(self)
# session.add(self)
# session.flush()
# But there is a bug in sqlalchemy and eventlet that
# raises NoneType exception if there is no running
# transaction and rollback is called. As long as
# sqlalchemy has this bug we have to create transaction
# explicity.
# explicitly.
with session.begin(subtransactions=True):
session.add(self)
session.flush()
@ -61,7 +59,16 @@ class ModelBase(object):
def get(self, key, default=None):
return getattr(self, key, default)
def _get_extra_keys(self):
@property
def _extra_keys(self):
"""Specifies custom fields
Subclasses can override this property to return a list
of custom fields that should be included in their dict
representation.
For reference check tests/db/sqlalchemy/test_models.py
"""
return []
def __iter__(self):
@ -69,7 +76,7 @@ class ModelBase(object):
# NOTE(russellb): Allow models to specify other keys that can be looked
# up, beyond the actual db columns. An example would be the 'name'
# property for an Instance.
columns.extend(self._get_extra_keys())
columns.extend(self._extra_keys)
self._i = iter(columns)
return self
@ -91,12 +98,12 @@ class ModelBase(object):
joined = dict([(k, v) for k, v in six.iteritems(self.__dict__)
if not k[0] == '_'])
local.update(joined)
return local.iteritems()
return six.iteritems(local)
class TimestampMixin(object):
created_at = Column(DateTime, default=timeutils.utcnow)
updated_at = Column(DateTime, onupdate=timeutils.utcnow)
created_at = Column(DateTime, default=lambda: timeutils.utcnow())
updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow())
class SoftDeleteMixin(object):

View File

@ -0,0 +1,187 @@
# Copyright 2013 Mirantis.inc
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Provision test environment for specific DB backends"""
import argparse
import os
import random
import string
from six import moves
import sqlalchemy
from ironic.openstack.common.db import exception as exc
SQL_CONNECTION = os.getenv('OS_TEST_DBAPI_ADMIN_CONNECTION', 'sqlite://')
def _gen_credentials(*names):
"""Generate credentials."""
auth_dict = {}
for name in names:
val = ''.join(random.choice(string.ascii_lowercase)
for i in moves.range(10))
auth_dict[name] = val
return auth_dict
def _get_engine(uri=SQL_CONNECTION):
"""Engine creation
By default the uri is SQL_CONNECTION which is admin credentials.
Call the function without arguments to get admin connection. Admin
connection required to create temporary user and database for each
particular test. Otherwise use existing connection to recreate connection
to the temporary database.
"""
return sqlalchemy.create_engine(uri, poolclass=sqlalchemy.pool.NullPool)
def _execute_sql(engine, sql, driver):
"""Initialize connection, execute sql query and close it."""
try:
with engine.connect() as conn:
if driver == 'postgresql':
conn.connection.set_isolation_level(0)
for s in sql:
conn.execute(s)
except sqlalchemy.exc.OperationalError:
msg = ('%s does not match database admin '
'credentials or database does not exist.')
raise exc.DBConnectionError(msg % SQL_CONNECTION)
def create_database(engine):
"""Provide temporary user and database for each particular test."""
driver = engine.name
auth = _gen_credentials('database', 'user', 'passwd')
sqls = {
'mysql': [
"drop database if exists %(database)s;",
"grant all on %(database)s.* to '%(user)s'@'localhost'"
" identified by '%(passwd)s';",
"create database %(database)s;",
],
'postgresql': [
"drop database if exists %(database)s;",
"drop user if exists %(user)s;",
"create user %(user)s with password '%(passwd)s';",
"create database %(database)s owner %(user)s;",
]
}
if driver == 'sqlite':
return 'sqlite:////tmp/%s' % auth['database']
try:
sql_rows = sqls[driver]
except KeyError:
raise ValueError('Unsupported RDBMS %s' % driver)
sql_query = map(lambda x: x % auth, sql_rows)
_execute_sql(engine, sql_query, driver)
params = auth.copy()
params['backend'] = driver
return "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" % params
def drop_database(engine, current_uri):
"""Drop temporary database and user after each particular test."""
engine = _get_engine(current_uri)
admin_engine = _get_engine()
driver = engine.name
auth = {'database': engine.url.database, 'user': engine.url.username}
if driver == 'sqlite':
try:
os.remove(auth['database'])
except OSError:
pass
return
sqls = {
'mysql': [
"drop database if exists %(database)s;",
"drop user '%(user)s'@'localhost';",
],
'postgresql': [
"drop database if exists %(database)s;",
"drop user if exists %(user)s;",
]
}
try:
sql_rows = sqls[driver]
except KeyError:
raise ValueError('Unsupported RDBMS %s' % driver)
sql_query = map(lambda x: x % auth, sql_rows)
_execute_sql(admin_engine, sql_query, driver)
def main():
"""Controller to handle commands
::create: Create test user and database with random names.
::drop: Drop user and database created by previous command.
"""
parser = argparse.ArgumentParser(
description='Controller to handle database creation and dropping'
' commands.',
epilog='Under normal circumstances is not used directly.'
' Used in .testr.conf to automate test database creation'
' and dropping processes.')
subparsers = parser.add_subparsers(
help='Subcommands to manipulate temporary test databases.')
create = subparsers.add_parser(
'create',
help='Create temporary test '
'databases and users.')
create.set_defaults(which='create')
create.add_argument(
'instances_count',
type=int,
help='Number of databases to create.')
drop = subparsers.add_parser(
'drop',
help='Drop temporary test databases and users.')
drop.set_defaults(which='drop')
drop.add_argument(
'instances',
nargs='+',
help='List of databases uri to be dropped.')
args = parser.parse_args()
engine = _get_engine()
which = args.which
if which == "create":
for i in range(int(args.instances_count)):
print(create_database(engine))
elif which == "drop":
for db in args.instances:
drop_database(engine, db)
if __name__ == "__main__":
main()

View File

@ -1,5 +1,3 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
@ -23,7 +21,7 @@ Initializing:
* Call set_defaults with the minimal of the following kwargs:
sql_connection, sqlite_db
Example:
Example::
session.set_defaults(
sql_connection="sqlite:///var/lib/ironic/sqlite.db",
@ -44,17 +42,17 @@ Recommended ways to use sessions within this framework:
functionality should be handled at a logical level. For an example, look at
the code around quotas and reservation_rollback().
Examples:
Examples::
def get_foo(context, foo):
return model_query(context, models.Foo).\
filter_by(foo=foo).\
first()
return (model_query(context, models.Foo).
filter_by(foo=foo).
first())
def update_foo(context, id, newfoo):
model_query(context, models.Foo).\
filter_by(id=id).\
update({'foo': newfoo})
(model_query(context, models.Foo).
filter_by(id=id).
update({'foo': newfoo}))
def create_foo(context, values):
foo_ref = models.Foo()
@ -68,14 +66,21 @@ Recommended ways to use sessions within this framework:
handler will take care of calling flush() and commit() for you.
If using this approach, you should not explicitly call flush() or commit().
Any error within the context of the session will cause the session to emit
a ROLLBACK. If the connection is dropped before this is possible, the
database will implicitly rollback the transaction.
a ROLLBACK. Database Errors like IntegrityError will be raised in
session's __exit__ handler, and any try/except within the context managed
by session will not be triggered. And catching other non-database errors in
the session will not trigger the ROLLBACK, so exception handlers should
always be outside the session, unless the developer wants to do a partial
commit on purpose. If the connection is dropped before this is possible,
the database will implicitly roll back the transaction.
Note: statements in the session scope will not be automatically retried.
If you create models within the session, they need to be added, but you
do not need to call model.save()
::
def create_many_foo(context, foos):
session = get_session()
with session.begin():
@ -87,33 +92,50 @@ Recommended ways to use sessions within this framework:
def update_bar(context, foo_id, newbar):
session = get_session()
with session.begin():
foo_ref = model_query(context, models.Foo, session).\
filter_by(id=foo_id).\
first()
model_query(context, models.Bar, session).\
filter_by(id=foo_ref['bar_id']).\
update({'bar': newbar})
foo_ref = (model_query(context, models.Foo, session).
filter_by(id=foo_id).
first())
(model_query(context, models.Bar, session).
filter_by(id=foo_ref['bar_id']).
update({'bar': newbar}))
Note: update_bar is a trivially simple example of using "with session.begin".
Whereas create_many_foo is a good example of when a transaction is needed,
it is always best to use as few queries as possible. The two queries in
update_bar can be better expressed using a single query which avoids
the need for an explicit transaction. It can be expressed like so:
the need for an explicit transaction. It can be expressed like so::
def update_bar(context, foo_id, newbar):
subq = model_query(context, models.Foo.id).\
filter_by(id=foo_id).\
limit(1).\
subquery()
model_query(context, models.Bar).\
filter_by(id=subq.as_scalar()).\
update({'bar': newbar})
subq = (model_query(context, models.Foo.id).
filter_by(id=foo_id).
limit(1).
subquery())
(model_query(context, models.Bar).
filter_by(id=subq.as_scalar()).
update({'bar': newbar}))
For reference, this emits approximagely the following SQL statement:
For reference, this emits approximately the following SQL statement::
UPDATE bar SET bar = ${newbar}
WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1);
Note: create_duplicate_foo is a trivially simple example of catching an
exception while using "with session.begin". Here create two duplicate
instances with same primary key, must catch the exception out of context
managed by a single session:
def create_duplicate_foo(context):
foo1 = models.Foo()
foo2 = models.Foo()
foo1.id = foo2.id = 1
session = get_session()
try:
with session.begin():
session.add(foo1)
session.add(foo2)
except exception.DBDuplicateEntry as e:
handle_error(e)
* Passing an active session between methods. Sessions should only be passed
to private methods. The private method must use a subtransaction; otherwise
SQLAlchemy will throw an error when you call session.begin() on an existing
@ -129,6 +151,8 @@ Recommended ways to use sessions within this framework:
becomes less clear in this situation. When this is needed for code clarity,
it should be clearly documented.
::
def myfunc(foo):
session = get_session()
with session.begin():
@ -173,7 +197,7 @@ There are some things which it is best to avoid:
Enabling soft deletes:
* To use/enable soft-deletes, the SoftDeleteMixin must be added
to your model class. For example:
to your model class. For example::
class NovaBase(models.SoftDeleteMixin, models.ModelBase):
pass
@ -181,14 +205,15 @@ Enabling soft deletes:
Efficient use of soft deletes:
* There are two possible ways to mark a record as deleted:
* There are two possible ways to mark a record as deleted::
model.soft_delete() and query.soft_delete().
model.soft_delete() method works with single already fetched entry.
query.soft_delete() makes only one db request for all entries that correspond
to query.
* In almost all cases you should use query.soft_delete(). Some examples:
* In almost all cases you should use query.soft_delete(). Some examples::
def soft_delete_bar():
count = model_query(BarModel).find(some_condition).soft_delete()
@ -199,9 +224,9 @@ Efficient use of soft deletes:
if session is None:
session = get_session()
with session.begin(subtransactions=True):
count = model_query(BarModel).\
find(some_condition).\
soft_delete(synchronize_session=True)
count = (model_query(BarModel).
find(some_condition).
soft_delete(synchronize_session=True))
# Here synchronize_session is required, because we
# don't know what is going on in outer session.
if count == 0:
@ -211,6 +236,8 @@ Efficient use of soft deletes:
you fetch a single record, work with it, and mark it as deleted in the same
transaction.
::
def soft_delete_bar_model():
session = get_session()
with session.begin():
@ -219,13 +246,13 @@ Efficient use of soft deletes:
bar_ref.soft_delete(session=session)
However, if you need to work with all entries that correspond to query and
then soft delete them you should use query.soft_delete() method:
then soft delete them you should use query.soft_delete() method::
def soft_delete_multi_models():
session = get_session()
with session.begin():
query = model_query(BarModel, session=session).\
find(some_condition)
query = (model_query(BarModel, session=session).
find(some_condition))
model_refs = query.all()
# Work with model_refs
query.soft_delete(synchronize_session=False)
@ -236,6 +263,8 @@ Efficient use of soft deletes:
which issues a single query. Using model.soft_delete(), as in the following
example, is very inefficient.
::
for bar_ref in bar_refs:
bar_ref.soft_delete(session=session)
# This will produce count(bar_refs) db requests.
@ -249,24 +278,23 @@ import time
from oslo.config import cfg
import six
from sqlalchemy import exc as sqla_exc
import sqlalchemy.interfaces
from sqlalchemy.interfaces import PoolListener
import sqlalchemy.orm
from sqlalchemy.pool import NullPool, StaticPool
from sqlalchemy.sql.expression import literal_column
from ironic.openstack.common.db import exception
from ironic.openstack.common.gettextutils import _ # noqa
from ironic.openstack.common.gettextutils import _
from ironic.openstack.common import log as logging
from ironic.openstack.common import timeutils
sqlite_db_opts = [
cfg.StrOpt('sqlite_db',
default='ironic.sqlite',
help='the filename to use with sqlite'),
help='The file name to use with SQLite'),
cfg.BoolOpt('sqlite_synchronous',
default=True,
help='If true, use synchronous mode for sqlite'),
help='If True, SQLite uses synchronous mode'),
]
database_opts = [
@ -276,6 +304,7 @@ database_opts = [
'../', '$sqlite_db')),
help='The SQLAlchemy connection string used to connect to the '
'database',
secret=True,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_connection',
@ -284,6 +313,7 @@ database_opts = [
group='sql'), ]),
cfg.StrOpt('slave_connection',
default='',
secret=True,
help='The SQLAlchemy connection string used to connect to the '
'slave database'),
cfg.IntOpt('idle_timeout',
@ -291,8 +321,10 @@ database_opts = [
deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_idle_timeout',
group='DATABASE')],
help='timeout before idle sql connections are reaped'),
group='DATABASE'),
cfg.DeprecatedOpt('idle_timeout',
group='sql')],
help='Timeout before idle sql connections are reaped'),
cfg.IntOpt('min_pool_size',
default=1,
deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size',
@ -315,7 +347,7 @@ database_opts = [
group='DEFAULT'),
cfg.DeprecatedOpt('sql_max_retries',
group='DATABASE')],
help='maximum db connection retries during startup. '
help='Maximum db connection retries during startup. '
'(setting -1 implies an infinite retry count)'),
cfg.IntOpt('retry_interval',
default=10,
@ -323,7 +355,7 @@ database_opts = [
group='DEFAULT'),
cfg.DeprecatedOpt('reconnect_interval',
group='DATABASE')],
help='interval between retries of opening a sql connection'),
help='Interval between retries of opening a sql connection'),
cfg.IntOpt('max_overflow',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow',
@ -409,8 +441,8 @@ class SqliteForeignKeysListener(PoolListener):
dbapi_con.execute('pragma foreign_keys=ON')
def get_session(autocommit=True, expire_on_commit=False,
sqlite_fk=False, slave_session=False):
def get_session(autocommit=True, expire_on_commit=False, sqlite_fk=False,
slave_session=False, mysql_traditional_mode=False):
"""Return a SQLAlchemy session."""
global _MAKER
global _SLAVE_MAKER
@ -420,7 +452,8 @@ def get_session(autocommit=True, expire_on_commit=False,
maker = _SLAVE_MAKER
if maker is None:
engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session)
engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session,
mysql_traditional_mode=mysql_traditional_mode)
maker = get_maker(engine, autocommit, expire_on_commit)
if slave_session:
@ -439,6 +472,11 @@ def get_session(autocommit=True, expire_on_commit=False,
# 1 column - (IntegrityError) column c1 is not unique
# N columns - (IntegrityError) column c1, c2, ..., N are not unique
#
# sqlite since 3.7.16:
# 1 column - (IntegrityError) UNIQUE constraint failed: k1
#
# N columns - (IntegrityError) UNIQUE constraint failed: k1, k2
#
# postgres:
# 1 column - (IntegrityError) duplicate key value violates unique
# constraint "users_c1_key"
@ -451,9 +489,10 @@ def get_session(autocommit=True, expire_on_commit=False,
# N columns - (IntegrityError) (1062, "Duplicate entry 'values joined
# with -' for key 'name_of_our_constraint'")
_DUP_KEY_RE_DB = {
"sqlite": re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"),
"postgresql": re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"),
"mysql": re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$")
"sqlite": (re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"),
re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(.+)$")),
"postgresql": (re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"),),
"mysql": (re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$"),)
}
@ -483,10 +522,14 @@ def _raise_if_duplicate_entry_error(integrity_error, engine_name):
# SQLAlchemy can differ when using unicode() and accessing .message.
# An audit across all three supported engines will be necessary to
# ensure there are no regressions.
m = _DUP_KEY_RE_DB[engine_name].match(integrity_error.message)
if not m:
for pattern in _DUP_KEY_RE_DB[engine_name]:
match = pattern.match(integrity_error.message)
if match:
break
else:
return
columns = m.group(1)
columns = match.group(1)
if engine_name == "sqlite":
columns = columns.strip().split(", ")
@ -555,7 +598,8 @@ def _wrap_db_error(f):
return _wrap
def get_engine(sqlite_fk=False, slave_engine=False):
def get_engine(sqlite_fk=False, slave_engine=False,
mysql_traditional_mode=False):
"""Return a SQLAlchemy engine."""
global _ENGINE
global _SLAVE_ENGINE
@ -567,8 +611,8 @@ def get_engine(sqlite_fk=False, slave_engine=False):
db_uri = CONF.database.slave_connection
if engine is None:
engine = create_engine(db_uri,
sqlite_fk=sqlite_fk)
engine = create_engine(db_uri, sqlite_fk=sqlite_fk,
mysql_traditional_mode=mysql_traditional_mode)
if slave_engine:
_SLAVE_ENGINE = engine
else:
@ -603,22 +647,39 @@ def _thread_yield(dbapi_con, con_record):
time.sleep(0)
def _ping_listener(dbapi_conn, connection_rec, connection_proxy):
"""Ensures that MySQL connections checked out of the pool are alive.
def _ping_listener(engine, dbapi_conn, connection_rec, connection_proxy):
"""Ensures that MySQL and DB2 connections are alive.
Borrowed from:
http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f
"""
cursor = dbapi_conn.cursor()
try:
dbapi_conn.cursor().execute('select 1')
except dbapi_conn.OperationalError as ex:
if ex.args[0] in (2006, 2013, 2014, 2045, 2055):
LOG.warn(_('Got mysql server has gone away: %s'), ex)
raise sqla_exc.DisconnectionError("Database server went away")
ping_sql = 'select 1'
if engine.name == 'ibm_db_sa':
# DB2 requires a table expression
ping_sql = 'select 1 from (values (1)) AS t1'
cursor.execute(ping_sql)
except Exception as ex:
if engine.dialect.is_disconnect(ex, dbapi_conn, cursor):
msg = _('Database server has gone away: %s') % ex
LOG.warning(msg)
raise sqla_exc.DisconnectionError(msg)
else:
raise
def _set_mode_traditional(dbapi_con, connection_rec, connection_proxy):
"""Set engine mode to 'traditional'.
Required to prevent silent truncates at insert or update operations
under MySQL. By default MySQL truncates inserted string if it longer
than a declared field just with warning. That is fraught with data
corruption.
"""
dbapi_con.cursor().execute("SET SESSION sql_mode = TRADITIONAL;")
def _is_db_connection_error(args):
"""Return True if error in connecting to db."""
# NOTE(adam_g): This is currently MySQL specific and needs to be extended
@ -631,7 +692,8 @@ def _is_db_connection_error(args):
return False
def create_engine(sql_connection, sqlite_fk=False):
def create_engine(sql_connection, sqlite_fk=False,
mysql_traditional_mode=False):
"""Return a new SQLAlchemy engine."""
# NOTE(geekinutah): At this point we could be connecting to the normal
# db handle or the slave db handle. Things like
@ -672,8 +734,16 @@ def create_engine(sql_connection, sqlite_fk=False):
sqlalchemy.event.listen(engine, 'checkin', _thread_yield)
if 'mysql' in connection_dict.drivername:
sqlalchemy.event.listen(engine, 'checkout', _ping_listener)
if engine.name in ['mysql', 'ibm_db_sa']:
callback = functools.partial(_ping_listener, engine)
sqlalchemy.event.listen(engine, 'checkout', callback)
if mysql_traditional_mode:
sqlalchemy.event.listen(engine, 'checkout', _set_mode_traditional)
else:
LOG.warning(_("This application has not enabled MySQL traditional"
" mode, which means silent data corruption may"
" occur. Please encourage the application"
" developers to enable this mode."))
elif 'sqlite' in connection_dict.drivername:
if not CONF.sqlite_synchronous:
sqlalchemy.event.listen(engine, 'connect',
@ -695,7 +765,7 @@ def create_engine(sql_connection, sqlite_fk=False):
remaining = 'infinite'
while True:
msg = _('SQL connection failed. %s attempts left.')
LOG.warn(msg % remaining)
LOG.warning(msg % remaining)
if remaining != 'infinite':
remaining -= 1
time.sleep(CONF.database.retry_interval)

View File

@ -0,0 +1,306 @@
# Copyright 2010-2011 OpenStack Foundation
# Copyright 2012-2013 IBM Corp.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import os
import subprocess
import lockfile
from six import moves
import sqlalchemy
import sqlalchemy.exc
from ironic.openstack.common.gettextutils import _
from ironic.openstack.common import log as logging
from ironic.openstack.common.py3kcompat import urlutils
from ironic.openstack.common import test
LOG = logging.getLogger(__name__)
def _get_connect_string(backend, user, passwd, database):
"""Get database connection
Try to get a connection with a very specific set of values, if we get
these then we'll run the tests, otherwise they are skipped
"""
if backend == "postgres":
backend = "postgresql+psycopg2"
elif backend == "mysql":
backend = "mysql+mysqldb"
else:
raise Exception("Unrecognized backend: '%s'" % backend)
return ("%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"
% {'backend': backend, 'user': user, 'passwd': passwd,
'database': database})
def _is_backend_avail(backend, user, passwd, database):
try:
connect_uri = _get_connect_string(backend, user, passwd, database)
engine = sqlalchemy.create_engine(connect_uri)
connection = engine.connect()
except Exception:
# intentionally catch all to handle exceptions even if we don't
# have any backend code loaded.
return False
else:
connection.close()
engine.dispose()
return True
def _have_mysql(user, passwd, database):
present = os.environ.get('TEST_MYSQL_PRESENT')
if present is None:
return _is_backend_avail('mysql', user, passwd, database)
return present.lower() in ('', 'true')
def _have_postgresql(user, passwd, database):
present = os.environ.get('TEST_POSTGRESQL_PRESENT')
if present is None:
return _is_backend_avail('postgres', user, passwd, database)
return present.lower() in ('', 'true')
def get_db_connection_info(conn_pieces):
database = conn_pieces.path.strip('/')
loc_pieces = conn_pieces.netloc.split('@')
host = loc_pieces[1]
auth_pieces = loc_pieces[0].split(':')
user = auth_pieces[0]
password = ""
if len(auth_pieces) > 1:
password = auth_pieces[1].strip()
return (user, password, database, host)
def _set_db_lock(lock_path=None, lock_prefix=None):
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
path = lock_path or os.environ.get("IRONIC_LOCK_PATH")
lock = lockfile.FileLock(os.path.join(path, lock_prefix))
with lock:
LOG.debug(_('Got lock "%s"') % f.__name__)
return f(*args, **kwargs)
finally:
LOG.debug(_('Lock released "%s"') % f.__name__)
return wrapper
return decorator
class BaseMigrationTestCase(test.BaseTestCase):
"""Base class fort testing of migration utils."""
def __init__(self, *args, **kwargs):
super(BaseMigrationTestCase, self).__init__(*args, **kwargs)
self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__),
'test_migrations.conf')
# Test machines can set the TEST_MIGRATIONS_CONF variable
# to override the location of the config file for migration testing
self.CONFIG_FILE_PATH = os.environ.get('TEST_MIGRATIONS_CONF',
self.DEFAULT_CONFIG_FILE)
self.test_databases = {}
self.migration_api = None
def setUp(self):
super(BaseMigrationTestCase, self).setUp()
# Load test databases from the config file. Only do this
# once. No need to re-run this on each test...
LOG.debug('config_path is %s' % self.CONFIG_FILE_PATH)
if os.path.exists(self.CONFIG_FILE_PATH):
cp = moves.configparser.RawConfigParser()
try:
cp.read(self.CONFIG_FILE_PATH)
defaults = cp.defaults()
for key, value in defaults.items():
self.test_databases[key] = value
except moves.configparser.ParsingError as e:
self.fail("Failed to read test_migrations.conf config "
"file. Got error: %s" % e)
else:
self.fail("Failed to find test_migrations.conf config "
"file.")
self.engines = {}
for key, value in self.test_databases.items():
self.engines[key] = sqlalchemy.create_engine(value)
# We start each test case with a completely blank slate.
self._reset_databases()
def tearDown(self):
# We destroy the test data store between each test case,
# and recreate it, which ensures that we have no side-effects
# from the tests
self._reset_databases()
super(BaseMigrationTestCase, self).tearDown()
def execute_cmd(self, cmd=None):
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
output = process.communicate()[0]
LOG.debug(output)
self.assertEqual(0, process.returncode,
"Failed to run: %s\n%s" % (cmd, output))
def _reset_pg(self, conn_pieces):
(user, password, database, host) = get_db_connection_info(conn_pieces)
os.environ['PGPASSWORD'] = password
os.environ['PGUSER'] = user
# note(boris-42): We must create and drop database, we can't
# drop database which we have connected to, so for such
# operations there is a special database template1.
sqlcmd = ("psql -w -U %(user)s -h %(host)s -c"
" '%(sql)s' -d template1")
sql = ("drop database if exists %s;") % database
droptable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
self.execute_cmd(droptable)
sql = ("create database %s;") % database
createtable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
self.execute_cmd(createtable)
os.unsetenv('PGPASSWORD')
os.unsetenv('PGUSER')
@_set_db_lock(lock_prefix='migration_tests-')
def _reset_databases(self):
for key, engine in self.engines.items():
conn_string = self.test_databases[key]
conn_pieces = urlutils.urlparse(conn_string)
engine.dispose()
if conn_string.startswith('sqlite'):
# We can just delete the SQLite database, which is
# the easiest and cleanest solution
db_path = conn_pieces.path.strip('/')
if os.path.exists(db_path):
os.unlink(db_path)
# No need to recreate the SQLite DB. SQLite will
# create it for us if it's not there...
elif conn_string.startswith('mysql'):
# We can execute the MySQL client to destroy and re-create
# the MYSQL database, which is easier and less error-prone
# than using SQLAlchemy to do this via MetaData...trust me.
(user, password, database, host) = \
get_db_connection_info(conn_pieces)
sql = ("drop database if exists %(db)s; "
"create database %(db)s;") % {'db': database}
cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s "
"-e \"%(sql)s\"") % {'user': user, 'password': password,
'host': host, 'sql': sql}
self.execute_cmd(cmd)
elif conn_string.startswith('postgresql'):
self._reset_pg(conn_pieces)
class WalkVersionsMixin(object):
def _walk_versions(self, engine=None, snake_walk=False, downgrade=True):
# Determine latest version script from the repo, then
# upgrade from 1 through to the latest, with no data
# in the databases. This just checks that the schema itself
# upgrades successfully.
# Place the database under version control
self.migration_api.version_control(engine, self.REPOSITORY,
self.INIT_VERSION)
self.assertEqual(self.INIT_VERSION,
self.migration_api.db_version(engine,
self.REPOSITORY))
LOG.debug('latest version is %s' % self.REPOSITORY.latest)
versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
for version in versions:
# upgrade -> downgrade -> upgrade
self._migrate_up(engine, version, with_data=True)
if snake_walk:
downgraded = self._migrate_down(
engine, version - 1, with_data=True)
if downgraded:
self._migrate_up(engine, version)
if downgrade:
# Now walk it back down to 0 from the latest, testing
# the downgrade paths.
for version in reversed(versions):
# downgrade -> upgrade -> downgrade
downgraded = self._migrate_down(engine, version - 1)
if snake_walk and downgraded:
self._migrate_up(engine, version)
self._migrate_down(engine, version - 1)
def _migrate_down(self, engine, version, with_data=False):
try:
self.migration_api.downgrade(engine, self.REPOSITORY, version)
except NotImplementedError:
# NOTE(sirp): some migrations, namely release-level
# migrations, don't support a downgrade.
return False
self.assertEqual(
version, self.migration_api.db_version(engine, self.REPOSITORY))
# NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target'
# version). So if we have any downgrade checks, they need to be run for
# the previous (higher numbered) migration.
if with_data:
post_downgrade = getattr(
self, "_post_downgrade_%03d" % (version + 1), None)
if post_downgrade:
post_downgrade(engine)
return True
def _migrate_up(self, engine, version, with_data=False):
"""migrate up to a new version of the db.
We allow for data insertion and post checks at every
migration version with special _pre_upgrade_### and
_check_### functions in the main test.
"""
# NOTE(sdague): try block is here because it's impossible to debug
# where a failed data migration happens otherwise
try:
if with_data:
data = None
pre_upgrade = getattr(
self, "_pre_upgrade_%03d" % version, None)
if pre_upgrade:
data = pre_upgrade(engine)
self.migration_api.upgrade(engine, self.REPOSITORY, version)
self.assertEqual(version,
self.migration_api.db_version(engine,
self.REPOSITORY))
if with_data:
check = getattr(self, "_check_%03d" % version, None)
if check:
check(engine, data)
except Exception:
LOG.error("Failed to migrate to version %s on engine %s" %
(version, engine))
raise

View File

@ -1,5 +1,3 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2010-2011 OpenStack Foundation.
@ -38,7 +36,7 @@ from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.types import NullType
from ironic.openstack.common.gettextutils import _ # noqa
from ironic.openstack.common.gettextutils import _
from ironic.openstack.common import log as logging
from ironic.openstack.common import timeutils
@ -96,7 +94,7 @@ def paginate_query(query, model, limit, sort_keys, marker=None,
if 'id' not in sort_keys:
# TODO(justinsb): If this ever gives a false-positive, check
# the actual primary key, rather than assuming its id
LOG.warn(_('Id not in sort_keys; is sort_keys unique?'))
LOG.warning(_('Id not in sort_keys; is sort_keys unique?'))
assert(not (sort_dir and sort_dirs))
@ -135,9 +133,9 @@ def paginate_query(query, model, limit, sort_keys, marker=None,
# Build up an array of sort criteria as in the docstring
criteria_list = []
for i in range(0, len(sort_keys)):
for i in range(len(sort_keys)):
crit_attrs = []
for j in range(0, i):
for j in range(i):
model_attr = getattr(model, sort_keys[j])
crit_attrs.append((model_attr == marker_values[j]))

View File

@ -0,0 +1,67 @@
#
# Copyright 2013 Canonical Ltd.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
"""
Python2/Python3 compatibility layer for OpenStack
"""
import six
if six.PY3:
# python3
import urllib.error
import urllib.parse
import urllib.request
urlencode = urllib.parse.urlencode
urljoin = urllib.parse.urljoin
quote = urllib.parse.quote
quote_plus = urllib.parse.quote_plus
parse_qsl = urllib.parse.parse_qsl
unquote = urllib.parse.unquote
unquote_plus = urllib.parse.unquote_plus
urlparse = urllib.parse.urlparse
urlsplit = urllib.parse.urlsplit
urlunsplit = urllib.parse.urlunsplit
SplitResult = urllib.parse.SplitResult
urlopen = urllib.request.urlopen
URLError = urllib.error.URLError
pathname2url = urllib.request.pathname2url
else:
# python2
import urllib
import urllib2
import urlparse
urlencode = urllib.urlencode
quote = urllib.quote
quote_plus = urllib.quote_plus
unquote = urllib.unquote
unquote_plus = urllib.unquote_plus
parse = urlparse
parse_qsl = parse.parse_qsl
urljoin = parse.urljoin
urlparse = parse.urlparse
urlsplit = parse.urlsplit
urlunsplit = parse.urlunsplit
SplitResult = parse.SplitResult
urlopen = urllib2.urlopen
URLError = urllib2.URLError
pathname2url = urllib.pathname2url

View File

@ -0,0 +1,71 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Common utilities used in testing"""
import logging
import os
import fixtures
import testtools
_TRUE_VALUES = ('True', 'true', '1', 'yes')
_LOG_FORMAT = "%(levelname)8s [%(name)s] %(message)s"
class BaseTestCase(testtools.TestCase):
def setUp(self):
super(BaseTestCase, self).setUp()
self._set_timeout()
self._fake_output()
self._fake_logs()
self.useFixture(fixtures.NestedTempfile())
self.useFixture(fixtures.TempHomeDir())
def _set_timeout(self):
test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0)
try:
test_timeout = int(test_timeout)
except ValueError:
# If timeout value is invalid do not set a timeout.
test_timeout = 0
if test_timeout > 0:
self.useFixture(fixtures.Timeout(test_timeout, gentle=True))
def _fake_output(self):
if os.environ.get('OS_STDOUT_CAPTURE') in _TRUE_VALUES:
stdout = self.useFixture(fixtures.StringStream('stdout')).stream
self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout))
if os.environ.get('OS_STDERR_CAPTURE') in _TRUE_VALUES:
stderr = self.useFixture(fixtures.StringStream('stderr')).stream
self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr))
def _fake_logs(self):
if os.environ.get('OS_DEBUG') in _TRUE_VALUES:
level = logging.DEBUG
else:
level = logging.INFO
capture_logs = os.environ.get('OS_LOG_CAPTURE') in _TRUE_VALUES
if capture_logs:
self.useFixture(
fixtures.FakeLogger(
format=_LOG_FORMAT,
level=level,
nuke_handlers=capture_logs,
)
)
else:
logging.basicConfig(format=_LOG_FORMAT, level=level)

View File

@ -5,6 +5,7 @@ module=config.generator
module=context
module=db
module=db.sqlalchemy
module=db.sqlalchemy.migration_cli
module=eventlet_backdoor
module=excutils
module=fileutils
@ -23,10 +24,12 @@ module=notifier
module=periodic_task
module=policy
module=processutils
module=py3kcompat
module=rpc
module=setup
module=strutils
module=timeutils
module=test
module=version
# The base module to hold the copy of openstack.common