diff --git a/nova/cmd/status.py b/nova/cmd/status.py index 8e8f8b0dcdd8..2f14e3fed2ff 100644 --- a/nova/cmd/status.py +++ b/nova/cmd/status.py @@ -32,7 +32,8 @@ from nova.cmd import common as cmd_common import nova.conf from nova import config from nova import context as nova_context -from nova.db.main import api as db_session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova.i18n import _ from nova.objects import cell_mapping as cell_mapping_obj @@ -86,7 +87,7 @@ class UpgradeCommands(upgradecheck.UpgradeCommands): # table, or by only counting compute nodes with a service version of at # least 15 which was the highest service version when Newton was # released. - meta = sa.MetaData(bind=db_session.get_engine(context=context)) + meta = sa.MetaData(bind=main_db_api.get_engine(context=context)) compute_nodes = sa.Table('compute_nodes', meta, autoload=True) return sa.select([sqlfunc.count()]).select_from(compute_nodes).where( compute_nodes.c.deleted == 0).scalar() @@ -103,7 +104,7 @@ class UpgradeCommands(upgradecheck.UpgradeCommands): for compute nodes if there are no host mappings on a fresh install. """ meta = sa.MetaData() - meta.bind = db_session.get_api_engine() + meta.bind = api_db_api.get_engine() cell_mappings = self._get_cell_mappings() count = len(cell_mappings) diff --git a/nova/compute/api.py b/nova/compute/api.py index 909cd52962c8..bc75399af70c 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -53,7 +53,8 @@ from nova import conductor import nova.conf from nova import context as nova_context from nova import crypto -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova import exception_wrapper from nova.i18n import _ @@ -1081,7 +1082,7 @@ class API: network_metadata) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_reqspec_buildreq_instmapping(context, rs, br, im): """Create the request spec, build request, and instance mapping in a single database transaction. @@ -5082,7 +5083,7 @@ class API: def get_instance_metadata(self, context, instance): """Get all metadata associated with an instance.""" - return db_api.instance_metadata_get(context, instance.uuid) + return main_db_api.instance_metadata_get(context, instance.uuid) @check_instance_lock @check_instance_state(vm_state=[vm_states.ACTIVE, vm_states.PAUSED, @@ -5962,7 +5963,7 @@ class HostAPI: """Return the task logs within a given range, optionally filtering by host and/or state. """ - return db_api.task_log_get_all( + return main_db_api.task_log_get_all( context, task_name, period_beginning, period_ending, host=host, state=state) @@ -6055,7 +6056,7 @@ class HostAPI: if cell.uuid == objects.CellMapping.CELL0_UUID: continue with nova_context.target_cell(context, cell) as cctxt: - cell_stats.append(db_api.compute_node_statistics(cctxt)) + cell_stats.append(main_db_api.compute_node_statistics(cctxt)) if cell_stats: keys = cell_stats[0].keys() diff --git a/nova/config.py b/nova/config.py index 1af845897646..98aad0c73f18 100644 --- a/nova/config.py +++ b/nova/config.py @@ -22,7 +22,8 @@ from oslo_policy import opts as policy_opts from oslo_utils import importutils import nova.conf -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import middleware from nova import policy from nova import rpc @@ -100,4 +101,5 @@ def parse_args(argv, default_config_files=None, configure_db=True, rpc.init(CONF) if configure_db: - db_api.configure(CONF) + main_db_api.configure(CONF) + api_db_api.configure(CONF) diff --git a/nova/db/api/api.py b/nova/db/api/api.py new file mode 100644 index 000000000000..713562e4338e --- /dev/null +++ b/nova/db/api/api.py @@ -0,0 +1,50 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_db.sqlalchemy import enginefacade +from oslo_utils import importutils +import sqlalchemy as sa + +import nova.conf + +profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy') + +CONF = nova.conf.CONF + +context_manager = enginefacade.transaction_context() + +# NOTE(stephenfin): We don't need equivalents of the 'get_context_manager' or +# 'create_context_manager' APIs found in 'nova.db.main.api' since we don't need +# to be cell-aware here + + +def _get_db_conf(conf_group, connection=None): + kw = dict(conf_group.items()) + if connection is not None: + kw['connection'] = connection + return kw + + +def configure(conf): + context_manager.configure(**_get_db_conf(conf.api_database)) + + if ( + profiler_sqlalchemy and + CONF.profiler.enabled and + CONF.profiler.trace_sqlalchemy + ): + context_manager.append_on_engine_create( + lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) + + +def get_engine(): + return context_manager.writer.get_engine() diff --git a/nova/db/main/api.py b/nova/db/main/api.py index b19d761b5cd4..83e390bf6d37 100644 --- a/nova/db/main/api.py +++ b/nova/db/main/api.py @@ -22,7 +22,6 @@ import copy import datetime import functools import inspect -import sys import traceback from oslo_db import api as oslo_db_api @@ -49,6 +48,8 @@ from nova.compute import vm_states import nova.conf import nova.context from nova.db.main import models +from nova.db import utils as db_utils +from nova.db.utils import require_context from nova import exception from nova.i18n import _ from nova import safe_utils @@ -60,8 +61,7 @@ LOG = logging.getLogger(__name__) DISABLE_DB_ACCESS = False -main_context_manager = enginefacade.transaction_context() -api_context_manager = enginefacade.transaction_context() +context_manager = enginefacade.transaction_context() def _get_db_conf(conf_group, connection=None): @@ -89,15 +89,14 @@ def _joinedload_all(column): def configure(conf): - main_context_manager.configure(**_get_db_conf(conf.database)) - api_context_manager.configure(**_get_db_conf(conf.api_database)) + context_manager.configure(**_get_db_conf(conf.database)) - if profiler_sqlalchemy and CONF.profiler.enabled \ - and CONF.profiler.trace_sqlalchemy: - - main_context_manager.append_on_engine_create( - lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) - api_context_manager.append_on_engine_create( + if ( + profiler_sqlalchemy and + CONF.profiler.enabled and + CONF.profiler.trace_sqlalchemy + ): + context_manager.append_on_engine_create( lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) @@ -116,7 +115,7 @@ def get_context_manager(context): :param context: The request context that can contain a context manager """ - return _context_manager_from_context(context) or main_context_manager + return _context_manager_from_context(context) or context_manager def get_engine(use_slave=False, context=None): @@ -131,40 +130,11 @@ def get_engine(use_slave=False, context=None): return ctxt_mgr.writer.get_engine() -def get_api_engine(): - return api_context_manager.writer.get_engine() - - _SHADOW_TABLE_PREFIX = 'shadow_' _DEFAULT_QUOTA_NAME = 'default' PER_PROJECT_QUOTAS = ['fixed_ips', 'floating_ips', 'networks'] -# NOTE(stephenfin): This is required and used by oslo.db -def get_backend(): - """The backend is this module itself.""" - return sys.modules[__name__] - - -def require_context(f): - """Decorator to require *any* user or admin context. - - This does no authorization for user or project access matching, see - :py:func:`nova.context.authorize_project_context` and - :py:func:`nova.context.authorize_user_context`. - - The first argument to the wrapped function must be the context. - - """ - - @functools.wraps(f) - def wrapper(*args, **kwargs): - nova.context.require_context(args[0]) - return f(*args, **kwargs) - wrapper.__signature__ = inspect.signature(f) - return wrapper - - def select_db_reader_mode(f): """Decorator to select synchronous or asynchronous reader mode. @@ -1662,9 +1632,8 @@ def instance_get_all_by_filters_sort(context, filters, limit=None, marker=None, if limit == 0: return [] - sort_keys, sort_dirs = process_sort_params(sort_keys, - sort_dirs, - default_dir='desc') + sort_keys, sort_dirs = db_utils.process_sort_params( + sort_keys, sort_dirs, default_dir='desc') if columns_to_join is None: columns_to_join_new = ['info_cache', 'security_groups'] @@ -2043,75 +2012,6 @@ def _exact_instance_filter(query, filters, legal_keys): return query -def process_sort_params(sort_keys, sort_dirs, - default_keys=['created_at', 'id'], - default_dir='asc'): - """Process the sort parameters to include default keys. - - Creates a list of sort keys and a list of sort directions. Adds the default - keys to the end of the list if they are not already included. - - When adding the default keys to the sort keys list, the associated - direction is: - 1) The first element in the 'sort_dirs' list (if specified), else - 2) 'default_dir' value (Note that 'asc' is the default value since this is - the default in sqlalchemy.utils.paginate_query) - - :param sort_keys: List of sort keys to include in the processed list - :param sort_dirs: List of sort directions to include in the processed list - :param default_keys: List of sort keys that need to be included in the - processed list, they are added at the end of the list - if not already specified. - :param default_dir: Sort direction associated with each of the default - keys that are not supplied, used when they are added - to the processed list - :returns: list of sort keys, list of sort directions - :raise exception.InvalidInput: If more sort directions than sort keys - are specified or if an invalid sort - direction is specified - """ - # Determine direction to use for when adding default keys - if sort_dirs and len(sort_dirs) != 0: - default_dir_value = sort_dirs[0] - else: - default_dir_value = default_dir - - # Create list of keys (do not modify the input list) - if sort_keys: - result_keys = list(sort_keys) - else: - result_keys = [] - - # If a list of directions is not provided, use the default sort direction - # for all provided keys - if sort_dirs: - result_dirs = [] - # Verify sort direction - for sort_dir in sort_dirs: - if sort_dir not in ('asc', 'desc'): - msg = _("Unknown sort direction, must be 'desc' or 'asc'") - raise exception.InvalidInput(reason=msg) - result_dirs.append(sort_dir) - else: - result_dirs = [default_dir_value for _sort_key in result_keys] - - # Ensure that the key and direction length match - while len(result_dirs) < len(result_keys): - result_dirs.append(default_dir_value) - # Unless more direction are specified, which is an error - if len(result_dirs) > len(result_keys): - msg = _("Sort direction size exceeds sort key size") - raise exception.InvalidInput(reason=msg) - - # Ensure defaults are included - for key in default_keys: - if key not in result_keys: - result_keys.append(key) - result_dirs.append(default_dir_value) - - return result_keys, result_dirs - - @require_context @pick_context_manager_reader_allow_async def instance_get_active_by_window_joined(context, begin, end=None, @@ -3507,8 +3407,8 @@ def migration_get_all_by_filters(context, filters, raise exception.MarkerNotFound(marker=marker) if limit or marker or sort_keys or sort_dirs: # Default sort by desc(['created_at', 'id']) - sort_keys, sort_dirs = process_sort_params(sort_keys, sort_dirs, - default_dir='desc') + sort_keys, sort_dirs = db_utils.process_sort_params( + sort_keys, sort_dirs, default_dir='desc') return sqlalchemyutils.paginate_query(query, models.Migration, limit=limit, diff --git a/nova/db/migration.py b/nova/db/migration.py index a1e49a216025..d5ae303a3ee3 100644 --- a/nova/db/migration.py +++ b/nova/db/migration.py @@ -22,7 +22,8 @@ from migrate.versioning.repository import Repository from oslo_log import log as logging import sqlalchemy -from nova.db.main import api as db_session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova.i18n import _ @@ -36,10 +37,10 @@ LOG = logging.getLogger(__name__) def get_engine(database='main', context=None): if database == 'main': - return db_session.get_engine(context=context) + return main_db_api.get_engine(context=context) if database == 'api': - return db_session.get_api_engine() + return api_db_api.get_engine() def find_migrate_repo(database='main'): diff --git a/nova/db/utils.py b/nova/db/utils.py new file mode 100644 index 000000000000..234845a359be --- /dev/null +++ b/nova/db/utils.py @@ -0,0 +1,109 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import inspect + +import nova.context +from nova import exception +from nova.i18n import _ + + +def require_context(f): + """Decorator to require *any* user or admin context. + + This does no authorization for user or project access matching, see + :py:func:`nova.context.authorize_project_context` and + :py:func:`nova.context.authorize_user_context`. + + The first argument to the wrapped function must be the context. + + """ + + @functools.wraps(f) + def wrapper(*args, **kwargs): + nova.context.require_context(args[0]) + return f(*args, **kwargs) + wrapper.__signature__ = inspect.signature(f) + return wrapper + + +def process_sort_params( + sort_keys, + sort_dirs, + default_keys=['created_at', 'id'], + default_dir='asc', +): + """Process the sort parameters to include default keys. + + Creates a list of sort keys and a list of sort directions. Adds the default + keys to the end of the list if they are not already included. + + When adding the default keys to the sort keys list, the associated + direction is: + + 1. The first element in the 'sort_dirs' list (if specified), else + 2. 'default_dir' value (Note that 'asc' is the default value since this is + the default in sqlalchemy.utils.paginate_query) + + :param sort_keys: List of sort keys to include in the processed list + :param sort_dirs: List of sort directions to include in the processed list + :param default_keys: List of sort keys that need to be included in the + processed list, they are added at the end of the list if not already + specified. + :param default_dir: Sort direction associated with each of the default + keys that are not supplied, used when they are added to the processed + list + :returns: list of sort keys, list of sort directions + :raise exception.InvalidInput: If more sort directions than sort keys + are specified or if an invalid sort direction is specified + """ + # Determine direction to use for when adding default keys + if sort_dirs and len(sort_dirs) != 0: + default_dir_value = sort_dirs[0] + else: + default_dir_value = default_dir + + # Create list of keys (do not modify the input list) + if sort_keys: + result_keys = list(sort_keys) + else: + result_keys = [] + + # If a list of directions is not provided, use the default sort direction + # for all provided keys + if sort_dirs: + result_dirs = [] + # Verify sort direction + for sort_dir in sort_dirs: + if sort_dir not in ('asc', 'desc'): + msg = _("Unknown sort direction, must be 'desc' or 'asc'") + raise exception.InvalidInput(reason=msg) + result_dirs.append(sort_dir) + else: + result_dirs = [default_dir_value for _sort_key in result_keys] + + # Ensure that the key and direction length match + while len(result_dirs) < len(result_keys): + result_dirs.append(default_dir_value) + # Unless more direction are specified, which is an error + if len(result_dirs) > len(result_keys): + msg = _("Sort direction size exceeds sort key size") + raise exception.InvalidInput(reason=msg) + + # Ensure defaults are included + for key in default_keys: + if key not in result_keys: + result_keys.append(key) + result_dirs.append(default_dir_value) + + return result_keys, result_dirs diff --git a/nova/objects/aggregate.py b/nova/objects/aggregate.py index 5880165cf355..2aa802cf9b8f 100644 --- a/nova/objects/aggregate.py +++ b/nova/objects/aggregate.py @@ -19,8 +19,8 @@ from oslo_utils import uuidutils from sqlalchemy import orm from nova.compute import utils as compute_utils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova.i18n import _ from nova import objects @@ -32,7 +32,7 @@ LOG = logging.getLogger(__name__) DEPRECATED_FIELDS = ['deleted', 'deleted_at'] -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_get_from_db(context, aggregate_id): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -47,7 +47,7 @@ def _aggregate_get_from_db(context, aggregate_id): return aggregate -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_get_from_db_by_uuid(context, aggregate_uuid): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -64,7 +64,7 @@ def _aggregate_get_from_db_by_uuid(context, aggregate_uuid): def _host_add_to_db(context, aggregate_id, host): try: - with db_api.api_context_manager.writer.using(context): + with api_db_api.context_manager.writer.using(context): # Check to see if the aggregate exists _aggregate_get_from_db(context, aggregate_id) @@ -79,7 +79,7 @@ def _host_add_to_db(context, aggregate_id, host): def _host_delete_from_db(context, aggregate_id, host): count = 0 - with db_api.api_context_manager.writer.using(context): + with api_db_api.context_manager.writer.using(context): # Check to see if the aggregate exists _aggregate_get_from_db(context, aggregate_id) @@ -98,7 +98,7 @@ def _metadata_add_to_db(context, aggregate_id, metadata, max_retries=10, all_keys = metadata.keys() for attempt in range(max_retries): try: - with db_api.api_context_manager.writer.using(context): + with api_db_api.context_manager.writer.using(context): query = context.session.query(api_models.AggregateMetadata).\ filter_by(aggregate_id=aggregate_id) @@ -142,7 +142,7 @@ def _metadata_add_to_db(context, aggregate_id, metadata, max_retries=10, LOG.warning(msg) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _metadata_delete_from_db(context, aggregate_id, key): # Check to see if the aggregate exists _aggregate_get_from_db(context, aggregate_id) @@ -157,7 +157,7 @@ def _metadata_delete_from_db(context, aggregate_id, key): aggregate_id=aggregate_id, metadata_key=key) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _aggregate_create_in_db(context, values, metadata=None): query = context.session.query(api_models.Aggregate) query = query.filter(api_models.Aggregate.name == values['name']) @@ -181,7 +181,7 @@ def _aggregate_create_in_db(context, values, metadata=None): return aggregate -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _aggregate_delete_from_db(context, aggregate_id): # Delete Metadata first context.session.query(api_models.AggregateMetadata).\ @@ -196,7 +196,7 @@ def _aggregate_delete_from_db(context, aggregate_id): raise exception.AggregateNotFound(aggregate_id=aggregate_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _aggregate_update_to_db(context, aggregate_id, values): aggregate = _aggregate_get_from_db(context, aggregate_id) @@ -411,7 +411,7 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject): return self.metadata.get('availability_zone', None) -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_all_from_db(context): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -420,7 +420,7 @@ def _get_all_from_db(context): return query.all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_by_host_from_db(context, host, key=None): query = context.session.query(api_models.Aggregate).\ options(orm.joinedload('_hosts')).\ @@ -435,7 +435,7 @@ def _get_by_host_from_db(context, host, key=None): return query.all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_by_metadata_from_db(context, key=None, value=None): assert(key is not None or value is not None) query = context.session.query(api_models.Aggregate) @@ -450,7 +450,7 @@ def _get_by_metadata_from_db(context, key=None, value=None): return query.all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_non_matching_by_metadata_keys_from_db(context, ignored_keys, key_prefix, value): """Filter aggregates based on non matching metadata. diff --git a/nova/objects/build_request.py b/nova/objects/build_request.py index 67f271cb94d6..11ca05cbc70a 100644 --- a/nova/objects/build_request.py +++ b/nova/objects/build_request.py @@ -18,8 +18,9 @@ from oslo_serialization import jsonutils from oslo_utils import versionutils from oslo_versionedobjects import exception as ovoo_exc +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db import utils as db_utils from nova import exception from nova import objects from nova.objects import base @@ -163,7 +164,7 @@ class BuildRequest(base.NovaObject): return req @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuid_from_db(context, instance_uuid): db_req = context.session.query(api_models.BuildRequest).filter_by( instance_uuid=instance_uuid).first() @@ -177,7 +178,7 @@ class BuildRequest(base.NovaObject): return cls._from_db_object(context, cls(), db_req) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_req = api_models.BuildRequest() db_req.update(updates) @@ -206,7 +207,7 @@ class BuildRequest(base.NovaObject): self._from_db_object(self._context, self, db_req) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, instance_uuid): result = context.session.query(api_models.BuildRequest).filter_by( instance_uuid=instance_uuid).delete() @@ -217,7 +218,7 @@ class BuildRequest(base.NovaObject): def destroy(self): self._destroy_in_db(self._context, self.instance_uuid) - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(self, context, req_id, updates): db_req = context.session.query( api_models.BuildRequest).filter_by(id=req_id).first() @@ -262,7 +263,7 @@ class BuildRequestList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db(context): query = context.session.query(api_models.BuildRequest) @@ -396,8 +397,8 @@ class BuildRequestList(base.ObjectListBase, base.NovaObject): # exists. So it can be ignored. # 'deleted' and 'cleaned' are handled above. - sort_keys, sort_dirs = db.process_sort_params(sort_keys, sort_dirs, - default_dir='desc') + sort_keys, sort_dirs = db_utils.process_sort_params( + sort_keys, sort_dirs, default_dir='desc') # For other filters that don't match this, we will do regexp matching # Taken from db/sqlalchemy/api.py diff --git a/nova/objects/cell_mapping.py b/nova/objects/cell_mapping.py index f19f2f190842..595ec43e480e 100644 --- a/nova/objects/cell_mapping.py +++ b/nova/objects/cell_mapping.py @@ -18,8 +18,8 @@ from sqlalchemy import sql from sqlalchemy.sql import expression import nova.conf -from nova.db.api import models as api_models -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.api import models as api_db_models from nova import exception from nova.objects import base from nova.objects import fields @@ -168,11 +168,11 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): return cell_mapping @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_uuid_from_db(context, uuid): - db_mapping = context.session.query(api_models.CellMapping).filter_by( - uuid=uuid).first() + db_mapping = context.session\ + .query(api_db_models.CellMapping).filter_by(uuid=uuid).first() if not db_mapping: raise exception.CellMappingNotFound(uuid=uuid) @@ -185,10 +185,10 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): return cls._from_db_object(context, cls(), db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): - db_mapping = api_models.CellMapping() + db_mapping = api_db_models.CellMapping() db_mapping.update(updates) db_mapping.save(context.session) return db_mapping @@ -199,11 +199,11 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): self._from_db_object(self._context, self, db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, uuid, updates): db_mapping = context.session.query( - api_models.CellMapping).filter_by(uuid=uuid).first() + api_db_models.CellMapping).filter_by(uuid=uuid).first() if not db_mapping: raise exception.CellMappingNotFound(uuid=uuid) @@ -219,10 +219,10 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject): self.obj_reset_changes() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, uuid): - result = context.session.query(api_models.CellMapping).filter_by( + result = context.session.query(api_db_models.CellMapping).filter_by( uuid=uuid).delete() if not result: raise exception.CellMappingNotFound(uuid=uuid) @@ -246,10 +246,10 @@ class CellMappingList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db(context): - return context.session.query(api_models.CellMapping).order_by( - expression.asc(api_models.CellMapping.id)).all() + return context.session.query(api_db_models.CellMapping).order_by( + expression.asc(api_db_models.CellMapping.id)).all() @base.remotable_classmethod def get_all(cls, context): @@ -257,16 +257,16 @@ class CellMappingList(base.ObjectListBase, base.NovaObject): return base.obj_make_list(context, cls(), CellMapping, db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_disabled_from_db(context, disabled): if disabled: - return context.session.query(api_models.CellMapping)\ + return context.session.query(api_db_models.CellMapping)\ .filter_by(disabled=sql.true())\ - .order_by(expression.asc(api_models.CellMapping.id)).all() + .order_by(expression.asc(api_db_models.CellMapping.id)).all() else: - return context.session.query(api_models.CellMapping)\ + return context.session.query(api_db_models.CellMapping)\ .filter_by(disabled=sql.false())\ - .order_by(expression.asc(api_models.CellMapping.id)).all() + .order_by(expression.asc(api_db_models.CellMapping.id)).all() @base.remotable_classmethod def get_by_disabled(cls, context, disabled): @@ -274,16 +274,16 @@ class CellMappingList(base.ObjectListBase, base.NovaObject): return base.obj_make_list(context, cls(), CellMapping, db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_project_id_from_db(context, project_id): # SELECT DISTINCT cell_id FROM instance_mappings \ # WHERE project_id = $project_id; cell_ids = context.session.query( - api_models.InstanceMapping.cell_id).filter_by( + api_db_models.InstanceMapping.cell_id).filter_by( project_id=project_id).distinct().subquery() # SELECT cell_mappings WHERE cell_id IN ($cell_ids); - return context.session.query(api_models.CellMapping).filter( - api_models.CellMapping.id.in_(cell_ids)).all() + return context.session.query(api_db_models.CellMapping).filter( + api_db_models.CellMapping.id.in_(cell_ids)).all() @classmethod def get_by_project_id(cls, context, project_id): diff --git a/nova/objects/flavor.py b/nova/objects/flavor.py index ae27b6389288..20378b0dae53 100644 --- a/nova/objects/flavor.py +++ b/nova/objects/flavor.py @@ -21,8 +21,9 @@ from sqlalchemy import sql from sqlalchemy.sql import expression import nova.conf +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api +from nova.db import utils as db_utils from nova import exception from nova.notifications.objects import base as notification from nova.notifications.objects import flavor as flavor_notification @@ -51,7 +52,7 @@ def _dict_with_extra_specs(flavor_model): # decorators with static methods. We pull these out for now and can # move them back into the actual staticmethods on the object when those # issues are resolved. -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _get_projects_from_db(context, flavorid): db_flavor = context.session.query(api_models.Flavors).\ filter_by(flavorid=flavorid).\ @@ -62,7 +63,7 @@ def _get_projects_from_db(context, flavorid): return [x['project_id'] for x in db_flavor['projects']] -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_add_project(context, flavor_id, project_id): project = api_models.FlavorProjects() project.update({'flavor_id': flavor_id, @@ -74,7 +75,7 @@ def _flavor_add_project(context, flavor_id, project_id): project_id=project_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_del_project(context, flavor_id, project_id): result = context.session.query(api_models.FlavorProjects).\ filter_by(project_id=project_id).\ @@ -85,9 +86,9 @@ def _flavor_del_project(context, flavor_id, project_id): project_id=project_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_extra_specs_add(context, flavor_id, specs, max_retries=10): - writer = db_api.api_context_manager.writer + writer = api_db_api.context_manager.writer for attempt in range(max_retries): try: spec_refs = context.session.query( @@ -122,7 +123,7 @@ def _flavor_extra_specs_add(context, flavor_id, specs, max_retries=10): id=flavor_id, retries=max_retries) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_extra_specs_del(context, flavor_id, key): result = context.session.query(api_models.FlavorExtraSpecs).\ filter_by(flavor_id=flavor_id).\ @@ -133,7 +134,7 @@ def _flavor_extra_specs_del(context, flavor_id, key): extra_specs_key=key, flavor_id=flavor_id) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_create(context, values): specs = values.get('extra_specs') db_specs = [] @@ -169,7 +170,7 @@ def _flavor_create(context, values): return _dict_with_extra_specs(db_flavor) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _flavor_destroy(context, flavor_id=None, flavorid=None): query = context.session.query(api_models.Flavors) @@ -268,7 +269,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return flavor @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _flavor_get_query_from_db(context): query = context.session.query(api_models.Flavors).\ options(orm.joinedload('extra_specs')) @@ -281,7 +282,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return query @staticmethod - @db_api.require_context + @db_utils.require_context def _flavor_get_from_db(context, id): """Returns a dict describing specific flavor.""" result = Flavor._flavor_get_query_from_db(context).\ @@ -292,7 +293,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return _dict_with_extra_specs(result) @staticmethod - @db_api.require_context + @db_utils.require_context def _flavor_get_by_name_from_db(context, name): """Returns a dict describing specific flavor.""" result = Flavor._flavor_get_query_from_db(context).\ @@ -303,7 +304,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, return _dict_with_extra_specs(result) @staticmethod - @db_api.require_context + @db_utils.require_context def _flavor_get_by_flavor_id_from_db(context, flavor_id): """Returns a dict describing specific flavor_id.""" result = Flavor._flavor_get_query_from_db(context).\ @@ -485,7 +486,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, # NOTE(mriedem): This method is not remotable since we only expect the API # to be able to make updates to a flavor. - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save(self, context, values): db_flavor = context.session.query(api_models.Flavors).\ filter_by(id=self.id).first() @@ -581,7 +582,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject, payload=payload).emit(self._context) -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _flavor_get_all_from_db(context, inactive, filters, sort_key, sort_dir, limit, marker): """Returns all flavors. diff --git a/nova/objects/host_mapping.py b/nova/objects/host_mapping.py index 98e36f3edfdf..09dfe81354b3 100644 --- a/nova/objects/host_mapping.py +++ b/nova/objects/host_mapping.py @@ -14,8 +14,8 @@ from oslo_db import exception as db_exc from sqlalchemy import orm from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova.i18n import _ from nova.objects import base @@ -52,7 +52,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): } def _get_cell_mapping(self): - with db_api.api_context_manager.reader.using(self._context) as session: + with api_db_api.context_manager.reader.using(self._context) as session: cell_map = (session.query(api_models.CellMapping) .join(api_models.HostMapping) .filter(api_models.HostMapping.host == self.host) @@ -87,7 +87,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): return host_mapping @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_host_from_db(context, host): db_mapping = context.session.query(api_models.HostMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -102,7 +102,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): return cls._from_db_object(context, cls(), db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_mapping = api_models.HostMapping() return _apply_updates(context, db_mapping, updates) @@ -116,7 +116,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): self._from_db_object(self._context, self, db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, obj, updates): db_mapping = context.session.query(api_models.HostMapping).filter_by( id=obj.id).first() @@ -134,7 +134,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject): self.obj_reset_changes() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, host): result = context.session.query(api_models.HostMapping).filter_by( host=host).delete() @@ -157,7 +157,7 @@ class HostMappingList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db(context, cell_id=None): query = (context.session.query(api_models.HostMapping) .options(orm.joinedload('cell_mapping'))) diff --git a/nova/objects/instance_group.py b/nova/objects/instance_group.py index 26df07df92b9..34cf40b7fe4a 100644 --- a/nova/objects/instance_group.py +++ b/nova/objects/instance_group.py @@ -22,8 +22,8 @@ from oslo_utils import versionutils from sqlalchemy import orm from nova.compute import utils as compute_utils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova import objects from nova.objects import base @@ -213,7 +213,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return instance_group @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_uuid(context, uuid): grp = _instance_group_get_query(context, id_field=api_models.InstanceGroup.uuid, @@ -223,7 +223,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_id(context, id): grp = _instance_group_get_query(context, id_field=api_models.InstanceGroup.id, @@ -233,7 +233,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_name(context, name): grp = _instance_group_get_query(context).filter_by(name=name).first() if not grp: @@ -241,7 +241,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db_by_instance(context, instance_uuid): grp_member = context.session.query(api_models.InstanceGroupMember).\ filter_by(instance_uuid=instance_uuid).first() @@ -251,7 +251,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, group_uuid, values): grp = InstanceGroup._get_from_db_by_uuid(context, group_uuid) values_copy = copy.copy(values) @@ -265,7 +265,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return grp @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, values, policies=None, members=None, policy=None, rules=None): try: @@ -301,7 +301,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, return group @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, group_uuid): qry = _instance_group_get_query(context, id_field=api_models.InstanceGroup.uuid, @@ -319,13 +319,13 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, qry.delete() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _add_members_in_db(context, group_uuid, members): return _instance_group_members_add_by_uuid(context, group_uuid, members) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _remove_members_in_db(context, group_id, instance_uuids): # There is no public method provided for removing members because the # user-facing API doesn't allow removal of instance group members. We @@ -337,7 +337,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject, delete(synchronize_session=False) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_members_bulk_in_db(context, instance_uuids): return context.session.query(api_models.InstanceGroupMember).filter( api_models.InstanceGroupMember.instance_uuid.in_(instance_uuids)).\ @@ -537,7 +537,7 @@ class InstanceGroupList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db(context, project_id=None): query = _instance_group_get_query(context) if project_id is not None: @@ -545,7 +545,7 @@ class InstanceGroupList(base.ObjectListBase, base.NovaObject): return query.all() @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_counts_from_db(context, project_id, user_id=None): query = context.session.query(api_models.InstanceGroup.id).\ filter_by(project_id=project_id) diff --git a/nova/objects/instance_mapping.py b/nova/objects/instance_mapping.py index 43fa5333ba3b..68f45cd8cc3c 100644 --- a/nova/objects/instance_mapping.py +++ b/nova/objects/instance_mapping.py @@ -20,8 +20,8 @@ from sqlalchemy import sql from sqlalchemy.sql import func from nova import context as nova_context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova.i18n import _ from nova import objects @@ -96,7 +96,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): return instance_mapping @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuid_from_db(context, instance_uuid): db_mapping = context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -113,7 +113,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): return cls._from_db_object(context, cls(), db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_mapping = api_models.InstanceMapping() db_mapping.update(updates) @@ -138,7 +138,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): self._from_db_object(self._context, self, db_mapping) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, instance_uuid, updates): db_mapping = context.session.query( api_models.InstanceMapping).filter_by( @@ -173,7 +173,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): self.obj_reset_changes() @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, instance_uuid): result = context.session.query(api_models.InstanceMapping).filter_by( instance_uuid=instance_uuid).delete() @@ -185,7 +185,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject): self._destroy_in_db(self._context, self.instance_uuid) -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def populate_queued_for_delete(context, max_count): cells = objects.CellMappingList.get_all(context) processed = 0 @@ -229,7 +229,7 @@ def populate_queued_for_delete(context, max_count): return processed, processed -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def populate_user_id(context, max_count): cells = objects.CellMappingList.get_all(context) cms_by_id = {cell.id: cell for cell in cells} @@ -309,7 +309,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): } @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_project_id_from_db(context, project_id): return context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -323,7 +323,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_cell_id_from_db(context, cell_id): return context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -336,7 +336,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuids_from_db(context, uuids): return context.session.query(api_models.InstanceMapping)\ .options(orm.joinedload('cell_mapping'))\ @@ -350,7 +350,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_bulk_in_db(context, instance_uuids): return context.session.query(api_models.InstanceMapping).filter( api_models.InstanceMapping.instance_uuid.in_(instance_uuids)).\ @@ -361,7 +361,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): return cls._destroy_bulk_in_db(context, instance_uuids) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_not_deleted_by_cell_and_project_from_db(context, cell_uuid, project_id, limit): query = context.session.query(api_models.InstanceMapping) @@ -400,7 +400,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): db_mappings) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_counts_in_db(context, project_id, user_id=None): project_query = context.session.query( func.count(api_models.InstanceMapping.id)).\ @@ -435,7 +435,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject): return cls._get_counts_in_db(context, project_id, user_id=user_id) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _get_count_by_uuids_and_user_in_db(context, uuids, user_id): query = (context.session.query( func.count(api_models.InstanceMapping.id)) diff --git a/nova/objects/keypair.py b/nova/objects/keypair.py index aa6075c5f961..0b3c4315d5c9 100644 --- a/nova/objects/keypair.py +++ b/nova/objects/keypair.py @@ -17,8 +17,9 @@ from oslo_db.sqlalchemy import utils as sqlalchemyutils from oslo_log import log as logging from oslo_utils import versionutils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova import exception from nova import objects from nova.objects import base @@ -29,7 +30,7 @@ KEYPAIR_TYPE_X509 = 'x509' LOG = logging.getLogger(__name__) -@db.api_context_manager.reader +@api_db_api.context_manager.reader def _get_from_db(context, user_id, name=None, limit=None, marker=None): query = context.session.query(api_models.KeyPair).\ filter(api_models.KeyPair.user_id == user_id) @@ -54,14 +55,14 @@ def _get_from_db(context, user_id, name=None, limit=None, marker=None): return query.all() -@db.api_context_manager.reader +@api_db_api.context_manager.reader def _get_count_from_db(context, user_id): return context.session.query(api_models.KeyPair).\ filter(api_models.KeyPair.user_id == user_id).\ count() -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _create_in_db(context, values): kp = api_models.KeyPair() kp.update(values) @@ -72,7 +73,7 @@ def _create_in_db(context, values): return kp -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _destroy_in_db(context, user_id, name): result = context.session.query(api_models.KeyPair).\ filter_by(user_id=user_id).\ @@ -143,7 +144,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, except exception.KeypairNotFound: pass if db_keypair is None: - db_keypair = db.key_pair_get(context, user_id, name) + db_keypair = main_db_api.key_pair_get(context, user_id, name) return cls._from_db_object(context, cls(), db_keypair) @base.remotable_classmethod @@ -151,7 +152,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, try: cls._destroy_in_db(context, user_id, name) except exception.KeypairNotFound: - db.key_pair_destroy(context, user_id, name) + main_db_api.key_pair_destroy(context, user_id, name) @base.remotable def create(self): @@ -163,7 +164,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, # letting them create in the API DB, since we won't get protection # from the UC. try: - db.key_pair_get(self._context, self.user_id, self.name) + main_db_api.key_pair_get(self._context, self.user_id, self.name) raise exception.KeyPairExists(key_name=self.name) except exception.KeypairNotFound: pass @@ -180,7 +181,8 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject, try: self._destroy_in_db(self._context, self.user_id, self.name) except exception.KeypairNotFound: - db.key_pair_destroy(self._context, self.user_id, self.name) + main_db_api.key_pair_destroy( + self._context, self.user_id, self.name) @base.NovaObjectRegistry.register @@ -222,7 +224,7 @@ class KeyPairList(base.ObjectListBase, base.NovaObject): limit_more = None if limit_more is None or limit_more > 0: - main_db_keypairs = db.key_pair_get_all_by_user( + main_db_keypairs = main_db_api.key_pair_get_all_by_user( context, user_id, limit=limit_more, marker=marker) else: main_db_keypairs = [] @@ -233,4 +235,4 @@ class KeyPairList(base.ObjectListBase, base.NovaObject): @base.remotable_classmethod def get_count_by_user(cls, context, user_id): return (cls._get_count_from_db(context, user_id) + - db.key_pair_count_by_user(context, user_id)) + main_db_api.key_pair_count_by_user(context, user_id)) diff --git a/nova/objects/quotas.py b/nova/objects/quotas.py index 39126dbb6de4..4bbb64fdf2bd 100644 --- a/nova/objects/quotas.py +++ b/nova/objects/quotas.py @@ -16,9 +16,11 @@ import collections from oslo_db import exception as db_exc +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova.db.main import models as main_models +from nova.db import utils as db_utils from nova import exception from nova.objects import base from nova.objects import fields @@ -82,7 +84,7 @@ class Quotas(base.NovaObject): self.obj_reset_changes(fields=[attr]) @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_from_db(context, project_id, resource, user_id=None): model = api_models.ProjectUserQuota if user_id else api_models.Quota query = context.session.query(model).\ @@ -100,14 +102,14 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db(context, project_id): return context.session.query(api_models.ProjectUserQuota).\ filter_by(project_id=project_id).\ all() @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db_by_project(context, project_id): # by_project refers to the returned dict that has a 'project_id' key rows = context.session.query(api_models.Quota).\ @@ -119,7 +121,7 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_from_db_by_project_and_user(context, project_id, user_id): # by_project_and_user refers to the returned dict that has # 'project_id' and 'user_id' keys @@ -135,7 +137,7 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_all_in_db_by_project(context, project_id): per_project = context.session.query(api_models.Quota).\ filter_by(project_id=project_id).\ @@ -147,7 +149,7 @@ class Quotas(base.NovaObject): raise exception.ProjectQuotaNotFound(project_id=project_id) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_all_in_db_by_project_and_user(context, project_id, user_id): result = context.session.query(api_models.ProjectUserQuota).\ filter_by(project_id=project_id).\ @@ -158,7 +160,7 @@ class Quotas(base.NovaObject): user_id=user_id) @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_class_from_db(context, class_name, resource): result = context.session.query(api_models.QuotaClass).\ filter_by(class_name=class_name).\ @@ -169,7 +171,7 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_all_class_from_db_by_name(context, class_name): # by_name refers to the returned dict that has a 'class_name' key rows = context.session.query(api_models.QuotaClass).\ @@ -181,14 +183,16 @@ class Quotas(base.NovaObject): return result @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_limit_in_db(context, project_id, resource, limit, user_id=None): # TODO(melwitt): We won't have per project resources after nova-network # is removed. # TODO(stephenfin): We need to do something here now...but what? - per_user = (user_id and - resource not in db.quota_get_per_project_resources()) + per_user = ( + user_id and + resource not in main_db_api.quota_get_per_project_resources() + ) quota_ref = (api_models.ProjectUserQuota() if per_user else api_models.Quota()) if per_user: @@ -204,14 +208,16 @@ class Quotas(base.NovaObject): return quota_ref @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _update_limit_in_db(context, project_id, resource, limit, user_id=None): # TODO(melwitt): We won't have per project resources after nova-network # is removed. # TODO(stephenfin): We need to do something here now...but what? - per_user = (user_id and - resource not in db.quota_get_per_project_resources()) + per_user = ( + user_id and + resource not in main_db_api.quota_get_per_project_resources() + ) model = api_models.ProjectUserQuota if per_user else api_models.Quota query = context.session.query(model).\ filter_by(project_id=project_id).\ @@ -228,7 +234,7 @@ class Quotas(base.NovaObject): raise exception.ProjectQuotaNotFound(project_id=project_id) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_class_in_db(context, class_name, resource, limit): # NOTE(melwitt): There's no unique constraint on the QuotaClass model, # so check for duplicate manually. @@ -247,7 +253,7 @@ class Quotas(base.NovaObject): return quota_class_ref @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _update_class_in_db(context, class_name, resource, limit): result = context.session.query(api_models.QuotaClass).\ filter_by(class_name=class_name).\ @@ -366,7 +372,8 @@ class Quotas(base.NovaObject): @base.remotable_classmethod def create_limit(cls, context, project_id, resource, limit, user_id=None): try: - db.quota_get(context, project_id, resource, user_id=user_id) + main_db_api.quota_get( + context, project_id, resource, user_id=user_id) except exception.QuotaNotFound: cls._create_limit_in_db(context, project_id, resource, limit, user_id=user_id) @@ -380,13 +387,13 @@ class Quotas(base.NovaObject): cls._update_limit_in_db(context, project_id, resource, limit, user_id=user_id) except exception.QuotaNotFound: - db.quota_update(context, project_id, resource, limit, + main_db_api.quota_update(context, project_id, resource, limit, user_id=user_id) @classmethod def create_class(cls, context, class_name, resource, limit): try: - db.quota_class_get(context, class_name, resource) + main_db_api.quota_class_get(context, class_name, resource) except exception.QuotaClassNotFound: cls._create_class_in_db(context, class_name, resource, limit) else: @@ -398,7 +405,8 @@ class Quotas(base.NovaObject): try: cls._update_class_in_db(context, class_name, resource, limit) except exception.QuotaClassNotFound: - db.quota_class_update(context, class_name, resource, limit) + main_db_api.quota_class_update( + context, class_name, resource, limit) # NOTE(melwitt): The following methods are not remotable and return # dict-like database model objects. We are using classmethods to provide @@ -409,21 +417,22 @@ class Quotas(base.NovaObject): quota = cls._get_from_db(context, project_id, resource, user_id=user_id) except exception.QuotaNotFound: - quota = db.quota_get(context, project_id, resource, + quota = main_db_api.quota_get(context, project_id, resource, user_id=user_id) return quota @classmethod def get_all(cls, context, project_id): api_db_quotas = cls._get_all_from_db(context, project_id) - main_db_quotas = db.quota_get_all(context, project_id) + main_db_quotas = main_db_api.quota_get_all(context, project_id) return api_db_quotas + main_db_quotas @classmethod def get_all_by_project(cls, context, project_id): api_db_quotas_dict = cls._get_all_from_db_by_project(context, project_id) - main_db_quotas_dict = db.quota_get_all_by_project(context, project_id) + main_db_quotas_dict = main_db_api.quota_get_all_by_project( + context, project_id) for k, v in api_db_quotas_dict.items(): main_db_quotas_dict[k] = v return main_db_quotas_dict @@ -432,7 +441,7 @@ class Quotas(base.NovaObject): def get_all_by_project_and_user(cls, context, project_id, user_id): api_db_quotas_dict = cls._get_all_from_db_by_project_and_user( context, project_id, user_id) - main_db_quotas_dict = db.quota_get_all_by_project_and_user( + main_db_quotas_dict = main_db_api.quota_get_all_by_project_and_user( context, project_id, user_id) for k, v in api_db_quotas_dict.items(): main_db_quotas_dict[k] = v @@ -443,7 +452,7 @@ class Quotas(base.NovaObject): try: cls._destroy_all_in_db_by_project(context, project_id) except exception.ProjectQuotaNotFound: - db.quota_destroy_all_by_project(context, project_id) + main_db_api.quota_destroy_all_by_project(context, project_id) @classmethod def destroy_all_by_project_and_user(cls, context, project_id, user_id): @@ -451,31 +460,31 @@ class Quotas(base.NovaObject): cls._destroy_all_in_db_by_project_and_user(context, project_id, user_id) except exception.ProjectUserQuotaNotFound: - db.quota_destroy_all_by_project_and_user(context, project_id, - user_id) + main_db_api.quota_destroy_all_by_project_and_user( + context, project_id, user_id) @classmethod def get_class(cls, context, class_name, resource): try: qclass = cls._get_class_from_db(context, class_name, resource) except exception.QuotaClassNotFound: - qclass = db.quota_class_get(context, class_name, resource) + qclass = main_db_api.quota_class_get(context, class_name, resource) return qclass @classmethod def get_default_class(cls, context): try: qclass = cls._get_all_class_from_db_by_name( - context, db._DEFAULT_QUOTA_NAME) + context, main_db_api._DEFAULT_QUOTA_NAME) except exception.QuotaClassNotFound: - qclass = db.quota_class_get_default(context) + qclass = main_db_api.quota_class_get_default(context) return qclass @classmethod def get_all_class_by_name(cls, context, class_name): api_db_quotas_dict = cls._get_all_class_from_db_by_name(context, class_name) - main_db_quotas_dict = db.quota_class_get_all_by_name(context, + main_db_quotas_dict = main_db_api.quota_class_get_all_by_name(context, class_name) for k, v in api_db_quotas_dict.items(): main_db_quotas_dict[k] = v @@ -501,8 +510,8 @@ class QuotasNoOp(Quotas): pass -@db.require_context -@db.pick_context_manager_reader +@db_utils.require_context +@main_db_api.pick_context_manager_reader def _get_main_per_project_limits(context, limit): return context.session.query(main_models.Quota).\ filter_by(deleted=0).\ @@ -510,8 +519,8 @@ def _get_main_per_project_limits(context, limit): all() -@db.require_context -@db.pick_context_manager_reader +@db_utils.require_context +@main_db_api.pick_context_manager_reader def _get_main_per_user_limits(context, limit): return context.session.query(main_models.ProjectUserQuota).\ filter_by(deleted=0).\ @@ -519,8 +528,8 @@ def _get_main_per_user_limits(context, limit): all() -@db.require_context -@db.pick_context_manager_writer +@db_utils.require_context +@main_db_api.pick_context_manager_writer def _destroy_main_per_project_limits(context, project_id, resource): context.session.query(main_models.Quota).\ filter_by(deleted=0).\ @@ -529,8 +538,8 @@ def _destroy_main_per_project_limits(context, project_id, resource): soft_delete(synchronize_session=False) -@db.require_context -@db.pick_context_manager_writer +@db_utils.require_context +@main_db_api.pick_context_manager_writer def _destroy_main_per_user_limits(context, project_id, resource, user_id): context.session.query(main_models.ProjectUserQuota).\ filter_by(deleted=0).\ @@ -540,7 +549,7 @@ def _destroy_main_per_user_limits(context, project_id, resource, user_id): soft_delete(synchronize_session=False) -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _create_limits_in_api_db(context, db_limits, per_user=False): for db_limit in db_limits: user_id = db_limit.user_id if per_user else None @@ -587,8 +596,8 @@ def migrate_quota_limits_to_api_db(context, count): return len(main_per_project_limits) + len(main_per_user_limits), done -@db.require_context -@db.pick_context_manager_reader +@db_utils.require_context +@main_db_api.pick_context_manager_reader def _get_main_quota_classes(context, limit): return context.session.query(main_models.QuotaClass).\ filter_by(deleted=0).\ @@ -596,7 +605,7 @@ def _get_main_quota_classes(context, limit): all() -@db.pick_context_manager_writer +@main_db_api.pick_context_manager_writer def _destroy_main_quota_classes(context, db_classes): for db_class in db_classes: context.session.query(main_models.QuotaClass).\ @@ -605,7 +614,7 @@ def _destroy_main_quota_classes(context, db_classes): soft_delete(synchronize_session=False) -@db.api_context_manager.writer +@api_db_api.context_manager.writer def _create_classes_in_api_db(context, db_classes): for db_class in db_classes: Quotas._create_class_in_db(context, db_class.class_name, diff --git a/nova/objects/request_spec.py b/nova/objects/request_spec.py index cef1c6633484..7d57903078af 100644 --- a/nova/objects/request_spec.py +++ b/nova/objects/request_spec.py @@ -20,8 +20,8 @@ from oslo_log import log as logging from oslo_serialization import jsonutils from oslo_utils import versionutils +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db from nova import exception from nova import objects from nova.objects import base @@ -630,7 +630,7 @@ class RequestSpec(base.NovaObject): return spec @staticmethod - @db.api_context_manager.reader + @api_db_api.context_manager.reader def _get_by_instance_uuid_from_db(context, instance_uuid): db_spec = context.session.query(api_models.RequestSpec).filter_by( instance_uuid=instance_uuid).first() @@ -645,7 +645,7 @@ class RequestSpec(base.NovaObject): return cls._from_db_object(context, cls(), db_spec) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _create_in_db(context, updates): db_spec = api_models.RequestSpec() db_spec.update(updates) @@ -709,7 +709,7 @@ class RequestSpec(base.NovaObject): self._from_db_object(self._context, self, db_spec) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _save_in_db(context, instance_uuid, updates): # FIXME(sbauza): Provide a classmethod when oslo.db bug #1520195 is # fixed and released @@ -729,7 +729,7 @@ class RequestSpec(base.NovaObject): self.obj_reset_changes() @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_in_db(context, instance_uuid): result = context.session.query(api_models.RequestSpec).filter_by( instance_uuid=instance_uuid).delete() @@ -741,7 +741,7 @@ class RequestSpec(base.NovaObject): self._destroy_in_db(self._context, self.instance_uuid) @staticmethod - @db.api_context_manager.writer + @api_db_api.context_manager.writer def _destroy_bulk_in_db(context, instance_uuids): return context.session.query(api_models.RequestSpec).filter( api_models.RequestSpec.instance_uuid.in_(instance_uuids)).\ diff --git a/nova/objects/virtual_interface.py b/nova/objects/virtual_interface.py index e9a769a7b513..7ce418aca202 100644 --- a/nova/objects/virtual_interface.py +++ b/nova/objects/virtual_interface.py @@ -16,8 +16,9 @@ from oslo_log import log as logging from oslo_utils import versionutils from nova import context as nova_context -from nova.db.main import api as db -from nova.db.main import models +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api +from nova.db.main import models as main_db_models from nova import exception from nova import objects from nova.objects import base @@ -70,26 +71,26 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject): @base.remotable_classmethod def get_by_id(cls, context, vif_id): - db_vif = db.virtual_interface_get(context, vif_id) + db_vif = main_db_api.virtual_interface_get(context, vif_id) if db_vif: return cls._from_db_object(context, cls(), db_vif) @base.remotable_classmethod def get_by_uuid(cls, context, vif_uuid): - db_vif = db.virtual_interface_get_by_uuid(context, vif_uuid) + db_vif = main_db_api.virtual_interface_get_by_uuid(context, vif_uuid) if db_vif: return cls._from_db_object(context, cls(), db_vif) @base.remotable_classmethod def get_by_address(cls, context, address): - db_vif = db.virtual_interface_get_by_address(context, address) + db_vif = main_db_api.virtual_interface_get_by_address(context, address) if db_vif: return cls._from_db_object(context, cls(), db_vif) @base.remotable_classmethod def get_by_instance_and_network(cls, context, instance_uuid, network_id): - db_vif = db.virtual_interface_get_by_instance_and_network(context, - instance_uuid, network_id) + db_vif = main_db_api.virtual_interface_get_by_instance_and_network( + context, instance_uuid, network_id) if db_vif: return cls._from_db_object(context, cls(), db_vif) @@ -99,7 +100,7 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject): raise exception.ObjectActionError(action='create', reason='already created') updates = self.obj_get_changes() - db_vif = db.virtual_interface_create(self._context, updates) + db_vif = main_db_api.virtual_interface_create(self._context, updates) self._from_db_object(self._context, self, db_vif) @base.remotable @@ -108,17 +109,18 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject): if 'address' in updates: raise exception.ObjectActionError(action='save', reason='address is not mutable') - db_vif = db.virtual_interface_update(self._context, self.address, - updates) + db_vif = main_db_api.virtual_interface_update( + self._context, self.address, updates) return self._from_db_object(self._context, self, db_vif) @base.remotable_classmethod def delete_by_instance_uuid(cls, context, instance_uuid): - db.virtual_interface_delete_by_instance(context, instance_uuid) + main_db_api.virtual_interface_delete_by_instance( + context, instance_uuid) @base.remotable def destroy(self): - db.virtual_interface_delete(self._context, self.id) + main_db_api.virtual_interface_delete(self._context, self.id) @base.NovaObjectRegistry.register @@ -131,15 +133,16 @@ class VirtualInterfaceList(base.ObjectListBase, base.NovaObject): @base.remotable_classmethod def get_all(cls, context): - db_vifs = db.virtual_interface_get_all(context) + db_vifs = main_db_api.virtual_interface_get_all(context) return base.obj_make_list(context, cls(context), objects.VirtualInterface, db_vifs) @staticmethod - @db.select_db_reader_mode + @main_db_api.select_db_reader_mode def _db_virtual_interface_get_by_instance(context, instance_uuid, use_slave=False): - return db.virtual_interface_get_by_instance(context, instance_uuid) + return main_db_api.virtual_interface_get_by_instance( + context, instance_uuid) @base.remotable_classmethod def get_by_instance_uuid(cls, context, instance_uuid, use_slave=False): @@ -149,7 +152,7 @@ class VirtualInterfaceList(base.ObjectListBase, base.NovaObject): objects.VirtualInterface, db_vifs) -@db.api_context_manager.writer +@api_db_api.context_manager.writer def fill_virtual_interface_list(context, max_count): """This fills missing VirtualInterface Objects in Nova DB""" count_hit = 0 @@ -287,14 +290,14 @@ def fill_virtual_interface_list(context, max_count): # we checked. # Please notice that because of virtual_interfaces_instance_uuid_fkey # we need to have FAKE_UUID instance object, even deleted one. -@db.pick_context_manager_writer +@main_db_api.pick_context_manager_writer def _set_or_delete_marker_for_migrate_instances(context, marker=None): - context.session.query(models.VirtualInterface).filter_by( + context.session.query(main_db_models.VirtualInterface).filter_by( instance_uuid=FAKE_UUID).delete() # Create FAKE_UUID instance objects, only for marker, if doesn't exist. # It is needed due constraint: virtual_interfaces_instance_uuid_fkey - instance = context.session.query(models.Instance).filter_by( + instance = context.session.query(main_db_models.Instance).filter_by( uuid=FAKE_UUID).first() if not instance: instance = objects.Instance(context) @@ -316,9 +319,9 @@ def _set_or_delete_marker_for_migrate_instances(context, marker=None): db_mapping.create() -@db.pick_context_manager_reader +@main_db_api.pick_context_manager_reader def _get_marker_for_migrate_instances(context): - vif = (context.session.query(models.VirtualInterface).filter_by( + vif = (context.session.query(main_db_models.VirtualInterface).filter_by( instance_uuid=FAKE_UUID)).first() marker = vif['tag'] if vif else None return marker diff --git a/nova/quota.py b/nova/quota.py index e3950e9d99c2..a311ecc87fae 100644 --- a/nova/quota.py +++ b/nova/quota.py @@ -24,8 +24,9 @@ from sqlalchemy import sql import nova.conf from nova import context as nova_context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova import exception from nova import objects from nova.scheduler.client import report @@ -177,7 +178,10 @@ class DbQuotaDriver(object): # displaying used limits. They are always zero. usages[resource.name] = {'in_use': 0} else: - if resource.name in db.quota_get_per_project_resources(): + if ( + resource.name in + main_db_api.quota_get_per_project_resources() + ): count = resource.count_as_dict(context, project_id) key = 'project' else: @@ -1045,7 +1049,7 @@ class QuotaEngine(object): return 0 -@db.api_context_manager.reader +@api_db_api.context_manager.reader def _user_id_queued_for_delete_populated(context, project_id=None): """Determine whether user_id and queued_for_delete are set. diff --git a/nova/tests/fixtures/nova.py b/nova/tests/fixtures/nova.py index 048adbd1b9d8..d5d96f5122e9 100644 --- a/nova/tests/fixtures/nova.py +++ b/nova/tests/fixtures/nova.py @@ -43,7 +43,8 @@ from nova.api import wsgi from nova.compute import multi_cell_list from nova.compute import rpcapi as compute_rpcapi from nova import context -from nova.db.main import api as session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova.db import migration from nova import exception from nova import objects @@ -61,7 +62,6 @@ LOG = logging.getLogger(__name__) DB_SCHEMA = collections.defaultdict(str) SESSION_CONFIGURED = False - PROJECT_ID = '6f70656e737461636b20342065766572' @@ -532,7 +532,7 @@ class CellDatabases(fixtures.Fixture): # will house the sqlite:// connection for this cell's in-memory # database. Store/index it by the connection string, which is # how we identify cells in CellMapping. - ctxt_mgr = session.create_context_manager() + ctxt_mgr = main_db_api.create_context_manager() self._ctxt_mgrs[connection_str] = ctxt_mgr # NOTE(melwitt): The first DB access through service start is @@ -595,30 +595,45 @@ class CellDatabases(fixtures.Fixture): class Database(fixtures.Fixture): + + # TODO(stephenfin): The 'version' argument is unused and can be removed def __init__(self, database='main', version=None, connection=None): """Create a database fixture. :param database: The type of database, 'main', or 'api' :param connection: The connection string to use """ - super(Database, self).__init__() - # NOTE(pkholkin): oslo_db.enginefacade is configured in tests the same - # way as it is done for any other service that uses db + super().__init__() + + # NOTE(pkholkin): oslo_db.enginefacade is configured in tests the + # same way as it is done for any other service that uses DB global SESSION_CONFIGURED if not SESSION_CONFIGURED: - session.configure(CONF) + main_db_api.configure(CONF) + api_db_api.configure(CONF) SESSION_CONFIGURED = True + + assert database in {'main', 'api'}, f'Unrecognized database {database}' + self.database = database self.version = version + if database == 'main': if connection is not None: - ctxt_mgr = session.create_context_manager( - connection=connection) + ctxt_mgr = main_db_api.create_context_manager( + connection=connection) self.get_engine = ctxt_mgr.writer.get_engine else: - self.get_engine = session.get_engine + self.get_engine = main_db_api.get_engine elif database == 'api': - self.get_engine = session.get_api_engine + assert connection is None, 'Not supported for the API database' + + self.get_engine = api_db_api.get_engine + + def setUp(self): + super(Database, self).setUp() + self.reset() + self.addCleanup(self.cleanup) def _cache_schema(self): global DB_SCHEMA @@ -642,11 +657,6 @@ class Database(fixtures.Fixture): conn.connection.executescript( DB_SCHEMA[(self.database, self.version)]) - def setUp(self): - super(Database, self).setUp() - self.reset() - self.addCleanup(self.cleanup) - class DefaultFlavorsFixture(fixtures.Fixture): def setUp(self): diff --git a/nova/tests/functional/db/test_aggregate.py b/nova/tests/functional/db/test_aggregate.py index 3c1cf8ded083..35d9024576e3 100644 --- a/nova/tests/functional/db/test_aggregate.py +++ b/nova/tests/functional/db/test_aggregate.py @@ -18,8 +18,8 @@ from oslo_utils.fixture import uuidsentinel from oslo_utils import timeutils from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception import nova.objects.aggregate as aggregate_obj from nova import test @@ -59,7 +59,7 @@ def _get_fake_metadata(db_id): 'unique_key': 'unique_value_' + str(db_id)} -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _create_aggregate(context, values=_get_fake_aggregate(1, result=False), metadata=_get_fake_metadata(1)): aggregate = api_models.Aggregate() @@ -77,7 +77,7 @@ def _create_aggregate(context, values=_get_fake_aggregate(1, result=False), return aggregate -@db_api.api_context_manager.writer +@api_db_api.context_manager.writer def _create_aggregate_with_hosts(context, values=_get_fake_aggregate(1, result=False), metadata=_get_fake_metadata(1), @@ -92,13 +92,13 @@ def _create_aggregate_with_hosts(context, return aggregate -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_host_get_all(context, aggregate_id): return context.session.query(api_models.AggregateHost).\ filter_by(aggregate_id=aggregate_id).all() -@db_api.api_context_manager.reader +@api_db_api.context_manager.reader def _aggregate_metadata_get_all(context, aggregate_id): results = context.session.query(api_models.AggregateMetadata).\ filter_by(aggregate_id=aggregate_id).all() diff --git a/nova/tests/functional/db/test_flavor.py b/nova/tests/functional/db/test_flavor.py index 3ce192fec754..c0435b5cbc64 100644 --- a/nova/tests/functional/db/test_flavor.py +++ b/nova/tests/functional/db/test_flavor.py @@ -11,8 +11,8 @@ # under the License. from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova import objects from nova import test @@ -111,7 +111,7 @@ class FlavorObjectTestCase(test.NoDBTestCase): self.assertEqual(set(projects), set(flavor2.projects)) @staticmethod - @db_api.api_context_manager.reader + @api_db_api.context_manager.reader def _collect_flavor_residue_api(context, flavor): flavors = context.session.query(api_models.Flavors).\ filter_by(id=flavor.id).all() diff --git a/nova/tests/unit/api/openstack/test_wsgi_app.py b/nova/tests/unit/api/openstack/test_wsgi_app.py index b6306af0e12a..d2bd7c4bd6e1 100644 --- a/nova/tests/unit/api/openstack/test_wsgi_app.py +++ b/nova/tests/unit/api/openstack/test_wsgi_app.py @@ -46,11 +46,14 @@ document_root = /tmp self.useFixture(config_fixture.Config()) @mock.patch('sys.argv', return_value=mock.sentinel.argv) + @mock.patch('nova.db.api.api.configure') @mock.patch('nova.db.main.api.configure') @mock.patch('nova.api.openstack.wsgi_app._setup_service') @mock.patch('nova.api.openstack.wsgi_app._get_config_files') - def test_init_application_called_twice(self, mock_get_files, mock_setup, - mock_db_configure, mock_argv): + def test_init_application_called_twice( + self, mock_get_files, mock_setup, mock_main_db_configure, + mock_api_db_configure, mock_argv, + ): """Test that init_application can tolerate being called twice in a single python interpreter instance. @@ -65,14 +68,15 @@ document_root = /tmp """ mock_get_files.return_value = [self.conf.name] mock_setup.side_effect = [test.TestingException, None] - # We need to mock the global database configure() method, else we will + # We need to mock the global database configure() methods, else we will # be affected by global database state altered by other tests that ran # before this test, causing this test to fail with # oslo_db.sqlalchemy.enginefacade.AlreadyStartedError. We can instead # mock the method to raise an exception if it's called a second time in # this test to simulate the fact that the database does not tolerate # re-init [after a database query has been made]. - mock_db_configure.side_effect = [None, test.TestingException] + mock_main_db_configure.side_effect = [None, test.TestingException] + mock_api_db_configure.side_effect = [None, test.TestingException] # Run init_application the first time, simulating an exception being # raised during it. self.assertRaises(test.TestingException, wsgi_app.init_application, diff --git a/nova/tests/unit/conductor/test_conductor.py b/nova/tests/unit/conductor/test_conductor.py index 152a625286c7..ccf84e0929f6 100644 --- a/nova/tests/unit/conductor/test_conductor.py +++ b/nova/tests/unit/conductor/test_conductor.py @@ -39,8 +39,9 @@ from nova.conductor.tasks import live_migrate from nova.conductor.tasks import migrate from nova import conf from nova import context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db +from nova.db.main import api as main_db_api from nova import exception as exc from nova.image import glance as image_api from nova import objects @@ -440,7 +441,7 @@ class _BaseTaskTestCase(object): @mock.patch.object(conductor_manager.ComputeTaskManager, '_create_and_bind_arqs') @mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance') - @mock.patch.object(db, 'block_device_mapping_get_all_by_instance', + @mock.patch.object(main_db_api, 'block_device_mapping_get_all_by_instance', return_value=[]) @mock.patch.object(conductor_manager.ComputeTaskManager, '_schedule_instances') @@ -547,7 +548,7 @@ class _BaseTaskTestCase(object): @mock.patch.object(conductor_manager.ComputeTaskManager, '_create_and_bind_arqs') @mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance') - @mock.patch.object(db, 'block_device_mapping_get_all_by_instance', + @mock.patch.object(main_db_api, 'block_device_mapping_get_all_by_instance', return_value=[]) @mock.patch.object(conductor_manager.ComputeTaskManager, '_schedule_instances') @@ -2740,7 +2741,7 @@ class ConductorTaskTestCase(_BaseTaskTestCase, test_compute.BaseTestCase): self.assertEqual(0, len(build_requests)) - @db.api_context_manager.reader + @api_db_api.context_manager.reader def request_spec_get_all(context): return context.session.query(api_models.RequestSpec).all() diff --git a/nova/tests/unit/db/api/__init__.py b/nova/tests/unit/db/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nova/tests/unit/db/api/test_api.py b/nova/tests/unit/db/api/test_api.py new file mode 100644 index 000000000000..251407612fad --- /dev/null +++ b/nova/tests/unit/db/api/test_api.py @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from nova.db.api import api as db_api +from nova import test + + +class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase): + + @mock.patch.object(db_api, 'context_manager') + def test_get_engine(self, mock_ctxt_mgr): + db_api.get_engine() + mock_ctxt_mgr.writer.get_engine.assert_called_once_with() diff --git a/nova/tests/unit/db/main/test_api.py b/nova/tests/unit/db/main/test_api.py index 7254f9262b82..de92f5feb7c8 100644 --- a/nova/tests/unit/db/main/test_api.py +++ b/nova/tests/unit/db/main/test_api.py @@ -1,5 +1,3 @@ -# encoding=UTF8 - # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. @@ -52,6 +50,7 @@ from nova import context from nova.db.main import api as db from nova.db.main import models from nova.db import types as col_types +from nova.db import utils as db_utils from nova import exception from nova.objects import fields from nova import test @@ -207,7 +206,7 @@ class DecoratorTestCase(test.TestCase): self.assertEqual(test_func.__module__, decorated_func.__module__) def test_require_context_decorator_wraps_functions_properly(self): - self._test_decorator_wraps_helper(db.require_context) + self._test_decorator_wraps_helper(db_utils.require_context) def test_require_deadlock_retry_wraps_functions_properly(self): self._test_decorator_wraps_helper( @@ -628,7 +627,7 @@ class ModelQueryTestCase(DbTestCase): @mock.patch.object(sqlalchemyutils, 'model_query') def test_model_query_use_context_session(self, mock_model_query): - @db.main_context_manager.reader + @db.context_manager.reader def fake_method(context): session = context.session db.model_query(context, models.Instance) @@ -642,15 +641,15 @@ class ModelQueryTestCase(DbTestCase): class EngineFacadeTestCase(DbTestCase): def test_use_single_context_session_writer(self): # Checks that session in context would not be overwritten by - # annotation @db.main_context_manager.writer if annotation + # annotation @db.context_manager.writer if annotation # is used twice. - @db.main_context_manager.writer + @db.context_manager.writer def fake_parent_method(context): session = context.session return fake_child_method(context), session - @db.main_context_manager.writer + @db.context_manager.writer def fake_child_method(context): session = context.session db.model_query(context, models.Instance) @@ -661,15 +660,15 @@ class EngineFacadeTestCase(DbTestCase): def test_use_single_context_session_reader(self): # Checks that session in context would not be overwritten by - # annotation @db.main_context_manager.reader if annotation + # annotation @db.context_manager.reader if annotation # is used twice. - @db.main_context_manager.reader + @db.context_manager.reader def fake_parent_method(context): session = context.session return fake_child_method(context), session - @db.main_context_manager.reader + @db.context_manager.reader def fake_child_method(context): session = context.session db.model_query(context, models.Instance) @@ -757,12 +756,12 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase): self.assertEqual('|', filter('|')) self.assertEqual('LIKE', op) - @mock.patch.object(db, 'main_context_manager') + @mock.patch.object(db, 'context_manager') def test_get_engine(self, mock_ctxt_mgr): db.get_engine() mock_ctxt_mgr.writer.get_engine.assert_called_once_with() - @mock.patch.object(db, 'main_context_manager') + @mock.patch.object(db, 'context_manager') def test_get_engine_use_slave(self, mock_ctxt_mgr): db.get_engine(use_slave=True) mock_ctxt_mgr.reader.get_engine.assert_called_once_with() @@ -774,11 +773,6 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase): connection='fake://') self.assertEqual('fake://', db_conf['connection']) - @mock.patch.object(db, 'api_context_manager') - def test_get_api_engine(self, mock_ctxt_mgr): - db.get_api_engine() - mock_ctxt_mgr.writer.get_engine.assert_called_once_with() - @mock.patch.object(db, '_instance_get_by_uuid') @mock.patch.object(db, '_instances_fill_metadata') @mock.patch('oslo_db.sqlalchemy.utils.paginate_query') @@ -1013,114 +1007,6 @@ class SqlAlchemyDbApiTestCase(DbTestCase): self.assertEqual(1, len(instances)) -class ProcessSortParamTestCase(test.TestCase): - - def test_process_sort_params_defaults(self): - '''Verifies default sort parameters.''' - sort_keys, sort_dirs = db.process_sort_params([], []) - self.assertEqual(['created_at', 'id'], sort_keys) - self.assertEqual(['asc', 'asc'], sort_dirs) - - sort_keys, sort_dirs = db.process_sort_params(None, None) - self.assertEqual(['created_at', 'id'], sort_keys) - self.assertEqual(['asc', 'asc'], sort_dirs) - - def test_process_sort_params_override_default_keys(self): - '''Verifies that the default keys can be overridden.''' - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_keys=['key1', 'key2', 'key3']) - self.assertEqual(['key1', 'key2', 'key3'], sort_keys) - self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) - - def test_process_sort_params_override_default_dir(self): - '''Verifies that the default direction can be overridden.''' - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_dir='dir1') - self.assertEqual(['created_at', 'id'], sort_keys) - self.assertEqual(['dir1', 'dir1'], sort_dirs) - - def test_process_sort_params_override_default_key_and_dir(self): - '''Verifies that the default key and dir can be overridden.''' - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_keys=['key1', 'key2', 'key3'], - default_dir='dir1') - self.assertEqual(['key1', 'key2', 'key3'], sort_keys) - self.assertEqual(['dir1', 'dir1', 'dir1'], sort_dirs) - - sort_keys, sort_dirs = db.process_sort_params( - [], [], default_keys=[], default_dir='dir1') - self.assertEqual([], sort_keys) - self.assertEqual([], sort_dirs) - - def test_process_sort_params_non_default(self): - '''Verifies that non-default keys are added correctly.''' - sort_keys, sort_dirs = db.process_sort_params( - ['key1', 'key2'], ['asc', 'desc']) - self.assertEqual(['key1', 'key2', 'created_at', 'id'], sort_keys) - # First sort_dir in list is used when adding the default keys - self.assertEqual(['asc', 'desc', 'asc', 'asc'], sort_dirs) - - def test_process_sort_params_default(self): - '''Verifies that default keys are added correctly.''' - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], ['asc', 'desc']) - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['asc', 'desc', 'asc'], sort_dirs) - - # Include default key value, rely on default direction - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], []) - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) - - def test_process_sort_params_default_dir(self): - '''Verifies that the default dir is applied to all keys.''' - # Direction is set, ignore default dir - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], ['desc'], default_dir='dir') - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['desc', 'desc', 'desc'], sort_dirs) - - # But should be used if no direction is set - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2'], [], default_dir='dir') - self.assertEqual(['id', 'key2', 'created_at'], sort_keys) - self.assertEqual(['dir', 'dir', 'dir'], sort_dirs) - - def test_process_sort_params_unequal_length(self): - '''Verifies that a sort direction list is applied correctly.''' - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2', 'key3'], ['desc']) - self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) - self.assertEqual(['desc', 'desc', 'desc', 'desc'], sort_dirs) - - # Default direction is the first key in the list - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2', 'key3'], ['desc', 'asc']) - self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) - self.assertEqual(['desc', 'asc', 'desc', 'desc'], sort_dirs) - - sort_keys, sort_dirs = db.process_sort_params( - ['id', 'key2', 'key3'], ['desc', 'asc', 'asc']) - self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) - self.assertEqual(['desc', 'asc', 'asc', 'desc'], sort_dirs) - - def test_process_sort_params_extra_dirs_lengths(self): - '''InvalidInput raised if more directions are given.''' - self.assertRaises(exception.InvalidInput, - db.process_sort_params, - ['key1', 'key2'], - ['asc', 'desc', 'desc']) - - def test_process_sort_params_invalid_sort_dir(self): - '''InvalidInput raised if invalid directions are given.''' - for dirs in [['foo'], ['asc', 'foo'], ['asc', 'desc', 'foo']]: - self.assertRaises(exception.InvalidInput, - db.process_sort_params, - ['key'], - dirs) - - class MigrationTestCase(test.TestCase): def setUp(self): @@ -1877,7 +1763,7 @@ class InstanceTestCase(test.TestCase, ModelsObjectComparatorMixin): @mock.patch.object(query.Query, 'filter') def test_instance_metadata_get_multi_no_uuids(self, mock_query_filter): - with db.main_context_manager.reader.using(self.ctxt): + with db.context_manager.reader.using(self.ctxt): db._instance_metadata_get_multi(self.ctxt, []) self.assertFalse(mock_query_filter.called) diff --git a/nova/tests/unit/db/test_migration.py b/nova/tests/unit/db/test_migration.py index f9c4d7e72e1e..66469e42a596 100644 --- a/nova/tests/unit/db/test_migration.py +++ b/nova/tests/unit/db/test_migration.py @@ -17,7 +17,8 @@ from migrate.versioning import api as versioning_api import mock import sqlalchemy -from nova.db.main import api as db_api +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova.db import migration from nova import test @@ -144,15 +145,17 @@ class TestDbVersionControl(test.NoDBTestCase): class TestGetEngine(test.NoDBTestCase): def test_get_main_engine(self): - with mock.patch.object(db_api, 'get_engine', - return_value='engine') as mock_get_engine: + with mock.patch.object( + main_db_api, 'get_engine', return_value='engine', + ) as mock_get_engine: engine = migration.get_engine() self.assertEqual('engine', engine) mock_get_engine.assert_called_once_with(context=None) def test_get_api_engine(self): - with mock.patch.object(db_api, 'get_api_engine', - return_value='api_engine') as mock_get_engine: + with mock.patch.object( + api_db_api, 'get_engine', return_value='engine', + ) as mock_get_engine: engine = migration.get_engine('api') - self.assertEqual('api_engine', engine) + self.assertEqual('engine', engine) mock_get_engine.assert_called_once_with() diff --git a/nova/tests/unit/db/test_utils.py b/nova/tests/unit/db/test_utils.py new file mode 100644 index 000000000000..5f293723f8c7 --- /dev/null +++ b/nova/tests/unit/db/test_utils.py @@ -0,0 +1,123 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from nova.db import utils +from nova import exception +from nova import test + + +class ProcessSortParamTestCase(test.TestCase): + + def test_process_sort_params_defaults(self): + """Verifies default sort parameters.""" + sort_keys, sort_dirs = utils.process_sort_params([], []) + self.assertEqual(['created_at', 'id'], sort_keys) + self.assertEqual(['asc', 'asc'], sort_dirs) + + sort_keys, sort_dirs = utils.process_sort_params(None, None) + self.assertEqual(['created_at', 'id'], sort_keys) + self.assertEqual(['asc', 'asc'], sort_dirs) + + def test_process_sort_params_override_default_keys(self): + """Verifies that the default keys can be overridden.""" + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_keys=['key1', 'key2', 'key3']) + self.assertEqual(['key1', 'key2', 'key3'], sort_keys) + self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) + + def test_process_sort_params_override_default_dir(self): + """Verifies that the default direction can be overridden.""" + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_dir='dir1') + self.assertEqual(['created_at', 'id'], sort_keys) + self.assertEqual(['dir1', 'dir1'], sort_dirs) + + def test_process_sort_params_override_default_key_and_dir(self): + """Verifies that the default key and dir can be overridden.""" + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_keys=['key1', 'key2', 'key3'], + default_dir='dir1') + self.assertEqual(['key1', 'key2', 'key3'], sort_keys) + self.assertEqual(['dir1', 'dir1', 'dir1'], sort_dirs) + + sort_keys, sort_dirs = utils.process_sort_params( + [], [], default_keys=[], default_dir='dir1') + self.assertEqual([], sort_keys) + self.assertEqual([], sort_dirs) + + def test_process_sort_params_non_default(self): + """Verifies that non-default keys are added correctly.""" + sort_keys, sort_dirs = utils.process_sort_params( + ['key1', 'key2'], ['asc', 'desc']) + self.assertEqual(['key1', 'key2', 'created_at', 'id'], sort_keys) + # First sort_dir in list is used when adding the default keys + self.assertEqual(['asc', 'desc', 'asc', 'asc'], sort_dirs) + + def test_process_sort_params_default(self): + """Verifies that default keys are added correctly.""" + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], ['asc', 'desc']) + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['asc', 'desc', 'asc'], sort_dirs) + + # Include default key value, rely on default direction + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], []) + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['asc', 'asc', 'asc'], sort_dirs) + + def test_process_sort_params_default_dir(self): + """Verifies that the default dir is applied to all keys.""" + # Direction is set, ignore default dir + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], ['desc'], default_dir='dir') + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['desc', 'desc', 'desc'], sort_dirs) + + # But should be used if no direction is set + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2'], [], default_dir='dir') + self.assertEqual(['id', 'key2', 'created_at'], sort_keys) + self.assertEqual(['dir', 'dir', 'dir'], sort_dirs) + + def test_process_sort_params_unequal_length(self): + """Verifies that a sort direction list is applied correctly.""" + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2', 'key3'], ['desc']) + self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) + self.assertEqual(['desc', 'desc', 'desc', 'desc'], sort_dirs) + + # Default direction is the first key in the list + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2', 'key3'], ['desc', 'asc']) + self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) + self.assertEqual(['desc', 'asc', 'desc', 'desc'], sort_dirs) + + sort_keys, sort_dirs = utils.process_sort_params( + ['id', 'key2', 'key3'], ['desc', 'asc', 'asc']) + self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys) + self.assertEqual(['desc', 'asc', 'asc', 'desc'], sort_dirs) + + def test_process_sort_params_extra_dirs_lengths(self): + """InvalidInput raised if more directions are given.""" + self.assertRaises( + exception.InvalidInput, + utils.process_sort_params, + ['key1', 'key2'], + ['asc', 'desc', 'desc']) + + def test_process_sort_params_invalid_sort_dir(self): + """InvalidInput raised if invalid directions are given.""" + for dirs in [['foo'], ['asc', 'foo'], ['asc', 'desc', 'foo']]: + self.assertRaises( + exception.InvalidInput, + utils.process_sort_params, ['key'], dirs) diff --git a/nova/tests/unit/objects/test_flavor.py b/nova/tests/unit/objects/test_flavor.py index 92adbc6df904..3f61a8fafbc0 100644 --- a/nova/tests/unit/objects/test_flavor.py +++ b/nova/tests/unit/objects/test_flavor.py @@ -19,8 +19,8 @@ from oslo_db import exception as db_exc from oslo_utils import uuidutils from nova import context as nova_context +from nova.db.api import api as api_db_api from nova.db.api import models as api_models -from nova.db.main import api as db_api from nova import exception from nova import objects from nova.objects import fields @@ -89,7 +89,7 @@ class _TestFlavor(object): mock_get.assert_called_once_with(self.context, 'm1.foo') @staticmethod - @db_api.api_context_manager.writer + @api_db_api.context_manager.writer def _create_api_flavor(context, altid=None): fake_db_flavor = dict(fake_flavor) del fake_db_flavor['extra_specs'] diff --git a/nova/tests/unit/test_conf.py b/nova/tests/unit/test_conf.py index 722f32d4c186..95a7c45114cc 100644 --- a/nova/tests/unit/test_conf.py +++ b/nova/tests/unit/test_conf.py @@ -87,19 +87,24 @@ class TestParseArgs(test.NoDBTestCase): def setUp(self): super(TestParseArgs, self).setUp() m = mock.patch('nova.db.main.api.configure') - self.nova_db_config_mock = m.start() - self.addCleanup(self.nova_db_config_mock.stop) + self.main_db_config_mock = m.start() + self.addCleanup(self.main_db_config_mock.stop) + m = mock.patch('nova.db.api.api.configure') + self.api_db_config_mock = m.start() + self.addCleanup(self.api_db_config_mock.stop) @mock.patch.object(config.log, 'register_options') def test_parse_args_glance_debug_false(self, register_options): self.flags(debug=False, group='glance') config.parse_args([], configure_db=False, init_rpc=False) self.assertIn('glanceclient=WARN', config.CONF.default_log_levels) - self.nova_db_config_mock.assert_not_called() + self.main_db_config_mock.assert_not_called() + self.api_db_config_mock.assert_not_called() @mock.patch.object(config.log, 'register_options') def test_parse_args_glance_debug_true(self, register_options): self.flags(debug=True, group='glance') config.parse_args([], configure_db=True, init_rpc=False) self.assertIn('glanceclient=DEBUG', config.CONF.default_log_levels) - self.nova_db_config_mock.assert_called_once_with(config.CONF) + self.main_db_config_mock.assert_called_once_with(config.CONF) + self.api_db_config_mock.assert_called_once_with(config.CONF) diff --git a/nova/tests/unit/test_fixtures.py b/nova/tests/unit/test_fixtures.py index 48913d05ad69..212b8107394d 100644 --- a/nova/tests/unit/test_fixtures.py +++ b/nova/tests/unit/test_fixtures.py @@ -34,7 +34,8 @@ import testtools from nova.compute import rpcapi as compute_rpcapi from nova import conductor from nova import context -from nova.db.main import api as session +from nova.db.api import api as api_db_api +from nova.db.main import api as main_db_api from nova import exception from nova.network import neutron as neutron_api from nova import objects @@ -121,7 +122,7 @@ class TestDatabaseFixture(testtools.TestCase): # because this sets up reasonable db connection strings self.useFixture(fixtures.ConfFixture()) self.useFixture(fixtures.Database()) - engine = session.get_engine() + engine = main_db_api.get_engine() conn = engine.connect() result = conn.execute("select * from instance_types") rows = result.fetchall() @@ -152,7 +153,7 @@ class TestDatabaseFixture(testtools.TestCase): # This sets up reasonable db connection strings self.useFixture(fixtures.ConfFixture()) self.useFixture(fixtures.Database(database='api')) - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() result = conn.execute("select * from cell_mappings") rows = result.fetchall() @@ -186,7 +187,7 @@ class TestDatabaseFixture(testtools.TestCase): fix.cleanup() # ensure the db contains nothing - engine = session.get_engine() + engine = main_db_api.get_engine() conn = engine.connect() schema = "".join(line for line in conn.connection.iterdump()) self.assertEqual(schema, "BEGIN TRANSACTION;COMMIT;") @@ -198,7 +199,7 @@ class TestDatabaseFixture(testtools.TestCase): self.useFixture(fix) # No data inserted by migrations so we need to add a row - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() uuid = uuidutils.generate_uuid() conn.execute("insert into cell_mappings (uuid, name) VALUES " @@ -211,7 +212,7 @@ class TestDatabaseFixture(testtools.TestCase): fix.cleanup() # Ensure the db contains nothing - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() schema = "".join(line for line in conn.connection.iterdump()) self.assertEqual("BEGIN TRANSACTION;COMMIT;", schema) @@ -224,7 +225,7 @@ class TestDefaultFlavorsFixture(testtools.TestCase): self.useFixture(fixtures.Database()) self.useFixture(fixtures.Database(database='api')) - engine = session.get_api_engine() + engine = api_db_api.get_engine() conn = engine.connect() result = conn.execute("select * from flavors") rows = result.fetchall()