From 43b253cd60bd68bce108969d6b3adc0565b8f95e Mon Sep 17 00:00:00 2001 From: Stephen Finucane Date: Thu, 1 Apr 2021 12:14:33 +0100 Subject: [PATCH] db: Post reshuffle cleanup Introduce a new 'nova.db.api.api' module to hold API database-specific helpers, plus a generic 'nova.db.utils' module to hold code suitable for both main and API databases. This highlights a level of complexity around connection management that is present for the main database but not for the API database. This is because we need to handle the complexity of cells for the former but not the latter. Change-Id: Ia5304c552ce552ae3c5223a2bfb3a9cd543ec57c Signed-off-by: Stephen Finucane --- nova/cmd/status.py | 7 +- nova/compute/api.py | 11 +- nova/config.py | 6 +- nova/db/api/api.py | 50 +++++++ nova/db/main/api.py | 130 ++--------------- nova/db/migration.py | 7 +- nova/db/utils.py | 109 ++++++++++++++ nova/objects/aggregate.py | 28 ++-- nova/objects/build_request.py | 17 ++- nova/objects/cell_mapping.py | 46 +++--- nova/objects/flavor.py | 31 ++-- nova/objects/host_mapping.py | 14 +- nova/objects/instance_group.py | 26 ++-- nova/objects/instance_mapping.py | 28 ++-- nova/objects/keypair.py | 24 +-- nova/objects/quotas.py | 99 +++++++------ nova/objects/request_spec.py | 12 +- nova/objects/virtual_interface.py | 45 +++--- nova/quota.py | 10 +- nova/tests/fixtures/nova.py | 42 ++++-- nova/tests/functional/db/test_aggregate.py | 10 +- nova/tests/functional/db/test_flavor.py | 4 +- .../tests/unit/api/openstack/test_wsgi_app.py | 12 +- nova/tests/unit/conductor/test_conductor.py | 9 +- nova/tests/unit/db/api/__init__.py | 0 nova/tests/unit/db/api/test_api.py | 24 +++ nova/tests/unit/db/main/test_api.py | 138 ++---------------- nova/tests/unit/db/test_migration.py | 15 +- nova/tests/unit/db/test_utils.py | 123 ++++++++++++++++ nova/tests/unit/objects/test_flavor.py | 4 +- nova/tests/unit/test_conf.py | 13 +- nova/tests/unit/test_fixtures.py | 15 +- 32 files changed, 625 insertions(+), 484 deletions(-) create mode 100644 nova/db/api/api.py create mode 100644 nova/db/utils.py create mode 100644 nova/tests/unit/db/api/__init__.py create mode 100644 nova/tests/unit/db/api/test_api.py create mode 100644 nova/tests/unit/db/test_utils.py diff --git a/nova/cmd/status.py b/nova/cmd/status.py index 8e8f8b0dcdd8..2f14e3fed2ff 100644 --- a/nova/cmd/status.py +++ b/nova/cmd/status.py @@ -32,7 +32,8 @@ from nova.cmd import common as cmd_common import nova.conf from nova import config from nova import context as nova_context -from nova.db.main import api as db_session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova.i18n import _ from nova.objects import cell_mapping as cell_mapping_obj @@ -86,7 +87,7 @@ class UpgradeCommands(upgradecheck.UpgradeCommands): # table, or by only counting compute nodes with a service version of at # least 15 which was the highest service version when Newton was # released. - meta = sa.MetaData(bind=db_session.get_engine(context=context)) + meta = sa.MetaData(bind=main_db_api.get_engine(context=context)) compute_nodes = sa.Table('compute_nodes', meta, autoload=True) return sa.select([sqlfunc.count()]).select_from(compute_nodes).where( compute_nodes.c.deleted == 0).scalar() @@ -103,7 +104,7 @@ class UpgradeCommands(upgradecheck.UpgradeCommands): for compute nodes if there are no host mappings on a fresh install. """ meta = sa.MetaData() - meta.bind = db_session.get_api_engine() + meta.bind = api_db_api.get_engine() cell_mappings = self._get_cell_mappings() count = len(cell_mappings) diff --git a/nova/compute/api.py b/nova/compute/api.py index 909cd52962c8..bc75399af70c 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -53,7 +53,8 @@ from nova import conductor import nova.conf from nova import context as nova_context from nova import crypto -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova import exception_wrapper from nova.i18n import _ @@ -1081,7 +1082,7 @@ class API: network_metadata) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_reqspec_buildreq_instmapping(context, rs, br, im): """Create the request spec, build request, and instance mapping in a single database transaction. @@ -5082,7 +5083,7 @@ class API: def get_instance_metadata(self, context, instance): """Get all metadata associated with an instance.""" - return db_api.instance_metadata_get(context, instance.uuid) + return main_db_api.instance_metadata_get(context, instance.uuid) @check_instance_lock @check_instance_state(vm_state=[vm_states.ACTIVE, vm_states.PAUSED, @@ -5962,7 +5963,7 @@ class HostAPI: """Return the task logs within a given range, optionally filtering by host and/or state. """ - return db_api.task_log_get_all( + return main_db_api.task_log_get_all( context, task_name, period_beginning, period_ending, host=host, state=state) @@ -6055,7 +6056,7 @@ class HostAPI: if cell.uuid == objects.CellMapping.CELL0_UUID: continue with nova_context.target_cell(context, cell) as cctxt: - cell_stats.append(db_api.compute_node_statistics(cctxt)) + cell_stats.append(main_db_api.compute_node_statistics(cctxt)) if cell_stats: keys = cell_stats[0].keys() diff --git a/nova/config.py b/nova/config.py index 1af845897646..98aad0c73f18 100644 --- a/nova/config.py +++ b/nova/config.py @@ -22,7 +22,8 @@ from oslo_policy import opts as policy_opts from oslo_utils import importutils import nova.conf -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import middleware from nova import policy from nova import rpc @@ -100,4 +101,5 @@ def parse_args(argv, default_config_files=None, configure_db=True, rpc.init(CONF) if configure_db: - db_api.configure(CONF) + main_db_api.configure(CONF) + api_db_api.configure(CONF) diff --git a/nova/db/api/api.py b/nova/db/api/api.py new file mode 100644 index 000000000000..713562e4338e --- /dev/null +++ b/nova/db/api/api.py @@ -0,0 +1,50 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_db.sqlalchemy import enginefacade +from oslo_utils import importutils +import sqlalchemy as sa + +import nova.conf + +profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy') + +CONF = nova.conf.CONF + +context_manager = enginefacade.transaction_context() + +# NOTE(stephenfin): We don't need equivalents of the 'get_context_manager' or +# 'create_context_manager' APIs found in 'nova.db.main.api' since we don't need +# to be cell-aware here + + +def _get_db_conf(conf_group, connection=None): + kw = dict(conf_group.items()) + if connection is not None: + kw['connection'] = connection + return kw + + +def configure(conf): + context_manager.configure(**_get_db_conf(conf.api_database)) + + if ( + profiler_sqlalchemy and + CONF.profiler.enabled and + CONF.profiler.trace_sqlalchemy + ): + context_manager.append_on_engine_create( + lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) + + +def get_engine(): + return context_manager.writer.get_engine() diff --git a/nova/db/main/api.py b/nova/db/main/api.py index b19d761b5cd4..83e390bf6d37 100644 --- a/nova/db/main/api.py +++ b/nova/db/main/api.py @@ -22,7 +22,6 @@ import copy import datetime import functools import inspect -import sys import traceback from oslo_db import api as oslo_db_api @@ -49,6 +48,8 @@ from nova.compute import vm_states import nova.conf import nova.context from nova.db.main import models +from nova.db import utils as db_utils +from nova.db.utils import require_context from nova import exception from nova.i18n import _ from nova import safe_utils @@ -60,8 +61,7 @@ LOG = logging.getLogger(__name__) DISABLE_DB_ACCESS = False -main_context_manager = enginefacade.transaction_context() -api_context_manager = enginefacade.transaction_context() +context_manager = enginefacade.transaction_context() def _get_db_conf(conf_group, connection=None): @@ -89,15 +89,14 @@ def _joinedload_all(column): def configure(conf): - main_context_manager.configure(**_get_db_conf(conf.database)) - api_context_manager.configure(**_get_db_conf(conf.api_database)) + context_manager.configure(**_get_db_conf(conf.database)) - if profiler_sqlalchemy and CONF.profiler.enabled \ - and CONF.profiler.trace_sqlalchemy: - - main_context_manager.append_on_engine_create( - lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) - api_context_manager.append_on_engine_create( + if ( + profiler_sqlalchemy and + CONF.profiler.enabled and + CONF.profiler.trace_sqlalchemy + ): + context_manager.append_on_engine_create( lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) @@ -116,7 +115,7 @@ def get_context_manager(context): :param context: The request context that can contain a context manager """ - return _context_manager_from_context(context) or main_context_manager + return _context_manager_from_context(context) or context_manager def get_engine(use_slave=False, context=None): @@ -131,40 +130,11 @@ def get_engine(use_slave=False, context=None): return ctxt_mgr.writer.get_engine() -def get_api_engine(): - return api_context_manager.writer.get_engine() - - _SHADOW_TABLE_PREFIX = 'shadow_' _DEFAULT_QUOTA_NAME = 'default' PER_PROJECT_QUOTAS = ['fixed_ips', 'floating_ips', 'networks'] -# NOTE(stephenfin): This is required and used by oslo.db -def get_backend(): - """The backend is this module itself.""" - return sys.modules[__name__] - - -def require_context(f): - """Decorator to require *any* user or admin context. - - This does no authorization for user or project access matching, see - :py:func:`nova.context.authorize_project_context` and - :py:func:`nova.context.authorize_user_context`. - - The first argument to the wrapped function must be the context. - - """ - - @functools.wraps(f) - def wrapper(*args, **kwargs): - nova.context.require_context(args[0]) - return f(*args, **kwargs) - wrapper.__signature__ = inspect.signature(f) - return wrapper - - def select_db_reader_mode(f): """Decorator to select synchronous or asynchronous reader mode. @@ -1662,9 +1632,8 @@ def instance_get_all_by_filters_sort(context, filters, limit=None, marker=None, if limit == 0: return [] - sort_keys, sort_dirs = process_sort_params(sort_keys, - sort_dirs, - default_dir='desc') + sort_keys, sort_dirs = db_utils.process_sort_params( + sort_keys, sort_dirs, default_dir='desc') if columns_to_join is None: columns_to_join_new = ['info_cache', 'security_groups'] @@ -2043,75 +2012,6 @@ def _exact_instance_filter(query, filters, legal_keys): return query -def process_sort_params(sort_keys, sort_dirs, - default_keys=['created_at', 'id'], - default_dir='asc'): - """Process the sort parameters to include default keys. - - Creates a list of sort keys and a list of sort directions. Adds the default - keys to the end of the list if they are not already included. - - When adding the default keys to the sort keys list, the associated - direction is: - 1) The first element in the 'sort_dirs' list (if specified), else - 2) 'default_dir' value (Note that 'asc' is the default value since this is - the default in sqlalchemy.utils.paginate_query) - - :param sort_keys: List of sort keys to include in the processed list - :param sort_dirs: List of sort directions to include in the processed list - :param default_keys: List of sort keys that need to be included in the - processed list, they are added at the end of the list - if not already specified. - :param default_dir: Sort direction associated with each of the default - keys that are not supplied, used when they are added - to the processed list - :returns: list of sort keys, list of sort directions - :raise exception.InvalidInput: If more sort directions than sort keys - are specified or if an invalid sort - direction is specified - """ - # Determine direction to use for when adding default keys - if sort_dirs and len(sort_dirs) != 0: - default_dir_value = sort_dirs[0] - else: - default_dir_value = default_dir - - # Create list of keys (do not modify the input list) - if sort_keys: - result_keys = list(sort_keys) - else: - result_keys = [] - - # If a list of directions is not provided, use the default sort direction - # for all provided keys - if sort_dirs: - result_dirs = [] - # Verify sort direction - for sort_dir in sort_dirs: - if sort_dir not in ('asc', 'desc'): - msg = _("Unknown sort direction, must be 'desc' or 'asc'") - raise exception.InvalidInput(reason=msg) - result_dirs.append(sort_dir) - else: - result_dirs = [default_dir_value for _sort_key in result_keys] - - # Ensure that the key and direction length match - while len(result_dirs) < len(result_keys): - result_dirs.append(default_dir_value) - # Unless more direction are specified, which is an error - if len(result_dirs) > len(result_keys): - msg = _("Sort direction size exceeds sort key size") - raise exception.InvalidInput(reason=msg) - - # Ensure defaults are included - for key in default_keys: - if key not in result_keys: - result_keys.append(key) - result_dirs.append(default_dir_value) - - return result_keys, result_dirs - - @require_context @pick_context_manager_reader_allow_async def instance_get_active_by_window_joined(context, begin, end=None, @@ -3507,8 +3407,8 @@ def migration_get_all_by_filters(context, filters, raise exception.MarkerNotFound(marker=marker) if limit or marker or sort_keys or sort_dirs: # Default sort by desc(['created_at', 'id']) - sort_keys, sort_dirs = process_sort_params(sort_keys, sort_dirs, - default_dir='desc') + sort_keys, sort_dirs = db_utils.process_sort_params( + sort_keys, sort_dirs, default_dir='desc') return sqlalchemyutils.paginate_query(query, models.Migration, limit=limit, diff --git a/nova/db/migration.py b/nova/db/migration.py index a1e49a216025..d5ae303a3ee3 100644 --- a/nova/db/migration.py +++ b/nova/db/migration.py @@ -22,7 +22,8 @@ from migrate.versioning.repository import Repository from oslo_log import log as logging import sqlalchemy -from nova.db.main import api as db_session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova.i18n import _ @@ -36,10 +37,10 @@ LOG = logging.getLogger(__name__) def get_engine(database='main', context=None): if database == 'main': - return db_session.get_engine(context=context) + return main_db_api.get_engine(context=context) if database == 'api': - return db_session.get_api_engine() + return api_db_api.get_engine() def find_migrate_repo(database='main'): diff --git a/nova/db/utils.py b/nova/db/utils.py new file mode 100644 index 000000000000..234845a359be --- /dev/null +++ b/nova/db/utils.py @@ -0,0 +1,109 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import inspect + +import nova.context +from nova import exception +from nova.i18n import _ + + +def require_context(f): + """Decorator to require *any* user or admin context. + + This does no authorization for user or project access matching, see + :py:func:`nova.context.authorize_project_context` and + :py:func:`nova.context.authorize_user_context`. + + The first argument to the wrapped function must be the context. + + """ + + @functools.wraps(f) + def wrapper(*args, **kwargs): + nova.context.require_context(args[0]) + return f(*args, **kwargs) + wrapper.__signature__ = inspect.signature(f) + return wrapper + + +def process_sort_params( + sort_keys, + sort_dirs, + default_keys=['created_at', 'id'], + default_dir='asc', +): + """Process the sort parameters to include default keys. + + Creates a list of sort keys and a list of sort directions. Adds the default + keys to the end of the list if they are not already included. + + When adding the default keys to the sort keys list, the associated + direction is: + + 1. The first element in the 'sort_dirs' list (if specified), else + 2. 'default_dir' value (Note that 'asc' is the default value since this is + the default in sqlalchemy.utils.paginate_query) + + :param sort_keys: List of sort keys to include in the processed list + :param sort_dirs: List of sort directions to include in the processed list + :param default_keys: List of sort keys that need to be included in the + processed list, they are added at the end of the list if not already + specified. + :param default_dir: Sort direction associated with each of the default + keys that are not supplied, used when they are added to the processed + list + :returns: list of sort keys, list of sort directions + :raise exception.InvalidInput: If more sort directions than sort keys + are specified or if an invalid sort direction is specified + """ + # Determine direction to use for when adding default keys + if sort_dirs and len(sort_dirs) != 0: + default_dir_value = sort_dirs[0] + else: + default_dir_value = default_dir + + # Create list of keys (do not modify the input list) + if sort_keys: + result_keys = list(sort_keys) + else: + result_keys = [] + + # If a list of directions is not provided, use the default sort direction + # for all provided keys + if sort_dirs: + result_dirs = [] + # Verify sort direction + for sort_dir in sort_dirs: + if sort_dir not in ('asc', 'desc'): + msg = _("Unknown sort direction, must be 'desc' or 'asc'") + raise exception.InvalidInput(reason=msg) + result_dirs.append(sort_dir) + else: + result_dirs = [default_dir_value for _sort_key in result_keys] + + # Ensure that the key and direction length match + while len(result_dirs) < len(result_keys): + result_dirs.append(default_dir_value) + # Unless more direction are specified, which is an error + if len(result_dirs) > len(result_keys): + msg = _("Sort direction size exceeds sort key size") + raise exception.InvalidInput(reason=msg) + + # Ensure defaults are included + for key in default_keys: + if key not in result_keys: + result_keys.append(key) + result_dirs.append(default_dir_value) + + return result_keys, result_dirs diff --git a/nova/objects/aggregate.py b/nova/objects/aggregate.py index 5880165cf355..2aa802cf9b8f 100644 --- a/nova/objects/aggregate.py +++ b/nova/objects/aggregate.py @@ -19,8 +19,8 @@ from oslo_utils import uuidutils from sqlalchemy import orm from nova.compute import utils as compute_utils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova.i18n import _ from nova import objects @@ -32,7 +32,7 @@ LOG = logging.getLogger(__name__) DEPRECATED_FIELDS = ['deleted', 'deleted_at'] -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_get_from_db(context, aggregate_id): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -47,7 +47,7 @@ def _aggregate_get_from_db(context, aggregate_id): return aggregate -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_get_from_db_by_uuid(context, aggregate_uuid): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -64,7 +64,7 @@ def _aggregate_get_from_db_by_uuid(context, aggregate_uuid): def _host_add_to_db(context, aggregate_id, host): try: - with db_api.api_context_manager.writer.using(context): + with api_db_api.context_manager.writer.using(context): # Check to see if the aggregate exists _aggregate_get_from_db(context, aggregate_id) @@ -79,7 +79,7 @@ def _host_add_to_db(context, aggregate_id, host): def _host_delete_from_db(context, aggregate_id, host): count = 0 - with db_api.api_context_manager.writer.using(context): + with api_db_api.context_manager.writer.using(context): # Check to see if the aggregate exists _aggregate_get_from_db(context, aggregate_id) @@ -98,7 +98,7 @@ def _metadata_add_to_db(context, aggregate_id, metadata, max_retries=10, all_keys = metadata.keys() for attempt in range(max_retries): try: - with db_api.api_context_manager.writer.using(context): + with api_db_api.context_manager.writer.using(context): query = context.session.query(api_models.AggregateMetadata).\ filter_by(aggregate_id=aggregate_id) @@ -142,7 +142,7 @@ def _metadata_add_to_db(context, aggregate_id, metadata, max_retries=10, LOG.warning(msg) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _metadata_delete_from_db(context, aggregate_id, key): # Check to see if the aggregate exists _aggregate_get_from_db(context, aggregate_id) @@ -157,7 +157,7 @@ def _metadata_delete_from_db(context, aggregate_id, key): aggregate_id=aggregate_id, metadata_key=key) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _aggregate_create_in_db(context, values, metadata=None): query = context.session.query(api_models.Aggregate) query = query.filter(api_models.Aggregate.name == values['name']) @@ -181,7 +181,7 @@ def _aggregate_create_in_db(context, values, metadata=None): return aggregate -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _aggregate_delete_from_db(context, aggregate_id): # Delete Metadata first context.session.query(api_models.AggregateMetadata).\ @@ -196,7 +196,7 @@ def _aggregate_delete_from_db(context, aggregate_id): raise exception.AggregateNotFound(aggregate_id=aggregate_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _aggregate_update_to_db(context, aggregate_id, values): aggregate = _aggregate_get_from_db(context, aggregate_id) @@ -411,7 +411,7 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject): return self.metadata.get('availability_zone', None) -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_all_from_db(context): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -420,7 +420,7 @@ def _get_all_from_db(context): return query.all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_by_host_from_db(context, host, key=None): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -435,7 +435,7 @@ def _get_by_host_from_db(context, host, key=None): return query.all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_by_metadata_from_db(context, key=None, value=None): assert(key is not None or value is not None) query = context.session.query(api_models.Aggregate) @@ -450,7 +450,7 @@ def _get_by_metadata_from_db(context, key=None, value=None): return query.all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_non_matching_by_metadata_keys_from_db(context, ignored_keys, key_prefix, value): """Filter aggregates based on non matching metadata. diff --git a/nova/objects/build_request.py b/nova/objects/build_request.py index 67f271cb94d6..11ca05cbc70a 100644 --- a/nova/objects/build_request.py +++ b/nova/objects/build_request.py @@ -18,8 +18,9 @@ from oslo_serialization import jsonutils from oslo_utils import versionutils from oslo_versionedobjects import exception as ovoo_exc +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db import utils as db_utils from nova import exception from nova import objects from nova.objects import base @@ -163,7 +164,7 @@ class BuildRequest(base.NovaObject): return req @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuid_from_db(context, instance_uuid): db_req = context.session.query(api_models.BuildRequest).filter_by( instance_uuid=instance_uuid).first() @@ -177,7 +178,7 @@ class BuildRequest(base.NovaObject): return cls._from_db_object(context, cls(), db_req) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_req = api_models.BuildRequest() db_req.update(updates) @@ -206,7 +207,7 @@ class BuildRequest(base.NovaObject): self._from_db_object(self._context, self, db_req) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, instance_uuid): result = context.session.query(api_models.BuildRequest).filter_by( instance_uuid=instance_uuid).delete() @@ -217,7 +218,7 @@ class BuildRequest(base.NovaObject): def destroy(self): self._destroy_in_db(self._context, self.instance_uuid) - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(self, context, req_id, updates): db_req = context.session.query( api_models.BuildRequest).filter_by(id=req_id).first() @@ -262,7 +263,7 @@ class BuildRequestList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db(context): query = context.session.query(api_models.BuildRequest) @@ -396,8 +397,8 @@ class BuildRequestList(base.ObjectListBase, base.NovaObject): # exists. So it can be ignored. # 'deleted' and 'cleaned' are handled above. - sort_keys, sort_dirs = db.process_sort_params(sort_keys, sort_dirs, - default_dir='desc') + sort_keys, sort_dirs = db_utils.process_sort_params( + sort_keys, sort_dirs, default_dir='desc') # For other filters that don't match this, we will do regexp matching # Taken from db/sqlalchemy/api.py diff --git a/nova/objects/cell_mapping.py b/nova/objects/cell_mapping.py index f19f2f190842..595ec43e480e 100644 --- a/nova/objects/cell_mapping.py +++ b/nova/objects/cell_mapping.py @@ -18,8 +18,8 @@ from sqlalchemy import sql from sqlalchemy.sql import expression import nova.conf -from nova.db.api import models as api_models -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.api import models as api_db_models from nova import exception from nova.objects import base from nova.objects import fields @@ -168,11 +168,11 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): return cell_mapping @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_uuid_from_db(context, uuid): - db_mapping = context.session.query(api_models.CellMapping).filter_by( - uuid=uuid).first() + db_mapping = context.session\ + .query(api_db_models.CellMapping).filter_by(uuid=uuid).first() if not db_mapping: raise exception.CellMappingNotFound(uuid=uuid) @@ -185,10 +185,10 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): return cls._from_db_object(context, cls(), db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): - db_mapping = api_models.CellMapping() + db_mapping = api_db_models.CellMapping() db_mapping.update(updates) db_mapping.save(context.session) return db_mapping @@ -199,11 +199,11 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): self._from_db_object(self._context, self, db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, uuid, updates): db_mapping = context.session.query( - api_models.CellMapping).filter_by(uuid=uuid).first() + api_db_models.CellMapping).filter_by(uuid=uuid).first() if not db_mapping: raise exception.CellMappingNotFound(uuid=uuid) @@ -219,10 +219,10 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): self.obj_reset_changes() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, uuid): - result = context.session.query(api_models.CellMapping).filter_by( + result = context.session.query(api_db_models.CellMapping).filter_by( uuid=uuid).delete() if not result: raise exception.CellMappingNotFound(uuid=uuid) @@ -246,10 +246,10 @@ class CellMappingList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db(context): - return context.session.query(api_models.CellMapping).order_by( - expression.asc(api_models.CellMapping.id)).all() + return context.session.query(api_db_models.CellMapping).order_by( + expression.asc(api_db_models.CellMapping.id)).all() @base.remotable_classmethod def get_all(cls, context): @@ -257,16 +257,16 @@ class CellMappingList(base.ObjectListBase, base.NovaObject): return base.obj_make_list(context, cls(), CellMapping, db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_disabled_from_db(context, disabled): if disabled: - return context.session.query(api_models.CellMapping)\ + return context.session.query(api_db_models.CellMapping)\ .filter_by(disabled=sql.true())\ - .order_by(expression.asc(api_models.CellMapping.id)).all() + .order_by(expression.asc(api_db_models.CellMapping.id)).all() else: - return context.session.query(api_models.CellMapping)\ + return context.session.query(api_db_models.CellMapping)\ .filter_by(disabled=sql.false())\ - .order_by(expression.asc(api_models.CellMapping.id)).all() + .order_by(expression.asc(api_db_models.CellMapping.id)).all() @base.remotable_classmethod def get_by_disabled(cls, context, disabled): @@ -274,16 +274,16 @@ class CellMappingList(base.ObjectListBase, base.NovaObject): return base.obj_make_list(context, cls(), CellMapping, db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_project_id_from_db(context, project_id): # SELECT DISTINCT cell_id FROM instance_mappings \ # WHERE project_id = $project_id; cell_ids = context.session.query( - api_models.InstanceMapping.cell_id).filter_by( + api_db_models.InstanceMapping.cell_id).filter_by( project_id=project_id).distinct().subquery() # SELECT cell_mappings WHERE cell_id IN ($cell_ids); - return context.session.query(api_models.CellMapping).filter( - api_models.CellMapping.id.in_(cell_ids)).all() + return context.session.query(api_db_models.CellMapping).filter( + api_db_models.CellMapping.id.in_(cell_ids)).all() @classmethod def get_by_project_id(cls, context, project_id): diff --git a/nova/objects/flavor.py b/nova/objects/flavor.py index ae27b6389288..20378b0dae53 100644 --- a/nova/objects/flavor.py +++ b/nova/objects/flavor.py @@ -21,8 +21,9 @@ from sqlalchemy import sql from sqlalchemy.sql import expression import nova.conf +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api +from nova.db import utils as db_utils from nova import exception from nova.notifications.objects import base as notification from nova.notifications.objects import flavor as flavor_notification @@ -51,7 +52,7 @@ def _dict_with_extra_specs(flavor_model): # decorators with static methods. We pull these out for now and can # move them back into the actual staticmethods on the object when those # issues are resolved. -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_projects_from_db(context, flavorid): db_flavor = context.session.query(api_models.Flavors).\ filter_by(flavorid=flavorid).\ @@ -62,7 +63,7 @@ def _get_projects_from_db(context, flavorid): return [x['project_id'] for x in db_flavor['projects']] -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_add_project(context, flavor_id, project_id): project = api_models.FlavorProjects() project.update({'flavor_id': flavor_id, @@ -74,7 +75,7 @@ def _flavor_add_project(context, flavor_id, project_id): project_id=project_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_del_project(context, flavor_id, project_id): result = context.session.query(api_models.FlavorProjects).\ filter_by(project_id=project_id).\ @@ -85,9 +86,9 @@ def _flavor_del_project(context, flavor_id, project_id): project_id=project_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_extra_specs_add(context, flavor_id, specs, max_retries=10): - writer = db_api.api_context_manager.writer + writer = api_db_api.context_manager.writer for attempt in range(max_retries): try: spec_refs = context.session.query( @@ -122,7 +123,7 @@ def _flavor_extra_specs_add(context, flavor_id, specs, max_retries=10): id=flavor_id, retries=max_retries) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_extra_specs_del(context, flavor_id, key): result = context.session.query(api_models.FlavorExtraSpecs).\ filter_by(flavor_id=flavor_id).\ @@ -133,7 +134,7 @@ def _flavor_extra_specs_del(context, flavor_id, key): extra_specs_key=key, flavor_id=flavor_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_create(context, values): specs = values.get('extra_specs') db_specs = [] @@ -169,7 +170,7 @@ def _flavor_create(context, values): return _dict_with_extra_specs(db_flavor) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_destroy(context, flavor_id=None, flavorid=None): query = context.session.query(api_models.Flavors) @@ -268,7 +269,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return flavor @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _flavor_get_query_from_db(context): query = context.session.query(api_models.Flavors).\ options(orm.joinedload('extra_specs')) @@ -281,7 +282,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return query @staticmethod - @db_api.require_context + @db_utils.require_context def _flavor_get_from_db(context, id): """Returns a dict describing specific flavor.""" result = Flavor._flavor_get_query_from_db(context).\ @@ -292,7 +293,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return _dict_with_extra_specs(result) @staticmethod - @db_api.require_context + @db_utils.require_context def _flavor_get_by_name_from_db(context, name): """Returns a dict describing specific flavor.""" result = Flavor._flavor_get_query_from_db(context).\ @@ -303,7 +304,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return _dict_with_extra_specs(result) @staticmethod - @db_api.require_context + @db_utils.require_context def _flavor_get_by_flavor_id_from_db(context, flavor_id): """Returns a dict describing specific flavor_id.""" result = Flavor._flavor_get_query_from_db(context).\ @@ -485,7 +486,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, # NOTE(mriedem): This method is not remotable since we only expect the API # to be able to make updates to a flavor. - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save(self, context, values): db_flavor = context.session.query(api_models.Flavors).\ filter_by(id=self.id).first() @@ -581,7 +582,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, payload=payload).emit(self._context) -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _flavor_get_all_from_db(context, inactive, filters, sort_key, sort_dir, limit, marker): """Returns all flavors. diff --git a/nova/objects/host_mapping.py b/nova/objects/host_mapping.py index 98e36f3edfdf..09dfe81354b3 100644 --- a/nova/objects/host_mapping.py +++ b/nova/objects/host_mapping.py @@ -14,8 +14,8 @@ from oslo_db import exception as db_exc from sqlalchemy import orm from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova.i18n import _ from nova.objects import base @@ -52,7 +52,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): } def _get_cell_mapping(self): - with db_api.api_context_manager.reader.using(self._context) as session: + with api_db_api.context_manager.reader.using(self._context) as session: cell_map = (session.query(api_models.CellMapping) .join(api_models.HostMapping) .filter(api_models.HostMapping.host == self.host) @@ -87,7 +87,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): return host_mapping @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_host_from_db(context, host): db_mapping = context.session.query(api_models.HostMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -102,7 +102,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): return cls._from_db_object(context, cls(), db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_mapping = api_models.HostMapping() return _apply_updates(context, db_mapping, updates) @@ -116,7 +116,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): self._from_db_object(self._context, self, db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, obj, updates): db_mapping = context.session.query(api_models.HostMapping).filter_by( id=obj.id).first() @@ -134,7 +134,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): self.obj_reset_changes() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, host): result = context.session.query(api_models.HostMapping).filter_by( host=host).delete() @@ -157,7 +157,7 @@ class HostMappingList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db(context, cell_id=None): query = (context.session.query(api_models.HostMapping) .options(orm.joinedload('cell_mapping'))) diff --git a/nova/objects/instance_group.py b/nova/objects/instance_group.py index 26df07df92b9..34cf40b7fe4a 100644 --- a/nova/objects/instance_group.py +++ b/nova/objects/instance_group.py @@ -22,8 +22,8 @@ from oslo_utils import versionutils from sqlalchemy import orm from nova.compute import utils as compute_utils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova import objects from nova.objects import base @@ -213,7 +213,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return instance_group @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_uuid(context, uuid): grp = _instance_group_get_query(context, id_field=api_models.InstanceGroup.uuid, @@ -223,7 +223,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_id(context, id): grp = _instance_group_get_query(context, id_field=api_models.InstanceGroup.id, @@ -233,7 +233,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_name(context, name): grp = _instance_group_get_query(context).filter_by(name=name).first() if not grp: @@ -241,7 +241,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_instance(context, instance_uuid): grp_member = context.session.query(api_models.InstanceGroupMember).\ filter_by(instance_uuid=instance_uuid).first() @@ -251,7 +251,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, group_uuid, values): grp = InstanceGroup._get_from_db_by_uuid(context, group_uuid) values_copy = copy.copy(values) @@ -265,7 +265,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, values, policies=None, members=None, policy=None, rules=None): try: @@ -301,7 +301,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return group @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, group_uuid): qry = _instance_group_get_query(context, id_field=api_models.InstanceGroup.uuid, @@ -319,13 +319,13 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, qry.delete() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _add_members_in_db(context, group_uuid, members): return _instance_group_members_add_by_uuid(context, group_uuid, members) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _remove_members_in_db(context, group_id, instance_uuids): # There is no public method provided for removing members because the # user-facing API doesn't allow removal of instance group members. We @@ -337,7 +337,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, delete(synchronize_session=False) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_members_bulk_in_db(context, instance_uuids): return context.session.query(api_models.InstanceGroupMember).filter( api_models.InstanceGroupMember.instance_uuid.in_(instance_uuids)).\ @@ -537,7 +537,7 @@ class InstanceGroupList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db(context, project_id=None): query = _instance_group_get_query(context) if project_id is not None: @@ -545,7 +545,7 @@ class InstanceGroupList(base.ObjectListBase, base.NovaObject): return query.all() @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_counts_from_db(context, project_id, user_id=None): query = context.session.query(api_models.InstanceGroup.id).\ filter_by(project_id=project_id) diff --git a/nova/objects/instance_mapping.py b/nova/objects/instance_mapping.py index 43fa5333ba3b..68f45cd8cc3c 100644 --- a/nova/objects/instance_mapping.py +++ b/nova/objects/instance_mapping.py @@ -20,8 +20,8 @@ from sqlalchemy import sql from sqlalchemy.sql import func from nova import context as nova_context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova.i18n import _ from nova import objects @@ -96,7 +96,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): return instance_mapping @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuid_from_db(context, instance_uuid): db_mapping = context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -113,7 +113,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): return cls._from_db_object(context, cls(), db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_mapping = api_models.InstanceMapping() db_mapping.update(updates) @@ -138,7 +138,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): self._from_db_object(self._context, self, db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, instance_uuid, updates): db_mapping = context.session.query( api_models.InstanceMapping).filter_by( @@ -173,7 +173,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): self.obj_reset_changes() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, instance_uuid): result = context.session.query(api_models.InstanceMapping).filter_by( instance_uuid=instance_uuid).delete() @@ -185,7 +185,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): self._destroy_in_db(self._context, self.instance_uuid) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def populate_queued_for_delete(context, max_count): cells = objects.CellMappingList.get_all(context) processed = 0 @@ -229,7 +229,7 @@ def populate_queued_for_delete(context, max_count): return processed, processed -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def populate_user_id(context, max_count): cells = objects.CellMappingList.get_all(context) cms_by_id = {cell.id: cell for cell in cells} @@ -309,7 +309,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_project_id_from_db(context, project_id): return context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -323,7 +323,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_cell_id_from_db(context, cell_id): return context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -336,7 +336,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuids_from_db(context, uuids): return context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -350,7 +350,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_bulk_in_db(context, instance_uuids): return context.session.query(api_models.InstanceMapping).filter( api_models.InstanceMapping.instance_uuid.in_(instance_uuids)).\ @@ -361,7 +361,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): return cls._destroy_bulk_in_db(context, instance_uuids) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_not_deleted_by_cell_and_project_from_db(context, cell_uuid, project_id, limit): query = context.session.query(api_models.InstanceMapping) @@ -400,7 +400,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_counts_in_db(context, project_id, user_id=None): project_query = context.session.query( func.count(api_models.InstanceMapping.id)).\ @@ -435,7 +435,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): return cls._get_counts_in_db(context, project_id, user_id=user_id) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_count_by_uuids_and_user_in_db(context, uuids, user_id): query = (context.session.query( func.count(api_models.InstanceMapping.id)) diff --git a/nova/objects/keypair.py b/nova/objects/keypair.py index aa6075c5f961..0b3c4315d5c9 100644 --- a/nova/objects/keypair.py +++ b/nova/objects/keypair.py @@ -17,8 +17,9 @@ from oslo_db.sqlalchemy import utils as sqlalchemyutils from oslo_log import log as logging from oslo_utils import versionutils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova import exception from nova import objects from nova.objects import base @@ -29,7 +30,7 @@ KEYPAIR_TYPE_X509 = 'x509' LOG = logging.getLogger(__name__) -@db.api_context_manager.reader +@api_db_api.context_manager.reader def _get_from_db(context, user_id, name=None, limit=None, marker=None): query = context.session.query(api_models.KeyPair).\ filter(api_models.KeyPair.user_id == user_id) @@ -54,14 +55,14 @@ def _get_from_db(context, user_id, name=None, limit=None, marker=None): return query.all() -@db.api_context_manager.reader +@api_db_api.context_manager.reader def _get_count_from_db(context, user_id): return context.session.query(api_models.KeyPair).\ filter(api_models.KeyPair.user_id == user_id).\ count() -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _create_in_db(context, values): kp = api_models.KeyPair() kp.update(values) @@ -72,7 +73,7 @@ def _create_in_db(context, values): return kp -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _destroy_in_db(context, user_id, name): result = context.session.query(api_models.KeyPair).\ filter_by(user_id=user_id).\ @@ -143,7 +144,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, except exception.KeypairNotFound: pass if db_keypair is None: - db_keypair = db.key_pair_get(context, user_id, name) + db_keypair = main_db_api.key_pair_get(context, user_id, name) return cls._from_db_object(context, cls(), db_keypair) @base.remotable_classmethod @@ -151,7 +152,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, try: cls._destroy_in_db(context, user_id, name) except exception.KeypairNotFound: - db.key_pair_destroy(context, user_id, name) + main_db_api.key_pair_destroy(context, user_id, name) @base.remotable def create(self): @@ -163,7 +164,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, # letting them create in the API DB, since we won't get protection # from the UC. try: - db.key_pair_get(self._context, self.user_id, self.name) + main_db_api.key_pair_get(self._context, self.user_id, self.name) raise exception.KeyPairExists(key_name=self.name) except exception.KeypairNotFound: pass @@ -180,7 +181,8 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, try: self._destroy_in_db(self._context, self.user_id, self.name) except exception.KeypairNotFound: - db.key_pair_destroy(self._context, self.user_id, self.name) + main_db_api.key_pair_destroy( + self._context, self.user_id, self.name) @base.NovaObjectRegistry.register @@ -222,7 +224,7 @@ class KeyPairList(base.ObjectListBase, base.NovaObject): limit_more = None if limit_more is None or limit_more > 0: - main_db_keypairs = db.key_pair_get_all_by_user( + main_db_keypairs = main_db_api.key_pair_get_all_by_user( context, user_id, limit=limit_more, marker=marker) else: main_db_keypairs = [] @@ -233,4 +235,4 @@ class KeyPairList(base.ObjectListBase, base.NovaObject): @base.remotable_classmethod def get_count_by_user(cls, context, user_id): return (cls._get_count_from_db(context, user_id) + - db.key_pair_count_by_user(context, user_id)) + main_db_api.key_pair_count_by_user(context, user_id)) diff --git a/nova/objects/quotas.py b/nova/objects/quotas.py index 39126dbb6de4..4bbb64fdf2bd 100644 --- a/nova/objects/quotas.py +++ b/nova/objects/quotas.py @@ -16,9 +16,11 @@ import collections from oslo_db import exception as db_exc +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova.db.main import models as main_models +from nova.db import utils as db_utils from nova import exception from nova.objects import base from nova.objects import fields @@ -82,7 +84,7 @@ class Quotas(base.NovaObject): self.obj_reset_changes(fields=[attr]) @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db(context, project_id, resource, user_id=None): model = api_models.ProjectUserQuota if user_id else api_models.Quota query = context.session.query(model).\ @@ -100,14 +102,14 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db(context, project_id): return context.session.query(api_models.ProjectUserQuota).\ filter_by(project_id=project_id).\ all() @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db_by_project(context, project_id): # by_project refers to the returned dict that has a 'project_id' key rows = context.session.query(api_models.Quota).\ @@ -119,7 +121,7 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db_by_project_and_user(context, project_id, user_id): # by_project_and_user refers to the returned dict that has # 'project_id' and 'user_id' keys @@ -135,7 +137,7 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_all_in_db_by_project(context, project_id): per_project = context.session.query(api_models.Quota).\ filter_by(project_id=project_id).\ @@ -147,7 +149,7 @@ class Quotas(base.NovaObject): raise exception.ProjectQuotaNotFound(project_id=project_id) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_all_in_db_by_project_and_user(context, project_id, user_id): result = context.session.query(api_models.ProjectUserQuota).\ filter_by(project_id=project_id).\ @@ -158,7 +160,7 @@ class Quotas(base.NovaObject): user_id=user_id) @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_class_from_db(context, class_name, resource): result = context.session.query(api_models.QuotaClass).\ filter_by(class_name=class_name).\ @@ -169,7 +171,7 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_class_from_db_by_name(context, class_name): # by_name refers to the returned dict that has a 'class_name' key rows = context.session.query(api_models.QuotaClass).\ @@ -181,14 +183,16 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_limit_in_db(context, project_id, resource, limit, user_id=None): # TODO(melwitt): We won't have per project resources after nova-network # is removed. # TODO(stephenfin): We need to do something here now...but what? - per_user = (user_id and - resource not in db.quota_get_per_project_resources()) + per_user = ( + user_id and + resource not in main_db_api.quota_get_per_project_resources() + ) quota_ref = (api_models.ProjectUserQuota() if per_user else api_models.Quota()) if per_user: @@ -204,14 +208,16 @@ class Quotas(base.NovaObject): return quota_ref @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _update_limit_in_db(context, project_id, resource, limit, user_id=None): # TODO(melwitt): We won't have per project resources after nova-network # is removed. # TODO(stephenfin): We need to do something here now...but what? - per_user = (user_id and - resource not in db.quota_get_per_project_resources()) + per_user = ( + user_id and + resource not in main_db_api.quota_get_per_project_resources() + ) model = api_models.ProjectUserQuota if per_user else api_models.Quota query = context.session.query(model).\ filter_by(project_id=project_id).\ @@ -228,7 +234,7 @@ class Quotas(base.NovaObject): raise exception.ProjectQuotaNotFound(project_id=project_id) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_class_in_db(context, class_name, resource, limit): # NOTE(melwitt): There's no unique constraint on the QuotaClass model, # so check for duplicate manually. @@ -247,7 +253,7 @@ class Quotas(base.NovaObject): return quota_class_ref @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _update_class_in_db(context, class_name, resource, limit): result = context.session.query(api_models.QuotaClass).\ filter_by(class_name=class_name).\ @@ -366,7 +372,8 @@ class Quotas(base.NovaObject): @base.remotable_classmethod def create_limit(cls, context, project_id, resource, limit, user_id=None): try: - db.quota_get(context, project_id, resource, user_id=user_id) + main_db_api.quota_get( + context, project_id, resource, user_id=user_id) except exception.QuotaNotFound: cls._create_limit_in_db(context, project_id, resource, limit, user_id=user_id) @@ -380,13 +387,13 @@ class Quotas(base.NovaObject): cls._update_limit_in_db(context, project_id, resource, limit, user_id=user_id) except exception.QuotaNotFound: - db.quota_update(context, project_id, resource, limit, + main_db_api.quota_update(context, project_id, resource, limit, user_id=user_id) @classmethod def create_class(cls, context, class_name, resource, limit): try: - db.quota_class_get(context, class_name, resource) + main_db_api.quota_class_get(context, class_name, resource) except exception.QuotaClassNotFound: cls._create_class_in_db(context, class_name, resource, limit) else: @@ -398,7 +405,8 @@ class Quotas(base.NovaObject): try: cls._update_class_in_db(context, class_name, resource, limit) except exception.QuotaClassNotFound: - db.quota_class_update(context, class_name, resource, limit) + main_db_api.quota_class_update( + context, class_name, resource, limit) # NOTE(melwitt): The following methods are not remotable and return # dict-like database model objects. We are using classmethods to provide @@ -409,21 +417,22 @@ class Quotas(base.NovaObject): quota = cls._get_from_db(context, project_id, resource, user_id=user_id) except exception.QuotaNotFound: - quota = db.quota_get(context, project_id, resource, + quota = main_db_api.quota_get(context, project_id, resource, user_id=user_id) return quota @classmethod def get_all(cls, context, project_id): api_db_quotas = cls._get_all_from_db(context, project_id) - main_db_quotas = db.quota_get_all(context, project_id) + main_db_quotas = main_db_api.quota_get_all(context, project_id) return api_db_quotas + main_db_quotas @classmethod def get_all_by_project(cls, context, project_id): api_db_quotas_dict = cls._get_all_from_db_by_project(context, project_id) - main_db_quotas_dict = db.quota_get_all_by_project(context, project_id) + main_db_quotas_dict = main_db_api.quota_get_all_by_project( + context, project_id) for k, v in api_db_quotas_dict.items(): main_db_quotas_dict[k] = v return main_db_quotas_dict @@ -432,7 +441,7 @@ class Quotas(base.NovaObject): def get_all_by_project_and_user(cls, context, project_id, user_id): api_db_quotas_dict = cls._get_all_from_db_by_project_and_user( context, project_id, user_id) - main_db_quotas_dict = db.quota_get_all_by_project_and_user( + main_db_quotas_dict = main_db_api.quota_get_all_by_project_and_user( context, project_id, user_id) for k, v in api_db_quotas_dict.items(): main_db_quotas_dict[k] = v @@ -443,7 +452,7 @@ class Quotas(base.NovaObject): try: cls._destroy_all_in_db_by_project(context, project_id) except exception.ProjectQuotaNotFound: - db.quota_destroy_all_by_project(context, project_id) + main_db_api.quota_destroy_all_by_project(context, project_id) @classmethod def destroy_all_by_project_and_user(cls, context, project_id, user_id): @@ -451,31 +460,31 @@ class Quotas(base.NovaObject): cls._destroy_all_in_db_by_project_and_user(context, project_id, user_id) except exception.ProjectUserQuotaNotFound: - db.quota_destroy_all_by_project_and_user(context, project_id, - user_id) + main_db_api.quota_destroy_all_by_project_and_user( + context, project_id, user_id) @classmethod def get_class(cls, context, class_name, resource): try: qclass = cls._get_class_from_db(context, class_name, resource) except exception.QuotaClassNotFound: - qclass = db.quota_class_get(context, class_name, resource) + qclass = main_db_api.quota_class_get(context, class_name, resource) return qclass @classmethod def get_default_class(cls, context): try: qclass = cls._get_all_class_from_db_by_name( - context, db._DEFAULT_QUOTA_NAME) + context, main_db_api._DEFAULT_QUOTA_NAME) except exception.QuotaClassNotFound: - qclass = db.quota_class_get_default(context) + qclass = main_db_api.quota_class_get_default(context) return qclass @classmethod def get_all_class_by_name(cls, context, class_name): api_db_quotas_dict = cls._get_all_class_from_db_by_name(context, class_name) - main_db_quotas_dict = db.quota_class_get_all_by_name(context, + main_db_quotas_dict = main_db_api.quota_class_get_all_by_name(context, class_name) for k, v in api_db_quotas_dict.items(): main_db_quotas_dict[k] = v @@ -501,8 +510,8 @@ class QuotasNoOp(Quotas): pass -@db.require_context -@db.pick_context_manager_reader +@db_utils.require_context +@main_db_api.pick_context_manager_reader def _get_main_per_project_limits(context, limit): return context.session.query(main_models.Quota).\ filter_by(deleted=0).\ @@ -510,8 +519,8 @@ def _get_main_per_project_limits(context, limit): all() -@db.require_context -@db.pick_context_manager_reader +@db_utils.require_context +@main_db_api.pick_context_manager_reader def _get_main_per_user_limits(context, limit): return context.session.query(main_models.ProjectUserQuota).\ filter_by(deleted=0).\ @@ -519,8 +528,8 @@ def _get_main_per_user_limits(context, limit): all() -@db.require_context -@db.pick_context_manager_writer +@db_utils.require_context +@main_db_api.pick_context_manager_writer def _destroy_main_per_project_limits(context, project_id, resource): context.session.query(main_models.Quota).\ filter_by(deleted=0).\ @@ -529,8 +538,8 @@ def _destroy_main_per_project_limits(context, project_id, resource): soft_delete(synchronize_session=False) -@db.require_context -@db.pick_context_manager_writer +@db_utils.require_context +@main_db_api.pick_context_manager_writer def _destroy_main_per_user_limits(context, project_id, resource, user_id): context.session.query(main_models.ProjectUserQuota).\ filter_by(deleted=0).\ @@ -540,7 +549,7 @@ def _destroy_main_per_user_limits(context, project_id, resource, user_id): soft_delete(synchronize_session=False) -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _create_limits_in_api_db(context, db_limits, per_user=False): for db_limit in db_limits: user_id = db_limit.user_id if per_user else None @@ -587,8 +596,8 @@ def migrate_quota_limits_to_api_db(context, count): return len(main_per_project_limits) + len(main_per_user_limits), done -@db.require_context -@db.pick_context_manager_reader +@db_utils.require_context +@main_db_api.pick_context_manager_reader def _get_main_quota_classes(context, limit): return context.session.query(main_models.QuotaClass).\ filter_by(deleted=0).\ @@ -596,7 +605,7 @@ def _get_main_quota_classes(context, limit): all() -@db.pick_context_manager_writer +@main_db_api.pick_context_manager_writer def _destroy_main_quota_classes(context, db_classes): for db_class in db_classes: context.session.query(main_models.QuotaClass).\ @@ -605,7 +614,7 @@ def _destroy_main_quota_classes(context, db_classes): soft_delete(synchronize_session=False) -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _create_classes_in_api_db(context, db_classes): for db_class in db_classes: Quotas._create_class_in_db(context, db_class.class_name, diff --git a/nova/objects/request_spec.py b/nova/objects/request_spec.py index cef1c6633484..7d57903078af 100644 --- a/nova/objects/request_spec.py +++ b/nova/objects/request_spec.py @@ -20,8 +20,8 @@ from oslo_log import log as logging from oslo_serialization import jsonutils from oslo_utils import versionutils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db from nova import exception from nova import objects from nova.objects import base @@ -630,7 +630,7 @@ class RequestSpec(base.NovaObject): return spec @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuid_from_db(context, instance_uuid): db_spec = context.session.query(api_models.RequestSpec).filter_by( instance_uuid=instance_uuid).first() @@ -645,7 +645,7 @@ class RequestSpec(base.NovaObject): return cls._from_db_object(context, cls(), db_spec) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_spec = api_models.RequestSpec() db_spec.update(updates) @@ -709,7 +709,7 @@ class RequestSpec(base.NovaObject): self._from_db_object(self._context, self, db_spec) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, instance_uuid, updates): # FIXME(sbauza): Provide a classmethod when oslo.db bug #1520195 is # fixed and released @@ -729,7 +729,7 @@ class RequestSpec(base.NovaObject): self.obj_reset_changes() @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, instance_uuid): result = context.session.query(api_models.RequestSpec).filter_by( instance_uuid=instance_uuid).delete() @@ -741,7 +741,7 @@ class RequestSpec(base.NovaObject): self._destroy_in_db(self._context, self.instance_uuid) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_bulk_in_db(context, instance_uuids): return context.session.query(api_models.RequestSpec).filter( api_models.RequestSpec.instance_uuid.in_(instance_uuids)).\ diff --git a/nova/objects/virtual_interface.py b/nova/objects/virtual_interface.py index e9a769a7b513..7ce418aca202 100644 --- a/nova/objects/virtual_interface.py +++ b/nova/objects/virtual_interface.py @@ -16,8 +16,9 @@ from oslo_log import log as logging from oslo_utils import versionutils from nova import context as nova_context -from nova.db.main import api as db -from nova.db.main import models +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api +from nova.db.main import models as main_db_models from nova import exception from nova import objects from nova.objects import base @@ -70,26 +71,26 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject): @base.remotable_classmethod def get_by_id(cls, context, vif_id): - db_vif = db.virtual_interface_get(context, vif_id) + db_vif = main_db_api.virtual_interface_get(context, vif_id) if db_vif: return cls._from_db_object(context, cls(), db_vif) @base.remotable_classmethod def get_by_uuid(cls, context, vif_uuid): - db_vif = db.virtual_interface_get_by_uuid(context, vif_uuid) + db_vif = main_db_api.virtual_interface_get_by_uuid(context, vif_uuid) if db_vif: return cls._from_db_object(context, cls(), db_vif) @base.remotable_classmethod def get_by_address(cls, context, address): - db_vif = db.virtual_interface_get_by_address(context, address) + db_vif = main_db_api.virtual_interface_get_by_address(context, address) if db_vif: return cls._from_db_object(context, cls(), db_vif) @base.remotable_classmethod def get_by_instance_and_network(cls, context, instance_uuid, network_id): - db_vif = db.virtual_interface_get_by_instance_and_network(context, - instance_uuid, network_id) + db_vif = main_db_api.virtual_interface_get_by_instance_and_network( + context, instance_uuid, network_id) if db_vif: return cls._from_db_object(context, cls(), db_vif) @@ -99,7 +100,7 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject): raise exception.ObjectActionError(action='create', reason='already created') updates = self.obj_get_changes() - db_vif = db.virtual_interface_create(self._context, updates) + db_vif = main_db_api.virtual_interface_create(self._context, updates) self._from_db_object(self._context, self, db_vif) @base.remotable @@ -108,17 +109,18 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject): if 'address' in updates: raise exception.ObjectActionError(action='save', reason='address is not mutable') - db_vif = db.virtual_interface_update(self._context, self.address, - updates) + db_vif = main_db_api.virtual_interface_update( + self._context, self.address, updates) return self._from_db_object(self._context, self, db_vif) @base.remotable_classmethod def delete_by_instance_uuid(cls, context, instance_uuid): - db.virtual_interface_delete_by_instance(context, instance_uuid) + main_db_api.virtual_interface_delete_by_instance( + context, instance_uuid) @base.remotable def destroy(self): - db.virtual_interface_delete(self._context, self.id) + main_db_api.virtual_interface_delete(self._context, self.id) @base.NovaObjectRegistry.register @@ -131,15 +133,16 @@ class VirtualInterfaceList(base.ObjectListBase, base.NovaObject): @base.remotable_classmethod def get_all(cls, context): - db_vifs = db.virtual_interface_get_all(context) + db_vifs = main_db_api.virtual_interface_get_all(context) return base.obj_make_list(context, cls(context), objects.VirtualInterface, db_vifs) @staticmethod - @db.select_db_reader_mode + @main_db_api.select_db_reader_mode def _db_virtual_interface_get_by_instance(context, instance_uuid, use_slave=False): - return db.virtual_interface_get_by_instance(context, instance_uuid) + return main_db_api.virtual_interface_get_by_instance( + context, instance_uuid) @base.remotable_classmethod def get_by_instance_uuid(cls, context, instance_uuid, use_slave=False): @@ -149,7 +152,7 @@ class VirtualInterfaceList(base.ObjectListBase, base.NovaObject): objects.VirtualInterface, db_vifs) -@db.api_context_manager.writer +@api_db_api.context_manager.writer def fill_virtual_interface_list(context, max_count): """This fills missing VirtualInterface Objects in Nova DB""" count_hit = 0 @@ -287,14 +290,14 @@ def fill_virtual_interface_list(context, max_count): # we checked. # Please notice that because of virtual_interfaces_instance_uuid_fkey # we need to have FAKE_UUID instance object, even deleted one. -@db.pick_context_manager_writer +@main_db_api.pick_context_manager_writer def _set_or_delete_marker_for_migrate_instances(context, marker=None): - context.session.query(models.VirtualInterface).filter_by( + context.session.query(main_db_models.VirtualInterface).filter_by( instance_uuid=FAKE_UUID).delete() # Create FAKE_UUID instance objects, only for marker, if doesn't exist. # It is needed due constraint: virtual_interfaces_instance_uuid_fkey - instance = context.session.query(models.Instance).filter_by( + instance = context.session.query(main_db_models.Instance).filter_by( uuid=FAKE_UUID).first() if not instance: instance = objects.Instance(context) @@ -316,9 +319,9 @@ def _set_or_delete_marker_for_migrate_instances(context, marker=None): db_mapping.create() -@db.pick_context_manager_reader +@main_db_api.pick_context_manager_reader def _get_marker_for_migrate_instances(context): - vif = (context.session.query(models.VirtualInterface).filter_by( + vif = (context.session.query(main_db_models.VirtualInterface).filter_by( instance_uuid=FAKE_UUID)).first() marker = vif['tag'] if vif else None return marker diff --git a/nova/quota.py b/nova/quota.py index e3950e9d99c2..a311ecc87fae 100644 --- a/nova/quota.py +++ b/nova/quota.py @@ -24,8 +24,9 @@ from sqlalchemy import sql import nova.conf from nova import context as nova_context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova import exception from nova import objects from nova.scheduler.client import report @@ -177,7 +178,10 @@ class DbQuotaDriver(object): # displaying used limits. They are always zero. usages[resource.name] = {'in_use': 0} else: - if resource.name in db.quota_get_per_project_resources(): + if ( + resource.name in + main_db_api.quota_get_per_project_resources() + ): count = resource.count_as_dict(context, project_id) key = 'project' else: @@ -1045,7 +1049,7 @@ class QuotaEngine(object): return 0 -@db.api_context_manager.reader +@api_db_api.context_manager.reader def _user_id_queued_for_delete_populated(context, project_id=None): """Determine whether user_id and queued_for_delete are set. diff --git a/nova/tests/fixtures/nova.py b/nova/tests/fixtures/nova.py index 048adbd1b9d8..d5d96f5122e9 100644 --- a/nova/tests/fixtures/nova.py +++ b/nova/tests/fixtures/nova.py @@ -43,7 +43,8 @@ from nova.api import wsgi from nova.compute import multi_cell_list from nova.compute import rpcapi as compute_rpcapi from nova import context -from nova.db.main import api as session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova.db import migration from nova import exception from nova import objects @@ -61,7 +62,6 @@ LOG = logging.getLogger(__name__) DB_SCHEMA = collections.defaultdict(str) SESSION_CONFIGURED = False - PROJECT_ID = '6f70656e737461636b20342065766572' @@ -532,7 +532,7 @@ class CellDatabases(fixtures.Fixture): # will house the sqlite:// connection for this cell's in-memory # database. Store/index it by the connection string, which is # how we identify cells in CellMapping. - ctxt_mgr = session.create_context_manager() + ctxt_mgr = main_db_api.create_context_manager() self._ctxt_mgrs[connection_str] = ctxt_mgr # NOTE(melwitt): The first DB access through service start is @@ -595,30 +595,45 @@ class CellDatabases(fixtures.Fixture): class Database(fixtures.Fixture): + + # TODO(stephenfin): The 'version' argument is unused and can be removed def __init__(self, database='main', version=None, connection=None): """Create a database fixture. :param database: The type of database, 'main', or 'api' :param connection: The connection string to use """ - super(Database, self).__init__() - # NOTE(pkholkin): oslo_db.enginefacade is configured in tests the same - # way as it is done for any other service that uses db + super().__init__() + + # NOTE(pkholkin): oslo_db.enginefacade is configured in tests the + # same way as it is done for any other service that uses DB global SESSION_CONFIGURED if not SESSION_CONFIGURED: - session.configure(CONF) + main_db_api.configure(CONF) + api_db_api.configure(CONF) SESSION_CONFIGURED = True + + assert database in {'main', 'api'}, f'Unrecognized database {database}' + self.database = database self.version = version + if database == 'main': if connection is not None: - ctxt_mgr = session.create_context_manager( - connection=connection) + ctxt_mgr = main_db_api.create_context_manager( + connection=connection) self.get_engine = ctxt_mgr.writer.get_engine else: - self.get_engine = session.get_engine + self.get_engine = main_db_api.get_engine elif database == 'api': - self.get_engine = session.get_api_engine + assert connection is None, 'Not supported for the API database' + + self.get_engine = api_db_api.get_engine + + def setUp(self): + super(Database, self).setUp() + self.reset() + self.addCleanup(self.cleanup) def _cache_schema(self): global DB_SCHEMA @@ -642,11 +657,6 @@ class Database(fixtures.Fixture): conn.connection.executescript( DB_SCHEMA[(self.database, self.version)]) - def setUp(self): - super(Database, self).setUp() - self.reset() - self.addCleanup(self.cleanup) - class DefaultFlavorsFixture(fixtures.Fixture): def setUp(self): diff --git a/nova/tests/functional/db/test_aggregate.py b/nova/tests/functional/db/test_aggregate.py index 3c1cf8ded083..35d9024576e3 100644 --- a/nova/tests/functional/db/test_aggregate.py +++ b/nova/tests/functional/db/test_aggregate.py @@ -18,8 +18,8 @@ from oslo_utils.fixture import uuidsentinel from oslo_utils import timeutils from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception import nova.objects.aggregate as aggregate_obj from nova import test @@ -59,7 +59,7 @@ def _get_fake_metadata(db_id): 'unique_key': 'unique_value_' + str(db_id)} -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _create_aggregate(context, values=_get_fake_aggregate(1, result=False), metadata=_get_fake_metadata(1)): aggregate = api_models.Aggregate() @@ -77,7 +77,7 @@ def _create_aggregate(context, values=_get_fake_aggregate(1, result=False), return aggregate -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _create_aggregate_with_hosts(context, values=_get_fake_aggregate(1, result=False), metadata=_get_fake_metadata(1), @@ -92,13 +92,13 @@ def _create_aggregate_with_hosts(context, return aggregate -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_host_get_all(context, aggregate_id): return context.session.query(api_models.AggregateHost).\ filter_by(aggregate_id=aggregate_id).all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_metadata_get_all(context, aggregate_id): results = context.session.query(api_models.AggregateMetadata).\ filter_by(aggregate_id=aggregate_id).all() diff --git a/nova/tests/functional/db/test_flavor.py b/nova/tests/functional/db/test_flavor.py index 3ce192fec754..c0435b5cbc64 100644 --- a/nova/tests/functional/db/test_flavor.py +++ b/nova/tests/functional/db/test_flavor.py @@ -11,8 +11,8 @@ # under the License. from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova import objects from nova import test @@ -111,7 +111,7 @@ class FlavorObjectTestCase(test.NoDBTestCase): self.assertEqual(set(projects), set(flavor2.projects)) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _collect_flavor_residue_api(context, flavor): flavors = context.session.query(api_models.Flavors).\ filter_by(id=flavor.id).all() diff --git a/nova/tests/unit/api/openstack/test_wsgi_app.py b/nova/tests/unit/api/openstack/test_wsgi_app.py index b6306af0e12a..d2bd7c4bd6e1 100644 --- a/nova/tests/unit/api/openstack/test_wsgi_app.py +++ b/nova/tests/unit/api/openstack/test_wsgi_app.py @@ -46,11 +46,14 @@ document_root = /tmp self.useFixture(config_fixture.Config()) @mock.patch('sys.argv', return_value=mock.sentinel.argv) + @mock.patch('nova.db.api.api.configure') @mock.patch('nova.db.main.api.configure') @mock.patch('nova.api.openstack.wsgi_app._setup_service') @mock.patch('nova.api.openstack.wsgi_app._get_config_files') - def test_init_application_called_twice(self, mock_get_files, mock_setup, - mock_db_configure, mock_argv): + def test_init_application_called_twice( + self, mock_get_files, mock_setup, mock_main_db_configure, + mock_api_db_configure, mock_argv, + ): """Test that init_application can tolerate being called twice in a single python interpreter instance. @@ -65,14 +68,15 @@ document_root = /tmp """ mock_get_files.return_value = [self.conf.name] mock_setup.side_effect = [test.TestingException, None] - # We need to mock the global database configure() method, else we will + # We need to mock the global database configure() methods, else we will # be affected by global database state altered by other tests that ran # before this test, causing this test to fail with # oslo_db.sqlalchemy.enginefacade.AlreadyStartedError. We can instead # mock the method to raise an exception if it's called a second time in # this test to simulate the fact that the database does not tolerate # re-init [after a database query has been made]. - mock_db_configure.side_effect = [None, test.TestingException] + mock_main_db_configure.side_effect = [None, test.TestingException] + mock_api_db_configure.side_effect = [None, test.TestingException] # Run init_application the first time, simulating an exception being # raised during it. self.assertRaises(test.TestingException, wsgi_app.init_application, diff --git a/nova/tests/unit/conductor/test_conductor.py b/nova/tests/unit/conductor/test_conductor.py index 152a625286c7..ccf84e0929f6 100644 --- a/nova/tests/unit/conductor/test_conductor.py +++ b/nova/tests/unit/conductor/test_conductor.py @@ -39,8 +39,9 @@ from nova.conductor.tasks import live_migrate from nova.conductor.tasks import migrate from nova import conf from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova import exception as exc from nova.image import glance as image_api from nova import objects @@ -440,7 +441,7 @@ class _BaseTaskTestCase(object): @mock.patch.object(conductor_manager.ComputeTaskManager, '_create_and_bind_arqs') @mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance') - @mock.patch.object(db, 'block_device_mapping_get_all_by_instance', + @mock.patch.object(main_db_api, 'block_device_mapping_get_all_by_instance', return_value=[]) @mock.patch.object(conductor_manager.ComputeTaskManager, '_schedule_instances') @@ -547,7 +548,7 @@ class _BaseTaskTestCase(object): @mock.patch.object(conductor_manager.ComputeTaskManager, '_create_and_bind_arqs') @mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance') - @mock.patch.object(db, 'block_device_mapping_get_all_by_instance', + @mock.patch.object(main_db_api, 'block_device_mapping_get_all_by_instance', return_value=[]) @mock.patch.object(conductor_manager.ComputeTaskManager, '_schedule_instances') @@ -2740,7 +2741,7 @@ class ConductorTaskTestCase(_BaseTaskTestCase, test_compute.BaseTestCase): self.assertEqual(0, len(build_requests)) - @db.api_context_manager.reader + @api_db_api.context_manager.reader def request_spec_get_all(context): return context.session.query(api_models.RequestSpec).all() diff --git a/nova/tests/unit/db/api/__init__.py b/nova/tests/unit/db/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nova/tests/unit/db/api/test_api.py b/nova/tests/unit/db/api/test_api.py new file mode 100644 index 000000000000..251407612fad --- /dev/null +++ b/nova/tests/unit/db/api/test_api.py @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from nova.db.api import api as db_api +from nova import test + + +class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase): + + @mock.patch.object(db_api, 'context_manager') + def test_get_engine(self, mock_ctxt_mgr): + db_api.get_engine() + mock_ctxt_mgr.writer.get_engine.assert_called_once_with() diff --git a/nova/tests/unit/db/main/test_api.py b/nova/tests/unit/db/main/test_api.py index 7254f9262b82..de92f5feb7c8 100644 --- a/nova/tests/unit/db/main/test_api.py +++ b/nova/tests/unit/db/main/test_api.py @@ -1,5 +1,3 @@ -# encoding=UTF8 - # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. @@ -52,6 +50,7 @@ from nova import context from nova.db.main import api as db from nova.db.main import models from nova.db import types as col_types +from nova.db import utils as db_utils from nova import exception from nova.objects import fields from nova import test @@ -207,7 +206,7 @@ class DecoratorTestCase(test.TestCase): self.assertEqual(test_func.__module__, decorated_func.__module__) def test_require_context_decorator_wraps_functions_properly(self): - self._test_decorator_wraps_helper(db.require_context) + self._test_decorator_wraps_helper(db_utils.require_context) def test_require_deadlock_retry_wraps_functions_properly(self): self._test_decorator_wraps_helper( @@ -628,7 +627,7 @@ class ModelQueryTestCase(DbTestCase): @mock.patch.object(sqlalchemyutils, 'model_query') def test_model_query_use_context_session(self, mock_model_query): - @db.main_context_manager.reader + @db.context_manager.reader def fake_method(context): session = context.session db.model_query(context, models.Instance) @@ -642,15 +641,15 @@ class ModelQueryTestCase(DbTestCase): class EngineFacadeTestCase(DbTestCase): def test_use_single_context_session_writer(self): # Checks that session in context would not be overwritten by - # annotation @db.main_context_manager.writer if annotation + # annotation @db.context_manager.writer if annotation # is used twice. - @db.main_context_manager.writer + @db.context_manager.writer def fake_parent_method(context): session = context.session return fake_child_method(context), session - @db.main_context_manager.writer + @db.context_manager.writer def fake_child_method(context): session = context.session db.model_query(context, models.Instance) @@ -661,15 +660,15 @@ class EngineFacadeTestCase(DbTestCase): def test_use_single_context_session_reader(self): # Checks that session in context would not be overwritten by - # annotation @db.main_context_manager.reader if annotation + # annotation @db.context_manager.reader if annotation # is used twice. - @db.main_context_manager.reader + @db.context_manager.reader def fake_parent_method(context): session = context.session return fake_child_method(context), session - @db.main_context_manager.reader + @db.context_manager.reader def fake_child_method(context): session = context.session db.model_query(context, models.Instance) @@ -757,12 +756,12 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase): self.assertEqual('|', filter('|')) self.assertEqual('LIKE', op) - @mock.patch.object(db, 'main_context_manager') + @mock.patch.object(db, 'context_manager') def test_get_engine(self, mock_ctxt_mgr): db.get_engine() mock_ctxt_mgr.writer.get_engine.assert_called_once_with() - @mock.patch.object(db, 'main_context_manager') + @mock.patch.object(db, 'context_manager') def test_get_engine_use_slave(self, mock_ctxt_mgr): db.get_engine(use_slave=True) mock_ctxt_mgr.reader.get_engine.assert_called_once_with() @@ -774,11 +773,6 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase): connection='fake://') self.assertEqual('fake://', db_conf['connection']) - @mock.patch.object(db, 'api_context_manager') - def test_get_api_engine(self, mock_ctxt_mgr): - db.get_api_engine() - mock_ctxt_mgr.writer.get_engine.assert_called_once_with() - @mock.patch.object(db, '_instance_get_by_uuid') @mock.patch.object(db, '_instances_fill_metadata') @mock.patch('oslo_db.sqlalchemy.utils.paginate_query') @@ -1013,114 +1007,6 @@ class SqlAlchemyDbApiTestCase(DbTestCase): self.assertEqual(1, len(instances)) -class ProcessSortParamTestCase(test.TestCase): - - def test_process_sort_params_defaults(self): - '''Verifies default sort parameters.''' - sort_keys, sort_dirs = db.process_sort_params([], []) - self.assertEqual(['created_at', 'id'], sort_keys) - self.assertEqual(['asc', 'asc'], sort_dirs) - - sort_keys, sort_dirs = db.process_sort_params(None, None) - self.assertEqual(['created_at', 'id'], sort_keys) - self.assertEqual(['asc', 'asc'], sort_dirs) - - def test_process_sort_params_override_default_keys(self): - '''Verifies that the default keys can be overridden.''' - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_keys=['key1', 'key2', 'key3']) - self.assertEqual(['key1', 'key2', 'key3'], sort_keys) - self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) - - def test_process_sort_params_override_default_dir(self): - '''Verifies that the default direction can be overridden.''' - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_dir='dir1') - self.assertEqual(['created_at', 'id'], sort_keys) - self.assertEqual(['dir1', 'dir1'], sort_dirs) - - def test_process_sort_params_override_default_key_and_dir(self): - '''Verifies that the default key and dir can be overridden.''' - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_keys=['key1', 'key2', 'key3'], - default_dir='dir1') - self.assertEqual(['key1', 'key2', 'key3'], sort_keys) - self.assertEqual(['dir1', 'dir1', 'dir1'], sort_dirs) - - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_keys=[], default_dir='dir1') - self.assertEqual([], sort_keys) - self.assertEqual([], sort_dirs) - - def test_process_sort_params_non_default(self): - '''Verifies that non-default keys are added correctly.''' - sort_keys, sort_dirs = db.process_sort_params( - ['key1', 'key2'], ['asc', 'desc']) - self.assertEqual(['key1', 'key2', 'created_at', 'id'], sort_keys) - # First sort_dir in list is used when adding the default keys - self.assertEqual(['asc', 'desc', 'asc', 'asc'], sort_dirs) - - def test_process_sort_params_default(self): - '''Verifies that default keys are added correctly.''' - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], ['asc', 'desc']) - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['asc', 'desc', 'asc'], sort_dirs) - - # Include default key value, rely on default direction - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], []) - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) - - def test_process_sort_params_default_dir(self): - '''Verifies that the default dir is applied to all keys.''' - # Direction is set, ignore default dir - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], ['desc'], default_dir='dir') - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['desc', 'desc', 'desc'], sort_dirs) - - # But should be used if no direction is set - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], [], default_dir='dir') - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['dir', 'dir', 'dir'], sort_dirs) - - def test_process_sort_params_unequal_length(self): - '''Verifies that a sort direction list is applied correctly.''' - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2', 'key3'], ['desc']) - self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) - self.assertEqual(['desc', 'desc', 'desc', 'desc'], sort_dirs) - - # Default direction is the first key in the list - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2', 'key3'], ['desc', 'asc']) - self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) - self.assertEqual(['desc', 'asc', 'desc', 'desc'], sort_dirs) - - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2', 'key3'], ['desc', 'asc', 'asc']) - self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) - self.assertEqual(['desc', 'asc', 'asc', 'desc'], sort_dirs) - - def test_process_sort_params_extra_dirs_lengths(self): - '''InvalidInput raised if more directions are given.''' - self.assertRaises(exception.InvalidInput, - db.process_sort_params, - ['key1', 'key2'], - ['asc', 'desc', 'desc']) - - def test_process_sort_params_invalid_sort_dir(self): - '''InvalidInput raised if invalid directions are given.''' - for dirs in [['foo'], ['asc', 'foo'], ['asc', 'desc', 'foo']]: - self.assertRaises(exception.InvalidInput, - db.process_sort_params, - ['key'], - dirs) - - class MigrationTestCase(test.TestCase): def setUp(self): @@ -1877,7 +1763,7 @@ class InstanceTestCase(test.TestCase, ModelsObjectComparatorMixin): @mock.patch.object(query.Query, 'filter') def test_instance_metadata_get_multi_no_uuids(self, mock_query_filter): - with db.main_context_manager.reader.using(self.ctxt): + with db.context_manager.reader.using(self.ctxt): db._instance_metadata_get_multi(self.ctxt, []) self.assertFalse(mock_query_filter.called) diff --git a/nova/tests/unit/db/test_migration.py b/nova/tests/unit/db/test_migration.py index f9c4d7e72e1e..66469e42a596 100644 --- a/nova/tests/unit/db/test_migration.py +++ b/nova/tests/unit/db/test_migration.py @@ -17,7 +17,8 @@ from migrate.versioning import api as versioning_api import mock import sqlalchemy -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova.db import migration from nova import test @@ -144,15 +145,17 @@ class TestDbVersionControl(test.NoDBTestCase): class TestGetEngine(test.NoDBTestCase): def test_get_main_engine(self): - with mock.patch.object(db_api, 'get_engine', - return_value='engine') as mock_get_engine: + with mock.patch.object( + main_db_api, 'get_engine', return_value='engine', + ) as mock_get_engine: engine = migration.get_engine() self.assertEqual('engine', engine) mock_get_engine.assert_called_once_with(context=None) def test_get_api_engine(self): - with mock.patch.object(db_api, 'get_api_engine', - return_value='api_engine') as mock_get_engine: + with mock.patch.object( + api_db_api, 'get_engine', return_value='engine', + ) as mock_get_engine: engine = migration.get_engine('api') - self.assertEqual('api_engine', engine) + self.assertEqual('engine', engine) mock_get_engine.assert_called_once_with() diff --git a/nova/tests/unit/db/test_utils.py b/nova/tests/unit/db/test_utils.py new file mode 100644 index 000000000000..5f293723f8c7 --- /dev/null +++ b/nova/tests/unit/db/test_utils.py @@ -0,0 +1,123 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from nova.db import utils +from nova import exception +from nova import test + + +class ProcessSortParamTestCase(test.TestCase): + + def test_process_sort_params_defaults(self): + """Verifies default sort parameters.""" + sort_keys, sort_dirs = utils.process_sort_params([], []) + self.assertEqual(['created_at', 'id'], sort_keys) + self.assertEqual(['asc', 'asc'], sort_dirs) + + sort_keys, sort_dirs = utils.process_sort_params(None, None) + self.assertEqual(['created_at', 'id'], sort_keys) + self.assertEqual(['asc', 'asc'], sort_dirs) + + def test_process_sort_params_override_default_keys(self): + """Verifies that the default keys can be overridden.""" + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_keys=['key1', 'key2', 'key3']) + self.assertEqual(['key1', 'key2', 'key3'], sort_keys) + self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) + + def test_process_sort_params_override_default_dir(self): + """Verifies that the default direction can be overridden.""" + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_dir='dir1') + self.assertEqual(['created_at', 'id'], sort_keys) + self.assertEqual(['dir1', 'dir1'], sort_dirs) + + def test_process_sort_params_override_default_key_and_dir(self): + """Verifies that the default key and dir can be overridden.""" + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_keys=['key1', 'key2', 'key3'], + default_dir='dir1') + self.assertEqual(['key1', 'key2', 'key3'], sort_keys) + self.assertEqual(['dir1', 'dir1', 'dir1'], sort_dirs) + + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_keys=[], default_dir='dir1') + self.assertEqual([], sort_keys) + self.assertEqual([], sort_dirs) + + def test_process_sort_params_non_default(self): + """Verifies that non-default keys are added correctly.""" + sort_keys, sort_dirs = utils.process_sort_params( + ['key1', 'key2'], ['asc', 'desc']) + self.assertEqual(['key1', 'key2', 'created_at', 'id'], sort_keys) + # First sort_dir in list is used when adding the default keys + self.assertEqual(['asc', 'desc', 'asc', 'asc'], sort_dirs) + + def test_process_sort_params_default(self): + """Verifies that default keys are added correctly.""" + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], ['asc', 'desc']) + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['asc', 'desc', 'asc'], sort_dirs) + + # Include default key value, rely on default direction + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], []) + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) + + def test_process_sort_params_default_dir(self): + """Verifies that the default dir is applied to all keys.""" + # Direction is set, ignore default dir + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], ['desc'], default_dir='dir') + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['desc', 'desc', 'desc'], sort_dirs) + + # But should be used if no direction is set + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], [], default_dir='dir') + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['dir', 'dir', 'dir'], sort_dirs) + + def test_process_sort_params_unequal_length(self): + """Verifies that a sort direction list is applied correctly.""" + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2', 'key3'], ['desc']) + self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) + self.assertEqual(['desc', 'desc', 'desc', 'desc'], sort_dirs) + + # Default direction is the first key in the list + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2', 'key3'], ['desc', 'asc']) + self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) + self.assertEqual(['desc', 'asc', 'desc', 'desc'], sort_dirs) + + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2', 'key3'], ['desc', 'asc', 'asc']) + self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) + self.assertEqual(['desc', 'asc', 'asc', 'desc'], sort_dirs) + + def test_process_sort_params_extra_dirs_lengths(self): + """InvalidInput raised if more directions are given.""" + self.assertRaises( + exception.InvalidInput, + utils.process_sort_params, + ['key1', 'key2'], + ['asc', 'desc', 'desc']) + + def test_process_sort_params_invalid_sort_dir(self): + """InvalidInput raised if invalid directions are given.""" + for dirs in [['foo'], ['asc', 'foo'], ['asc', 'desc', 'foo']]: + self.assertRaises( + exception.InvalidInput, + utils.process_sort_params, ['key'], dirs) diff --git a/nova/tests/unit/objects/test_flavor.py b/nova/tests/unit/objects/test_flavor.py index 92adbc6df904..3f61a8fafbc0 100644 --- a/nova/tests/unit/objects/test_flavor.py +++ b/nova/tests/unit/objects/test_flavor.py @@ -19,8 +19,8 @@ from oslo_db import exception as db_exc from oslo_utils import uuidutils from nova import context as nova_context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova import objects from nova.objects import fields @@ -89,7 +89,7 @@ class _TestFlavor(object): mock_get.assert_called_once_with(self.context, 'm1.foo') @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_api_flavor(context, altid=None): fake_db_flavor = dict(fake_flavor) del fake_db_flavor['extra_specs'] diff --git a/nova/tests/unit/test_conf.py b/nova/tests/unit/test_conf.py index 722f32d4c186..95a7c45114cc 100644 --- a/nova/tests/unit/test_conf.py +++ b/nova/tests/unit/test_conf.py @@ -87,19 +87,24 @@ class TestParseArgs(test.NoDBTestCase): def setUp(self): super(TestParseArgs, self).setUp() m = mock.patch('nova.db.main.api.configure') - self.nova_db_config_mock = m.start() - self.addCleanup(self.nova_db_config_mock.stop) + self.main_db_config_mock = m.start() + self.addCleanup(self.main_db_config_mock.stop) + m = mock.patch('nova.db.api.api.configure') + self.api_db_config_mock = m.start() + self.addCleanup(self.api_db_config_mock.stop) @mock.patch.object(config.log, 'register_options') def test_parse_args_glance_debug_false(self, register_options): self.flags(debug=False, group='glance') config.parse_args([], configure_db=False, init_rpc=False) self.assertIn('glanceclient=WARN', config.CONF.default_log_levels) - self.nova_db_config_mock.assert_not_called() + self.main_db_config_mock.assert_not_called() + self.api_db_config_mock.assert_not_called() @mock.patch.object(config.log, 'register_options') def test_parse_args_glance_debug_true(self, register_options): self.flags(debug=True, group='glance') config.parse_args([], configure_db=True, init_rpc=False) self.assertIn('glanceclient=DEBUG', config.CONF.default_log_levels) - self.nova_db_config_mock.assert_called_once_with(config.CONF) + self.main_db_config_mock.assert_called_once_with(config.CONF) + self.api_db_config_mock.assert_called_once_with(config.CONF) diff --git a/nova/tests/unit/test_fixtures.py b/nova/tests/unit/test_fixtures.py index 48913d05ad69..212b8107394d 100644 --- a/nova/tests/unit/test_fixtures.py +++ b/nova/tests/unit/test_fixtures.py @@ -34,7 +34,8 @@ import testtools from nova.compute import rpcapi as compute_rpcapi from nova import conductor from nova import context -from nova.db.main import api as session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova.network import neutron as neutron_api from nova import objects @@ -121,7 +122,7 @@ class TestDatabaseFixture(testtools.TestCase): # because this sets up reasonable db connection strings self.useFixture(fixtures.ConfFixture()) self.useFixture(fixtures.Database()) - engine = session.get_engine() + engine = main_db_api.get_engine() conn = engine.connect() result = conn.execute("select * from instance_types") rows = result.fetchall() @@ -152,7 +153,7 @@ class TestDatabaseFixture(testtools.TestCase): # This sets up reasonable db connection strings self.useFixture(fixtures.ConfFixture()) self.useFixture(fixtures.Database(database='api')) - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() result = conn.execute("select * from cell_mappings") rows = result.fetchall() @@ -186,7 +187,7 @@ class TestDatabaseFixture(testtools.TestCase): fix.cleanup() # ensure the db contains nothing - engine = session.get_engine() + engine = main_db_api.get_engine() conn = engine.connect() schema = "".join(line for line in conn.connection.iterdump()) self.assertEqual(schema, "BEGIN TRANSACTION;COMMIT;") @@ -198,7 +199,7 @@ class TestDatabaseFixture(testtools.TestCase): self.useFixture(fix) # No data inserted by migrations so we need to add a row - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() uuid = uuidutils.generate_uuid() conn.execute("insert into cell_mappings (uuid, name) VALUES " @@ -211,7 +212,7 @@ class TestDatabaseFixture(testtools.TestCase): fix.cleanup() # Ensure the db contains nothing - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() schema = "".join(line for line in conn.connection.iterdump()) self.assertEqual("BEGIN TRANSACTION;COMMIT;", schema) @@ -224,7 +225,7 @@ class TestDefaultFlavorsFixture(testtools.TestCase): self.useFixture(fixtures.Database()) self.useFixture(fixtures.Database(database='api')) - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() result = conn.execute("select * from flavors") rows = result.fetchall()