db: Post reshuffle cleanup

Introduce a new 'nova.db.api.api' module to hold API database-specific
helpers, plus a generic 'nova.db.utils' module to hold code suitable for
both main and API databases. This highlights a level of complexity
around connection management that is present for the main database but
not for the API database. This is because we need to handle the
complexity of cells for the former but not the latter.

Change-Id: Ia5304c552ce552ae3c5223a2bfb3a9cd543ec57c
Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
Stephen Finucane 2021-04-01 12:14:33 +01:00
parent bf8b5fc7d0
commit 43b253cd60
32 changed files with 625 additions and 484 deletions

View File

@ -32,7 +32,8 @@ from nova.cmd import common as cmd_common
import nova.conf import nova.conf
from nova import config from nova import config
from nova import context as nova_context from nova import context as nova_context
from nova.db.main import api as db_session from nova.db.api import api as api_db_api
from nova.db.main import api as main_db_api
from nova import exception from nova import exception
from nova.i18n import _ from nova.i18n import _
from nova.objects import cell_mapping as cell_mapping_obj from nova.objects import cell_mapping as cell_mapping_obj
@ -86,7 +87,7 @@ class UpgradeCommands(upgradecheck.UpgradeCommands):
# table, or by only counting compute nodes with a service version of at # table, or by only counting compute nodes with a service version of at
# least 15 which was the highest service version when Newton was # least 15 which was the highest service version when Newton was
# released. # released.
meta = sa.MetaData(bind=db_session.get_engine(context=context)) meta = sa.MetaData(bind=main_db_api.get_engine(context=context))
compute_nodes = sa.Table('compute_nodes', meta, autoload=True) compute_nodes = sa.Table('compute_nodes', meta, autoload=True)
return sa.select([sqlfunc.count()]).select_from(compute_nodes).where( return sa.select([sqlfunc.count()]).select_from(compute_nodes).where(
compute_nodes.c.deleted == 0).scalar() compute_nodes.c.deleted == 0).scalar()
@ -103,7 +104,7 @@ class UpgradeCommands(upgradecheck.UpgradeCommands):
for compute nodes if there are no host mappings on a fresh install. for compute nodes if there are no host mappings on a fresh install.
""" """
meta = sa.MetaData() meta = sa.MetaData()
meta.bind = db_session.get_api_engine() meta.bind = api_db_api.get_engine()
cell_mappings = self._get_cell_mappings() cell_mappings = self._get_cell_mappings()
count = len(cell_mappings) count = len(cell_mappings)

View File

@ -53,7 +53,8 @@ from nova import conductor
import nova.conf import nova.conf
from nova import context as nova_context from nova import context as nova_context
from nova import crypto from nova import crypto
from nova.db.main import api as db_api from nova.db.api import api as api_db_api
from nova.db.main import api as main_db_api
from nova import exception from nova import exception
from nova import exception_wrapper from nova import exception_wrapper
from nova.i18n import _ from nova.i18n import _
@ -1081,7 +1082,7 @@ class API:
network_metadata) network_metadata)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_reqspec_buildreq_instmapping(context, rs, br, im): def _create_reqspec_buildreq_instmapping(context, rs, br, im):
"""Create the request spec, build request, and instance mapping in a """Create the request spec, build request, and instance mapping in a
single database transaction. single database transaction.
@ -5082,7 +5083,7 @@ class API:
def get_instance_metadata(self, context, instance): def get_instance_metadata(self, context, instance):
"""Get all metadata associated with an instance.""" """Get all metadata associated with an instance."""
return db_api.instance_metadata_get(context, instance.uuid) return main_db_api.instance_metadata_get(context, instance.uuid)
@check_instance_lock @check_instance_lock
@check_instance_state(vm_state=[vm_states.ACTIVE, vm_states.PAUSED, @check_instance_state(vm_state=[vm_states.ACTIVE, vm_states.PAUSED,
@ -5962,7 +5963,7 @@ class HostAPI:
"""Return the task logs within a given range, optionally """Return the task logs within a given range, optionally
filtering by host and/or state. filtering by host and/or state.
""" """
return db_api.task_log_get_all( return main_db_api.task_log_get_all(
context, task_name, period_beginning, period_ending, host=host, context, task_name, period_beginning, period_ending, host=host,
state=state) state=state)
@ -6055,7 +6056,7 @@ class HostAPI:
if cell.uuid == objects.CellMapping.CELL0_UUID: if cell.uuid == objects.CellMapping.CELL0_UUID:
continue continue
with nova_context.target_cell(context, cell) as cctxt: with nova_context.target_cell(context, cell) as cctxt:
cell_stats.append(db_api.compute_node_statistics(cctxt)) cell_stats.append(main_db_api.compute_node_statistics(cctxt))
if cell_stats: if cell_stats:
keys = cell_stats[0].keys() keys = cell_stats[0].keys()

View File

@ -22,7 +22,8 @@ from oslo_policy import opts as policy_opts
from oslo_utils import importutils from oslo_utils import importutils
import nova.conf import nova.conf
from nova.db.main import api as db_api from nova.db.api import api as api_db_api
from nova.db.main import api as main_db_api
from nova import middleware from nova import middleware
from nova import policy from nova import policy
from nova import rpc from nova import rpc
@ -100,4 +101,5 @@ def parse_args(argv, default_config_files=None, configure_db=True,
rpc.init(CONF) rpc.init(CONF)
if configure_db: if configure_db:
db_api.configure(CONF) main_db_api.configure(CONF)
api_db_api.configure(CONF)

50
nova/db/api/api.py Normal file
View File

@ -0,0 +1,50 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_db.sqlalchemy import enginefacade
from oslo_utils import importutils
import sqlalchemy as sa
import nova.conf
profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
CONF = nova.conf.CONF
context_manager = enginefacade.transaction_context()
# NOTE(stephenfin): We don't need equivalents of the 'get_context_manager' or
# 'create_context_manager' APIs found in 'nova.db.main.api' since we don't need
# to be cell-aware here
def _get_db_conf(conf_group, connection=None):
kw = dict(conf_group.items())
if connection is not None:
kw['connection'] = connection
return kw
def configure(conf):
context_manager.configure(**_get_db_conf(conf.api_database))
if (
profiler_sqlalchemy and
CONF.profiler.enabled and
CONF.profiler.trace_sqlalchemy
):
context_manager.append_on_engine_create(
lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db"))
def get_engine():
return context_manager.writer.get_engine()

View File

@ -22,7 +22,6 @@ import copy
import datetime import datetime
import functools import functools
import inspect import inspect
import sys
import traceback import traceback
from oslo_db import api as oslo_db_api from oslo_db import api as oslo_db_api
@ -49,6 +48,8 @@ from nova.compute import vm_states
import nova.conf import nova.conf
import nova.context import nova.context
from nova.db.main import models from nova.db.main import models
from nova.db import utils as db_utils
from nova.db.utils import require_context
from nova import exception from nova import exception
from nova.i18n import _ from nova.i18n import _
from nova import safe_utils from nova import safe_utils
@ -60,8 +61,7 @@ LOG = logging.getLogger(__name__)
DISABLE_DB_ACCESS = False DISABLE_DB_ACCESS = False
main_context_manager = enginefacade.transaction_context() context_manager = enginefacade.transaction_context()
api_context_manager = enginefacade.transaction_context()
def _get_db_conf(conf_group, connection=None): def _get_db_conf(conf_group, connection=None):
@ -89,15 +89,14 @@ def _joinedload_all(column):
def configure(conf): def configure(conf):
main_context_manager.configure(**_get_db_conf(conf.database)) context_manager.configure(**_get_db_conf(conf.database))
api_context_manager.configure(**_get_db_conf(conf.api_database))
if profiler_sqlalchemy and CONF.profiler.enabled \ if (
and CONF.profiler.trace_sqlalchemy: profiler_sqlalchemy and
CONF.profiler.enabled and
main_context_manager.append_on_engine_create( CONF.profiler.trace_sqlalchemy
lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) ):
api_context_manager.append_on_engine_create( context_manager.append_on_engine_create(
lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db")) lambda eng: profiler_sqlalchemy.add_tracing(sa, eng, "db"))
@ -116,7 +115,7 @@ def get_context_manager(context):
:param context: The request context that can contain a context manager :param context: The request context that can contain a context manager
""" """
return _context_manager_from_context(context) or main_context_manager return _context_manager_from_context(context) or context_manager
def get_engine(use_slave=False, context=None): def get_engine(use_slave=False, context=None):
@ -131,40 +130,11 @@ def get_engine(use_slave=False, context=None):
return ctxt_mgr.writer.get_engine() return ctxt_mgr.writer.get_engine()
def get_api_engine():
return api_context_manager.writer.get_engine()
_SHADOW_TABLE_PREFIX = 'shadow_' _SHADOW_TABLE_PREFIX = 'shadow_'
_DEFAULT_QUOTA_NAME = 'default' _DEFAULT_QUOTA_NAME = 'default'
PER_PROJECT_QUOTAS = ['fixed_ips', 'floating_ips', 'networks'] PER_PROJECT_QUOTAS = ['fixed_ips', 'floating_ips', 'networks']
# NOTE(stephenfin): This is required and used by oslo.db
def get_backend():
"""The backend is this module itself."""
return sys.modules[__name__]
def require_context(f):
"""Decorator to require *any* user or admin context.
This does no authorization for user or project access matching, see
:py:func:`nova.context.authorize_project_context` and
:py:func:`nova.context.authorize_user_context`.
The first argument to the wrapped function must be the context.
"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
nova.context.require_context(args[0])
return f(*args, **kwargs)
wrapper.__signature__ = inspect.signature(f)
return wrapper
def select_db_reader_mode(f): def select_db_reader_mode(f):
"""Decorator to select synchronous or asynchronous reader mode. """Decorator to select synchronous or asynchronous reader mode.
@ -1662,9 +1632,8 @@ def instance_get_all_by_filters_sort(context, filters, limit=None, marker=None,
if limit == 0: if limit == 0:
return [] return []
sort_keys, sort_dirs = process_sort_params(sort_keys, sort_keys, sort_dirs = db_utils.process_sort_params(
sort_dirs, sort_keys, sort_dirs, default_dir='desc')
default_dir='desc')
if columns_to_join is None: if columns_to_join is None:
columns_to_join_new = ['info_cache', 'security_groups'] columns_to_join_new = ['info_cache', 'security_groups']
@ -2043,75 +2012,6 @@ def _exact_instance_filter(query, filters, legal_keys):
return query return query
def process_sort_params(sort_keys, sort_dirs,
default_keys=['created_at', 'id'],
default_dir='asc'):
"""Process the sort parameters to include default keys.
Creates a list of sort keys and a list of sort directions. Adds the default
keys to the end of the list if they are not already included.
When adding the default keys to the sort keys list, the associated
direction is:
1) The first element in the 'sort_dirs' list (if specified), else
2) 'default_dir' value (Note that 'asc' is the default value since this is
the default in sqlalchemy.utils.paginate_query)
:param sort_keys: List of sort keys to include in the processed list
:param sort_dirs: List of sort directions to include in the processed list
:param default_keys: List of sort keys that need to be included in the
processed list, they are added at the end of the list
if not already specified.
:param default_dir: Sort direction associated with each of the default
keys that are not supplied, used when they are added
to the processed list
:returns: list of sort keys, list of sort directions
:raise exception.InvalidInput: If more sort directions than sort keys
are specified or if an invalid sort
direction is specified
"""
# Determine direction to use for when adding default keys
if sort_dirs and len(sort_dirs) != 0:
default_dir_value = sort_dirs[0]
else:
default_dir_value = default_dir
# Create list of keys (do not modify the input list)
if sort_keys:
result_keys = list(sort_keys)
else:
result_keys = []
# If a list of directions is not provided, use the default sort direction
# for all provided keys
if sort_dirs:
result_dirs = []
# Verify sort direction
for sort_dir in sort_dirs:
if sort_dir not in ('asc', 'desc'):
msg = _("Unknown sort direction, must be 'desc' or 'asc'")
raise exception.InvalidInput(reason=msg)
result_dirs.append(sort_dir)
else:
result_dirs = [default_dir_value for _sort_key in result_keys]
# Ensure that the key and direction length match
while len(result_dirs) < len(result_keys):
result_dirs.append(default_dir_value)
# Unless more direction are specified, which is an error
if len(result_dirs) > len(result_keys):
msg = _("Sort direction size exceeds sort key size")
raise exception.InvalidInput(reason=msg)
# Ensure defaults are included
for key in default_keys:
if key not in result_keys:
result_keys.append(key)
result_dirs.append(default_dir_value)
return result_keys, result_dirs
@require_context @require_context
@pick_context_manager_reader_allow_async @pick_context_manager_reader_allow_async
def instance_get_active_by_window_joined(context, begin, end=None, def instance_get_active_by_window_joined(context, begin, end=None,
@ -3507,8 +3407,8 @@ def migration_get_all_by_filters(context, filters,
raise exception.MarkerNotFound(marker=marker) raise exception.MarkerNotFound(marker=marker)
if limit or marker or sort_keys or sort_dirs: if limit or marker or sort_keys or sort_dirs:
# Default sort by desc(['created_at', 'id']) # Default sort by desc(['created_at', 'id'])
sort_keys, sort_dirs = process_sort_params(sort_keys, sort_dirs, sort_keys, sort_dirs = db_utils.process_sort_params(
default_dir='desc') sort_keys, sort_dirs, default_dir='desc')
return sqlalchemyutils.paginate_query(query, return sqlalchemyutils.paginate_query(query,
models.Migration, models.Migration,
limit=limit, limit=limit,

View File

@ -22,7 +22,8 @@ from migrate.versioning.repository import Repository
from oslo_log import log as logging from oslo_log import log as logging
import sqlalchemy import sqlalchemy
from nova.db.main import api as db_session from nova.db.api import api as api_db_api
from nova.db.main import api as main_db_api
from nova import exception from nova import exception
from nova.i18n import _ from nova.i18n import _
@ -36,10 +37,10 @@ LOG = logging.getLogger(__name__)
def get_engine(database='main', context=None): def get_engine(database='main', context=None):
if database == 'main': if database == 'main':
return db_session.get_engine(context=context) return main_db_api.get_engine(context=context)
if database == 'api': if database == 'api':
return db_session.get_api_engine() return api_db_api.get_engine()
def find_migrate_repo(database='main'): def find_migrate_repo(database='main'):

109
nova/db/utils.py Normal file
View File

@ -0,0 +1,109 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import inspect
import nova.context
from nova import exception
from nova.i18n import _
def require_context(f):
"""Decorator to require *any* user or admin context.
This does no authorization for user or project access matching, see
:py:func:`nova.context.authorize_project_context` and
:py:func:`nova.context.authorize_user_context`.
The first argument to the wrapped function must be the context.
"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
nova.context.require_context(args[0])
return f(*args, **kwargs)
wrapper.__signature__ = inspect.signature(f)
return wrapper
def process_sort_params(
sort_keys,
sort_dirs,
default_keys=['created_at', 'id'],
default_dir='asc',
):
"""Process the sort parameters to include default keys.
Creates a list of sort keys and a list of sort directions. Adds the default
keys to the end of the list if they are not already included.
When adding the default keys to the sort keys list, the associated
direction is:
1. The first element in the 'sort_dirs' list (if specified), else
2. 'default_dir' value (Note that 'asc' is the default value since this is
the default in sqlalchemy.utils.paginate_query)
:param sort_keys: List of sort keys to include in the processed list
:param sort_dirs: List of sort directions to include in the processed list
:param default_keys: List of sort keys that need to be included in the
processed list, they are added at the end of the list if not already
specified.
:param default_dir: Sort direction associated with each of the default
keys that are not supplied, used when they are added to the processed
list
:returns: list of sort keys, list of sort directions
:raise exception.InvalidInput: If more sort directions than sort keys
are specified or if an invalid sort direction is specified
"""
# Determine direction to use for when adding default keys
if sort_dirs and len(sort_dirs) != 0:
default_dir_value = sort_dirs[0]
else:
default_dir_value = default_dir
# Create list of keys (do not modify the input list)
if sort_keys:
result_keys = list(sort_keys)
else:
result_keys = []
# If a list of directions is not provided, use the default sort direction
# for all provided keys
if sort_dirs:
result_dirs = []
# Verify sort direction
for sort_dir in sort_dirs:
if sort_dir not in ('asc', 'desc'):
msg = _("Unknown sort direction, must be 'desc' or 'asc'")
raise exception.InvalidInput(reason=msg)
result_dirs.append(sort_dir)
else:
result_dirs = [default_dir_value for _sort_key in result_keys]
# Ensure that the key and direction length match
while len(result_dirs) < len(result_keys):
result_dirs.append(default_dir_value)
# Unless more direction are specified, which is an error
if len(result_dirs) > len(result_keys):
msg = _("Sort direction size exceeds sort key size")
raise exception.InvalidInput(reason=msg)
# Ensure defaults are included
for key in default_keys:
if key not in result_keys:
result_keys.append(key)
result_dirs.append(default_dir_value)
return result_keys, result_dirs

View File

@ -19,8 +19,8 @@ from oslo_utils import uuidutils
from sqlalchemy import orm from sqlalchemy import orm
from nova.compute import utils as compute_utils from nova.compute import utils as compute_utils
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api
from nova import exception from nova import exception
from nova.i18n import _ from nova.i18n import _
from nova import objects from nova import objects
@ -32,7 +32,7 @@ LOG = logging.getLogger(__name__)
DEPRECATED_FIELDS = ['deleted', 'deleted_at'] DEPRECATED_FIELDS = ['deleted', 'deleted_at']
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _aggregate_get_from_db(context, aggregate_id): def _aggregate_get_from_db(context, aggregate_id):
query = context.session.query(api_models.Aggregate).\ query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\ options(orm.joinedload('_hosts')).\
@ -47,7 +47,7 @@ def _aggregate_get_from_db(context, aggregate_id):
return aggregate return aggregate
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _aggregate_get_from_db_by_uuid(context, aggregate_uuid): def _aggregate_get_from_db_by_uuid(context, aggregate_uuid):
query = context.session.query(api_models.Aggregate).\ query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\ options(orm.joinedload('_hosts')).\
@ -64,7 +64,7 @@ def _aggregate_get_from_db_by_uuid(context, aggregate_uuid):
def _host_add_to_db(context, aggregate_id, host): def _host_add_to_db(context, aggregate_id, host):
try: try:
with db_api.api_context_manager.writer.using(context): with api_db_api.context_manager.writer.using(context):
# Check to see if the aggregate exists # Check to see if the aggregate exists
_aggregate_get_from_db(context, aggregate_id) _aggregate_get_from_db(context, aggregate_id)
@ -79,7 +79,7 @@ def _host_add_to_db(context, aggregate_id, host):
def _host_delete_from_db(context, aggregate_id, host): def _host_delete_from_db(context, aggregate_id, host):
count = 0 count = 0
with db_api.api_context_manager.writer.using(context): with api_db_api.context_manager.writer.using(context):
# Check to see if the aggregate exists # Check to see if the aggregate exists
_aggregate_get_from_db(context, aggregate_id) _aggregate_get_from_db(context, aggregate_id)
@ -98,7 +98,7 @@ def _metadata_add_to_db(context, aggregate_id, metadata, max_retries=10,
all_keys = metadata.keys() all_keys = metadata.keys()
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
with db_api.api_context_manager.writer.using(context): with api_db_api.context_manager.writer.using(context):
query = context.session.query(api_models.AggregateMetadata).\ query = context.session.query(api_models.AggregateMetadata).\
filter_by(aggregate_id=aggregate_id) filter_by(aggregate_id=aggregate_id)
@ -142,7 +142,7 @@ def _metadata_add_to_db(context, aggregate_id, metadata, max_retries=10,
LOG.warning(msg) LOG.warning(msg)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _metadata_delete_from_db(context, aggregate_id, key): def _metadata_delete_from_db(context, aggregate_id, key):
# Check to see if the aggregate exists # Check to see if the aggregate exists
_aggregate_get_from_db(context, aggregate_id) _aggregate_get_from_db(context, aggregate_id)
@ -157,7 +157,7 @@ def _metadata_delete_from_db(context, aggregate_id, key):
aggregate_id=aggregate_id, metadata_key=key) aggregate_id=aggregate_id, metadata_key=key)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _aggregate_create_in_db(context, values, metadata=None): def _aggregate_create_in_db(context, values, metadata=None):
query = context.session.query(api_models.Aggregate) query = context.session.query(api_models.Aggregate)
query = query.filter(api_models.Aggregate.name == values['name']) query = query.filter(api_models.Aggregate.name == values['name'])
@ -181,7 +181,7 @@ def _aggregate_create_in_db(context, values, metadata=None):
return aggregate return aggregate
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _aggregate_delete_from_db(context, aggregate_id): def _aggregate_delete_from_db(context, aggregate_id):
# Delete Metadata first # Delete Metadata first
context.session.query(api_models.AggregateMetadata).\ context.session.query(api_models.AggregateMetadata).\
@ -196,7 +196,7 @@ def _aggregate_delete_from_db(context, aggregate_id):
raise exception.AggregateNotFound(aggregate_id=aggregate_id) raise exception.AggregateNotFound(aggregate_id=aggregate_id)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _aggregate_update_to_db(context, aggregate_id, values): def _aggregate_update_to_db(context, aggregate_id, values):
aggregate = _aggregate_get_from_db(context, aggregate_id) aggregate = _aggregate_get_from_db(context, aggregate_id)
@ -411,7 +411,7 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject):
return self.metadata.get('availability_zone', None) return self.metadata.get('availability_zone', None)
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_all_from_db(context): def _get_all_from_db(context):
query = context.session.query(api_models.Aggregate).\ query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\ options(orm.joinedload('_hosts')).\
@ -420,7 +420,7 @@ def _get_all_from_db(context):
return query.all() return query.all()
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_host_from_db(context, host, key=None): def _get_by_host_from_db(context, host, key=None):
query = context.session.query(api_models.Aggregate).\ query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\ options(orm.joinedload('_hosts')).\
@ -435,7 +435,7 @@ def _get_by_host_from_db(context, host, key=None):
return query.all() return query.all()
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_metadata_from_db(context, key=None, value=None): def _get_by_metadata_from_db(context, key=None, value=None):
assert(key is not None or value is not None) assert(key is not None or value is not None)
query = context.session.query(api_models.Aggregate) query = context.session.query(api_models.Aggregate)
@ -450,7 +450,7 @@ def _get_by_metadata_from_db(context, key=None, value=None):
return query.all() return query.all()
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_non_matching_by_metadata_keys_from_db(context, ignored_keys, def _get_non_matching_by_metadata_keys_from_db(context, ignored_keys,
key_prefix, value): key_prefix, value):
"""Filter aggregates based on non matching metadata. """Filter aggregates based on non matching metadata.

View File

@ -18,8 +18,9 @@ from oslo_serialization import jsonutils
from oslo_utils import versionutils from oslo_utils import versionutils
from oslo_versionedobjects import exception as ovoo_exc from oslo_versionedobjects import exception as ovoo_exc
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db from nova.db import utils as db_utils
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova.objects import base from nova.objects import base
@ -163,7 +164,7 @@ class BuildRequest(base.NovaObject):
return req return req
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_instance_uuid_from_db(context, instance_uuid): def _get_by_instance_uuid_from_db(context, instance_uuid):
db_req = context.session.query(api_models.BuildRequest).filter_by( db_req = context.session.query(api_models.BuildRequest).filter_by(
instance_uuid=instance_uuid).first() instance_uuid=instance_uuid).first()
@ -177,7 +178,7 @@ class BuildRequest(base.NovaObject):
return cls._from_db_object(context, cls(), db_req) return cls._from_db_object(context, cls(), db_req)
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _create_in_db(context, updates): def _create_in_db(context, updates):
db_req = api_models.BuildRequest() db_req = api_models.BuildRequest()
db_req.update(updates) db_req.update(updates)
@ -206,7 +207,7 @@ class BuildRequest(base.NovaObject):
self._from_db_object(self._context, self, db_req) self._from_db_object(self._context, self, db_req)
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_in_db(context, instance_uuid): def _destroy_in_db(context, instance_uuid):
result = context.session.query(api_models.BuildRequest).filter_by( result = context.session.query(api_models.BuildRequest).filter_by(
instance_uuid=instance_uuid).delete() instance_uuid=instance_uuid).delete()
@ -217,7 +218,7 @@ class BuildRequest(base.NovaObject):
def destroy(self): def destroy(self):
self._destroy_in_db(self._context, self.instance_uuid) self._destroy_in_db(self._context, self.instance_uuid)
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _save_in_db(self, context, req_id, updates): def _save_in_db(self, context, req_id, updates):
db_req = context.session.query( db_req = context.session.query(
api_models.BuildRequest).filter_by(id=req_id).first() api_models.BuildRequest).filter_by(id=req_id).first()
@ -262,7 +263,7 @@ class BuildRequestList(base.ObjectListBase, base.NovaObject):
} }
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_all_from_db(context): def _get_all_from_db(context):
query = context.session.query(api_models.BuildRequest) query = context.session.query(api_models.BuildRequest)
@ -396,8 +397,8 @@ class BuildRequestList(base.ObjectListBase, base.NovaObject):
# exists. So it can be ignored. # exists. So it can be ignored.
# 'deleted' and 'cleaned' are handled above. # 'deleted' and 'cleaned' are handled above.
sort_keys, sort_dirs = db.process_sort_params(sort_keys, sort_dirs, sort_keys, sort_dirs = db_utils.process_sort_params(
default_dir='desc') sort_keys, sort_dirs, default_dir='desc')
# For other filters that don't match this, we will do regexp matching # For other filters that don't match this, we will do regexp matching
# Taken from db/sqlalchemy/api.py # Taken from db/sqlalchemy/api.py

View File

@ -18,8 +18,8 @@ from sqlalchemy import sql
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
import nova.conf import nova.conf
from nova.db.api import models as api_models from nova.db.api import api as api_db_api
from nova.db.main import api as db_api from nova.db.api import models as api_db_models
from nova import exception from nova import exception
from nova.objects import base from nova.objects import base
from nova.objects import fields from nova.objects import fields
@ -168,11 +168,11 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject):
return cell_mapping return cell_mapping
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_uuid_from_db(context, uuid): def _get_by_uuid_from_db(context, uuid):
db_mapping = context.session.query(api_models.CellMapping).filter_by( db_mapping = context.session\
uuid=uuid).first() .query(api_db_models.CellMapping).filter_by(uuid=uuid).first()
if not db_mapping: if not db_mapping:
raise exception.CellMappingNotFound(uuid=uuid) raise exception.CellMappingNotFound(uuid=uuid)
@ -185,10 +185,10 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject):
return cls._from_db_object(context, cls(), db_mapping) return cls._from_db_object(context, cls(), db_mapping)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_in_db(context, updates): def _create_in_db(context, updates):
db_mapping = api_models.CellMapping() db_mapping = api_db_models.CellMapping()
db_mapping.update(updates) db_mapping.update(updates)
db_mapping.save(context.session) db_mapping.save(context.session)
return db_mapping return db_mapping
@ -199,11 +199,11 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject):
self._from_db_object(self._context, self, db_mapping) self._from_db_object(self._context, self, db_mapping)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _save_in_db(context, uuid, updates): def _save_in_db(context, uuid, updates):
db_mapping = context.session.query( db_mapping = context.session.query(
api_models.CellMapping).filter_by(uuid=uuid).first() api_db_models.CellMapping).filter_by(uuid=uuid).first()
if not db_mapping: if not db_mapping:
raise exception.CellMappingNotFound(uuid=uuid) raise exception.CellMappingNotFound(uuid=uuid)
@ -219,10 +219,10 @@ class CellMapping(base.NovaTimestampObject, base.NovaObject):
self.obj_reset_changes() self.obj_reset_changes()
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_in_db(context, uuid): def _destroy_in_db(context, uuid):
result = context.session.query(api_models.CellMapping).filter_by( result = context.session.query(api_db_models.CellMapping).filter_by(
uuid=uuid).delete() uuid=uuid).delete()
if not result: if not result:
raise exception.CellMappingNotFound(uuid=uuid) raise exception.CellMappingNotFound(uuid=uuid)
@ -246,10 +246,10 @@ class CellMappingList(base.ObjectListBase, base.NovaObject):
} }
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_all_from_db(context): def _get_all_from_db(context):
return context.session.query(api_models.CellMapping).order_by( return context.session.query(api_db_models.CellMapping).order_by(
expression.asc(api_models.CellMapping.id)).all() expression.asc(api_db_models.CellMapping.id)).all()
@base.remotable_classmethod @base.remotable_classmethod
def get_all(cls, context): def get_all(cls, context):
@ -257,16 +257,16 @@ class CellMappingList(base.ObjectListBase, base.NovaObject):
return base.obj_make_list(context, cls(), CellMapping, db_mappings) return base.obj_make_list(context, cls(), CellMapping, db_mappings)
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_disabled_from_db(context, disabled): def _get_by_disabled_from_db(context, disabled):
if disabled: if disabled:
return context.session.query(api_models.CellMapping)\ return context.session.query(api_db_models.CellMapping)\
.filter_by(disabled=sql.true())\ .filter_by(disabled=sql.true())\
.order_by(expression.asc(api_models.CellMapping.id)).all() .order_by(expression.asc(api_db_models.CellMapping.id)).all()
else: else:
return context.session.query(api_models.CellMapping)\ return context.session.query(api_db_models.CellMapping)\
.filter_by(disabled=sql.false())\ .filter_by(disabled=sql.false())\
.order_by(expression.asc(api_models.CellMapping.id)).all() .order_by(expression.asc(api_db_models.CellMapping.id)).all()
@base.remotable_classmethod @base.remotable_classmethod
def get_by_disabled(cls, context, disabled): def get_by_disabled(cls, context, disabled):
@ -274,16 +274,16 @@ class CellMappingList(base.ObjectListBase, base.NovaObject):
return base.obj_make_list(context, cls(), CellMapping, db_mappings) return base.obj_make_list(context, cls(), CellMapping, db_mappings)
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_project_id_from_db(context, project_id): def _get_by_project_id_from_db(context, project_id):
# SELECT DISTINCT cell_id FROM instance_mappings \ # SELECT DISTINCT cell_id FROM instance_mappings \
# WHERE project_id = $project_id; # WHERE project_id = $project_id;
cell_ids = context.session.query( cell_ids = context.session.query(
api_models.InstanceMapping.cell_id).filter_by( api_db_models.InstanceMapping.cell_id).filter_by(
project_id=project_id).distinct().subquery() project_id=project_id).distinct().subquery()
# SELECT cell_mappings WHERE cell_id IN ($cell_ids); # SELECT cell_mappings WHERE cell_id IN ($cell_ids);
return context.session.query(api_models.CellMapping).filter( return context.session.query(api_db_models.CellMapping).filter(
api_models.CellMapping.id.in_(cell_ids)).all() api_db_models.CellMapping.id.in_(cell_ids)).all()
@classmethod @classmethod
def get_by_project_id(cls, context, project_id): def get_by_project_id(cls, context, project_id):

View File

@ -21,8 +21,9 @@ from sqlalchemy import sql
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
import nova.conf import nova.conf
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api from nova.db import utils as db_utils
from nova import exception from nova import exception
from nova.notifications.objects import base as notification from nova.notifications.objects import base as notification
from nova.notifications.objects import flavor as flavor_notification from nova.notifications.objects import flavor as flavor_notification
@ -51,7 +52,7 @@ def _dict_with_extra_specs(flavor_model):
# decorators with static methods. We pull these out for now and can # decorators with static methods. We pull these out for now and can
# move them back into the actual staticmethods on the object when those # move them back into the actual staticmethods on the object when those
# issues are resolved. # issues are resolved.
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_projects_from_db(context, flavorid): def _get_projects_from_db(context, flavorid):
db_flavor = context.session.query(api_models.Flavors).\ db_flavor = context.session.query(api_models.Flavors).\
filter_by(flavorid=flavorid).\ filter_by(flavorid=flavorid).\
@ -62,7 +63,7 @@ def _get_projects_from_db(context, flavorid):
return [x['project_id'] for x in db_flavor['projects']] return [x['project_id'] for x in db_flavor['projects']]
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _flavor_add_project(context, flavor_id, project_id): def _flavor_add_project(context, flavor_id, project_id):
project = api_models.FlavorProjects() project = api_models.FlavorProjects()
project.update({'flavor_id': flavor_id, project.update({'flavor_id': flavor_id,
@ -74,7 +75,7 @@ def _flavor_add_project(context, flavor_id, project_id):
project_id=project_id) project_id=project_id)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _flavor_del_project(context, flavor_id, project_id): def _flavor_del_project(context, flavor_id, project_id):
result = context.session.query(api_models.FlavorProjects).\ result = context.session.query(api_models.FlavorProjects).\
filter_by(project_id=project_id).\ filter_by(project_id=project_id).\
@ -85,9 +86,9 @@ def _flavor_del_project(context, flavor_id, project_id):
project_id=project_id) project_id=project_id)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _flavor_extra_specs_add(context, flavor_id, specs, max_retries=10): def _flavor_extra_specs_add(context, flavor_id, specs, max_retries=10):
writer = db_api.api_context_manager.writer writer = api_db_api.context_manager.writer
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
spec_refs = context.session.query( spec_refs = context.session.query(
@ -122,7 +123,7 @@ def _flavor_extra_specs_add(context, flavor_id, specs, max_retries=10):
id=flavor_id, retries=max_retries) id=flavor_id, retries=max_retries)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _flavor_extra_specs_del(context, flavor_id, key): def _flavor_extra_specs_del(context, flavor_id, key):
result = context.session.query(api_models.FlavorExtraSpecs).\ result = context.session.query(api_models.FlavorExtraSpecs).\
filter_by(flavor_id=flavor_id).\ filter_by(flavor_id=flavor_id).\
@ -133,7 +134,7 @@ def _flavor_extra_specs_del(context, flavor_id, key):
extra_specs_key=key, flavor_id=flavor_id) extra_specs_key=key, flavor_id=flavor_id)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _flavor_create(context, values): def _flavor_create(context, values):
specs = values.get('extra_specs') specs = values.get('extra_specs')
db_specs = [] db_specs = []
@ -169,7 +170,7 @@ def _flavor_create(context, values):
return _dict_with_extra_specs(db_flavor) return _dict_with_extra_specs(db_flavor)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _flavor_destroy(context, flavor_id=None, flavorid=None): def _flavor_destroy(context, flavor_id=None, flavorid=None):
query = context.session.query(api_models.Flavors) query = context.session.query(api_models.Flavors)
@ -268,7 +269,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject,
return flavor return flavor
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _flavor_get_query_from_db(context): def _flavor_get_query_from_db(context):
query = context.session.query(api_models.Flavors).\ query = context.session.query(api_models.Flavors).\
options(orm.joinedload('extra_specs')) options(orm.joinedload('extra_specs'))
@ -281,7 +282,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject,
return query return query
@staticmethod @staticmethod
@db_api.require_context @db_utils.require_context
def _flavor_get_from_db(context, id): def _flavor_get_from_db(context, id):
"""Returns a dict describing specific flavor.""" """Returns a dict describing specific flavor."""
result = Flavor._flavor_get_query_from_db(context).\ result = Flavor._flavor_get_query_from_db(context).\
@ -292,7 +293,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject,
return _dict_with_extra_specs(result) return _dict_with_extra_specs(result)
@staticmethod @staticmethod
@db_api.require_context @db_utils.require_context
def _flavor_get_by_name_from_db(context, name): def _flavor_get_by_name_from_db(context, name):
"""Returns a dict describing specific flavor.""" """Returns a dict describing specific flavor."""
result = Flavor._flavor_get_query_from_db(context).\ result = Flavor._flavor_get_query_from_db(context).\
@ -303,7 +304,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject,
return _dict_with_extra_specs(result) return _dict_with_extra_specs(result)
@staticmethod @staticmethod
@db_api.require_context @db_utils.require_context
def _flavor_get_by_flavor_id_from_db(context, flavor_id): def _flavor_get_by_flavor_id_from_db(context, flavor_id):
"""Returns a dict describing specific flavor_id.""" """Returns a dict describing specific flavor_id."""
result = Flavor._flavor_get_query_from_db(context).\ result = Flavor._flavor_get_query_from_db(context).\
@ -485,7 +486,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject,
# NOTE(mriedem): This method is not remotable since we only expect the API # NOTE(mriedem): This method is not remotable since we only expect the API
# to be able to make updates to a flavor. # to be able to make updates to a flavor.
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _save(self, context, values): def _save(self, context, values):
db_flavor = context.session.query(api_models.Flavors).\ db_flavor = context.session.query(api_models.Flavors).\
filter_by(id=self.id).first() filter_by(id=self.id).first()
@ -581,7 +582,7 @@ class Flavor(base.NovaPersistentObject, base.NovaObject,
payload=payload).emit(self._context) payload=payload).emit(self._context)
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _flavor_get_all_from_db(context, inactive, filters, sort_key, sort_dir, def _flavor_get_all_from_db(context, inactive, filters, sort_key, sort_dir,
limit, marker): limit, marker):
"""Returns all flavors. """Returns all flavors.

View File

@ -14,8 +14,8 @@ from oslo_db import exception as db_exc
from sqlalchemy import orm from sqlalchemy import orm
from nova import context from nova import context
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api
from nova import exception from nova import exception
from nova.i18n import _ from nova.i18n import _
from nova.objects import base from nova.objects import base
@ -52,7 +52,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject):
} }
def _get_cell_mapping(self): def _get_cell_mapping(self):
with db_api.api_context_manager.reader.using(self._context) as session: with api_db_api.context_manager.reader.using(self._context) as session:
cell_map = (session.query(api_models.CellMapping) cell_map = (session.query(api_models.CellMapping)
.join(api_models.HostMapping) .join(api_models.HostMapping)
.filter(api_models.HostMapping.host == self.host) .filter(api_models.HostMapping.host == self.host)
@ -87,7 +87,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject):
return host_mapping return host_mapping
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_host_from_db(context, host): def _get_by_host_from_db(context, host):
db_mapping = context.session.query(api_models.HostMapping)\ db_mapping = context.session.query(api_models.HostMapping)\
.options(orm.joinedload('cell_mapping'))\ .options(orm.joinedload('cell_mapping'))\
@ -102,7 +102,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject):
return cls._from_db_object(context, cls(), db_mapping) return cls._from_db_object(context, cls(), db_mapping)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_in_db(context, updates): def _create_in_db(context, updates):
db_mapping = api_models.HostMapping() db_mapping = api_models.HostMapping()
return _apply_updates(context, db_mapping, updates) return _apply_updates(context, db_mapping, updates)
@ -116,7 +116,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject):
self._from_db_object(self._context, self, db_mapping) self._from_db_object(self._context, self, db_mapping)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _save_in_db(context, obj, updates): def _save_in_db(context, obj, updates):
db_mapping = context.session.query(api_models.HostMapping).filter_by( db_mapping = context.session.query(api_models.HostMapping).filter_by(
id=obj.id).first() id=obj.id).first()
@ -134,7 +134,7 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject):
self.obj_reset_changes() self.obj_reset_changes()
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_in_db(context, host): def _destroy_in_db(context, host):
result = context.session.query(api_models.HostMapping).filter_by( result = context.session.query(api_models.HostMapping).filter_by(
host=host).delete() host=host).delete()
@ -157,7 +157,7 @@ class HostMappingList(base.ObjectListBase, base.NovaObject):
} }
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db(context, cell_id=None): def _get_from_db(context, cell_id=None):
query = (context.session.query(api_models.HostMapping) query = (context.session.query(api_models.HostMapping)
.options(orm.joinedload('cell_mapping'))) .options(orm.joinedload('cell_mapping')))

View File

@ -22,8 +22,8 @@ from oslo_utils import versionutils
from sqlalchemy import orm from sqlalchemy import orm
from nova.compute import utils as compute_utils from nova.compute import utils as compute_utils
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova.objects import base from nova.objects import base
@ -213,7 +213,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
return instance_group return instance_group
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db_by_uuid(context, uuid): def _get_from_db_by_uuid(context, uuid):
grp = _instance_group_get_query(context, grp = _instance_group_get_query(context,
id_field=api_models.InstanceGroup.uuid, id_field=api_models.InstanceGroup.uuid,
@ -223,7 +223,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
return grp return grp
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db_by_id(context, id): def _get_from_db_by_id(context, id):
grp = _instance_group_get_query(context, grp = _instance_group_get_query(context,
id_field=api_models.InstanceGroup.id, id_field=api_models.InstanceGroup.id,
@ -233,7 +233,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
return grp return grp
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db_by_name(context, name): def _get_from_db_by_name(context, name):
grp = _instance_group_get_query(context).filter_by(name=name).first() grp = _instance_group_get_query(context).filter_by(name=name).first()
if not grp: if not grp:
@ -241,7 +241,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
return grp return grp
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db_by_instance(context, instance_uuid): def _get_from_db_by_instance(context, instance_uuid):
grp_member = context.session.query(api_models.InstanceGroupMember).\ grp_member = context.session.query(api_models.InstanceGroupMember).\
filter_by(instance_uuid=instance_uuid).first() filter_by(instance_uuid=instance_uuid).first()
@ -251,7 +251,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
return grp return grp
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _save_in_db(context, group_uuid, values): def _save_in_db(context, group_uuid, values):
grp = InstanceGroup._get_from_db_by_uuid(context, group_uuid) grp = InstanceGroup._get_from_db_by_uuid(context, group_uuid)
values_copy = copy.copy(values) values_copy = copy.copy(values)
@ -265,7 +265,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
return grp return grp
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_in_db(context, values, policies=None, members=None, def _create_in_db(context, values, policies=None, members=None,
policy=None, rules=None): policy=None, rules=None):
try: try:
@ -301,7 +301,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
return group return group
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_in_db(context, group_uuid): def _destroy_in_db(context, group_uuid):
qry = _instance_group_get_query(context, qry = _instance_group_get_query(context,
id_field=api_models.InstanceGroup.uuid, id_field=api_models.InstanceGroup.uuid,
@ -319,13 +319,13 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
qry.delete() qry.delete()
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _add_members_in_db(context, group_uuid, members): def _add_members_in_db(context, group_uuid, members):
return _instance_group_members_add_by_uuid(context, group_uuid, return _instance_group_members_add_by_uuid(context, group_uuid,
members) members)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _remove_members_in_db(context, group_id, instance_uuids): def _remove_members_in_db(context, group_id, instance_uuids):
# There is no public method provided for removing members because the # There is no public method provided for removing members because the
# user-facing API doesn't allow removal of instance group members. We # user-facing API doesn't allow removal of instance group members. We
@ -337,7 +337,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject,
delete(synchronize_session=False) delete(synchronize_session=False)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_members_bulk_in_db(context, instance_uuids): def _destroy_members_bulk_in_db(context, instance_uuids):
return context.session.query(api_models.InstanceGroupMember).filter( return context.session.query(api_models.InstanceGroupMember).filter(
api_models.InstanceGroupMember.instance_uuid.in_(instance_uuids)).\ api_models.InstanceGroupMember.instance_uuid.in_(instance_uuids)).\
@ -537,7 +537,7 @@ class InstanceGroupList(base.ObjectListBase, base.NovaObject):
} }
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db(context, project_id=None): def _get_from_db(context, project_id=None):
query = _instance_group_get_query(context) query = _instance_group_get_query(context)
if project_id is not None: if project_id is not None:
@ -545,7 +545,7 @@ class InstanceGroupList(base.ObjectListBase, base.NovaObject):
return query.all() return query.all()
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_counts_from_db(context, project_id, user_id=None): def _get_counts_from_db(context, project_id, user_id=None):
query = context.session.query(api_models.InstanceGroup.id).\ query = context.session.query(api_models.InstanceGroup.id).\
filter_by(project_id=project_id) filter_by(project_id=project_id)

View File

@ -20,8 +20,8 @@ from sqlalchemy import sql
from sqlalchemy.sql import func from sqlalchemy.sql import func
from nova import context as nova_context from nova import context as nova_context
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api
from nova import exception from nova import exception
from nova.i18n import _ from nova.i18n import _
from nova import objects from nova import objects
@ -96,7 +96,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject):
return instance_mapping return instance_mapping
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_instance_uuid_from_db(context, instance_uuid): def _get_by_instance_uuid_from_db(context, instance_uuid):
db_mapping = context.session.query(api_models.InstanceMapping)\ db_mapping = context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\ .options(orm.joinedload('cell_mapping'))\
@ -113,7 +113,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject):
return cls._from_db_object(context, cls(), db_mapping) return cls._from_db_object(context, cls(), db_mapping)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_in_db(context, updates): def _create_in_db(context, updates):
db_mapping = api_models.InstanceMapping() db_mapping = api_models.InstanceMapping()
db_mapping.update(updates) db_mapping.update(updates)
@ -138,7 +138,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject):
self._from_db_object(self._context, self, db_mapping) self._from_db_object(self._context, self, db_mapping)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _save_in_db(context, instance_uuid, updates): def _save_in_db(context, instance_uuid, updates):
db_mapping = context.session.query( db_mapping = context.session.query(
api_models.InstanceMapping).filter_by( api_models.InstanceMapping).filter_by(
@ -173,7 +173,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject):
self.obj_reset_changes() self.obj_reset_changes()
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_in_db(context, instance_uuid): def _destroy_in_db(context, instance_uuid):
result = context.session.query(api_models.InstanceMapping).filter_by( result = context.session.query(api_models.InstanceMapping).filter_by(
instance_uuid=instance_uuid).delete() instance_uuid=instance_uuid).delete()
@ -185,7 +185,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject):
self._destroy_in_db(self._context, self.instance_uuid) self._destroy_in_db(self._context, self.instance_uuid)
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def populate_queued_for_delete(context, max_count): def populate_queued_for_delete(context, max_count):
cells = objects.CellMappingList.get_all(context) cells = objects.CellMappingList.get_all(context)
processed = 0 processed = 0
@ -229,7 +229,7 @@ def populate_queued_for_delete(context, max_count):
return processed, processed return processed, processed
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def populate_user_id(context, max_count): def populate_user_id(context, max_count):
cells = objects.CellMappingList.get_all(context) cells = objects.CellMappingList.get_all(context)
cms_by_id = {cell.id: cell for cell in cells} cms_by_id = {cell.id: cell for cell in cells}
@ -309,7 +309,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
} }
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_project_id_from_db(context, project_id): def _get_by_project_id_from_db(context, project_id):
return context.session.query(api_models.InstanceMapping)\ return context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\ .options(orm.joinedload('cell_mapping'))\
@ -323,7 +323,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
db_mappings) db_mappings)
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_cell_id_from_db(context, cell_id): def _get_by_cell_id_from_db(context, cell_id):
return context.session.query(api_models.InstanceMapping)\ return context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\ .options(orm.joinedload('cell_mapping'))\
@ -336,7 +336,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
db_mappings) db_mappings)
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_instance_uuids_from_db(context, uuids): def _get_by_instance_uuids_from_db(context, uuids):
return context.session.query(api_models.InstanceMapping)\ return context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\ .options(orm.joinedload('cell_mapping'))\
@ -350,7 +350,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
db_mappings) db_mappings)
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_bulk_in_db(context, instance_uuids): def _destroy_bulk_in_db(context, instance_uuids):
return context.session.query(api_models.InstanceMapping).filter( return context.session.query(api_models.InstanceMapping).filter(
api_models.InstanceMapping.instance_uuid.in_(instance_uuids)).\ api_models.InstanceMapping.instance_uuid.in_(instance_uuids)).\
@ -361,7 +361,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
return cls._destroy_bulk_in_db(context, instance_uuids) return cls._destroy_bulk_in_db(context, instance_uuids)
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_not_deleted_by_cell_and_project_from_db(context, cell_uuid, def _get_not_deleted_by_cell_and_project_from_db(context, cell_uuid,
project_id, limit): project_id, limit):
query = context.session.query(api_models.InstanceMapping) query = context.session.query(api_models.InstanceMapping)
@ -400,7 +400,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
db_mappings) db_mappings)
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_counts_in_db(context, project_id, user_id=None): def _get_counts_in_db(context, project_id, user_id=None):
project_query = context.session.query( project_query = context.session.query(
func.count(api_models.InstanceMapping.id)).\ func.count(api_models.InstanceMapping.id)).\
@ -435,7 +435,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
return cls._get_counts_in_db(context, project_id, user_id=user_id) return cls._get_counts_in_db(context, project_id, user_id=user_id)
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _get_count_by_uuids_and_user_in_db(context, uuids, user_id): def _get_count_by_uuids_and_user_in_db(context, uuids, user_id):
query = (context.session.query( query = (context.session.query(
func.count(api_models.InstanceMapping.id)) func.count(api_models.InstanceMapping.id))

View File

@ -17,8 +17,9 @@ from oslo_db.sqlalchemy import utils as sqlalchemyutils
from oslo_log import log as logging from oslo_log import log as logging
from oslo_utils import versionutils from oslo_utils import versionutils
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db from nova.db.main import api as main_db_api
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova.objects import base from nova.objects import base
@ -29,7 +30,7 @@ KEYPAIR_TYPE_X509 = 'x509'
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db(context, user_id, name=None, limit=None, marker=None): def _get_from_db(context, user_id, name=None, limit=None, marker=None):
query = context.session.query(api_models.KeyPair).\ query = context.session.query(api_models.KeyPair).\
filter(api_models.KeyPair.user_id == user_id) filter(api_models.KeyPair.user_id == user_id)
@ -54,14 +55,14 @@ def _get_from_db(context, user_id, name=None, limit=None, marker=None):
return query.all() return query.all()
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_count_from_db(context, user_id): def _get_count_from_db(context, user_id):
return context.session.query(api_models.KeyPair).\ return context.session.query(api_models.KeyPair).\
filter(api_models.KeyPair.user_id == user_id).\ filter(api_models.KeyPair.user_id == user_id).\
count() count()
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _create_in_db(context, values): def _create_in_db(context, values):
kp = api_models.KeyPair() kp = api_models.KeyPair()
kp.update(values) kp.update(values)
@ -72,7 +73,7 @@ def _create_in_db(context, values):
return kp return kp
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_in_db(context, user_id, name): def _destroy_in_db(context, user_id, name):
result = context.session.query(api_models.KeyPair).\ result = context.session.query(api_models.KeyPair).\
filter_by(user_id=user_id).\ filter_by(user_id=user_id).\
@ -143,7 +144,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject,
except exception.KeypairNotFound: except exception.KeypairNotFound:
pass pass
if db_keypair is None: if db_keypair is None:
db_keypair = db.key_pair_get(context, user_id, name) db_keypair = main_db_api.key_pair_get(context, user_id, name)
return cls._from_db_object(context, cls(), db_keypair) return cls._from_db_object(context, cls(), db_keypair)
@base.remotable_classmethod @base.remotable_classmethod
@ -151,7 +152,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject,
try: try:
cls._destroy_in_db(context, user_id, name) cls._destroy_in_db(context, user_id, name)
except exception.KeypairNotFound: except exception.KeypairNotFound:
db.key_pair_destroy(context, user_id, name) main_db_api.key_pair_destroy(context, user_id, name)
@base.remotable @base.remotable
def create(self): def create(self):
@ -163,7 +164,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject,
# letting them create in the API DB, since we won't get protection # letting them create in the API DB, since we won't get protection
# from the UC. # from the UC.
try: try:
db.key_pair_get(self._context, self.user_id, self.name) main_db_api.key_pair_get(self._context, self.user_id, self.name)
raise exception.KeyPairExists(key_name=self.name) raise exception.KeyPairExists(key_name=self.name)
except exception.KeypairNotFound: except exception.KeypairNotFound:
pass pass
@ -180,7 +181,8 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject,
try: try:
self._destroy_in_db(self._context, self.user_id, self.name) self._destroy_in_db(self._context, self.user_id, self.name)
except exception.KeypairNotFound: except exception.KeypairNotFound:
db.key_pair_destroy(self._context, self.user_id, self.name) main_db_api.key_pair_destroy(
self._context, self.user_id, self.name)
@base.NovaObjectRegistry.register @base.NovaObjectRegistry.register
@ -222,7 +224,7 @@ class KeyPairList(base.ObjectListBase, base.NovaObject):
limit_more = None limit_more = None
if limit_more is None or limit_more > 0: if limit_more is None or limit_more > 0:
main_db_keypairs = db.key_pair_get_all_by_user( main_db_keypairs = main_db_api.key_pair_get_all_by_user(
context, user_id, limit=limit_more, marker=marker) context, user_id, limit=limit_more, marker=marker)
else: else:
main_db_keypairs = [] main_db_keypairs = []
@ -233,4 +235,4 @@ class KeyPairList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_count_by_user(cls, context, user_id): def get_count_by_user(cls, context, user_id):
return (cls._get_count_from_db(context, user_id) + return (cls._get_count_from_db(context, user_id) +
db.key_pair_count_by_user(context, user_id)) main_db_api.key_pair_count_by_user(context, user_id))

View File

@ -16,9 +16,11 @@ import collections
from oslo_db import exception as db_exc from oslo_db import exception as db_exc
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db from nova.db.main import api as main_db_api
from nova.db.main import models as main_models from nova.db.main import models as main_models
from nova.db import utils as db_utils
from nova import exception from nova import exception
from nova.objects import base from nova.objects import base
from nova.objects import fields from nova.objects import fields
@ -82,7 +84,7 @@ class Quotas(base.NovaObject):
self.obj_reset_changes(fields=[attr]) self.obj_reset_changes(fields=[attr])
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_from_db(context, project_id, resource, user_id=None): def _get_from_db(context, project_id, resource, user_id=None):
model = api_models.ProjectUserQuota if user_id else api_models.Quota model = api_models.ProjectUserQuota if user_id else api_models.Quota
query = context.session.query(model).\ query = context.session.query(model).\
@ -100,14 +102,14 @@ class Quotas(base.NovaObject):
return result return result
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_all_from_db(context, project_id): def _get_all_from_db(context, project_id):
return context.session.query(api_models.ProjectUserQuota).\ return context.session.query(api_models.ProjectUserQuota).\
filter_by(project_id=project_id).\ filter_by(project_id=project_id).\
all() all()
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_all_from_db_by_project(context, project_id): def _get_all_from_db_by_project(context, project_id):
# by_project refers to the returned dict that has a 'project_id' key # by_project refers to the returned dict that has a 'project_id' key
rows = context.session.query(api_models.Quota).\ rows = context.session.query(api_models.Quota).\
@ -119,7 +121,7 @@ class Quotas(base.NovaObject):
return result return result
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_all_from_db_by_project_and_user(context, project_id, user_id): def _get_all_from_db_by_project_and_user(context, project_id, user_id):
# by_project_and_user refers to the returned dict that has # by_project_and_user refers to the returned dict that has
# 'project_id' and 'user_id' keys # 'project_id' and 'user_id' keys
@ -135,7 +137,7 @@ class Quotas(base.NovaObject):
return result return result
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_all_in_db_by_project(context, project_id): def _destroy_all_in_db_by_project(context, project_id):
per_project = context.session.query(api_models.Quota).\ per_project = context.session.query(api_models.Quota).\
filter_by(project_id=project_id).\ filter_by(project_id=project_id).\
@ -147,7 +149,7 @@ class Quotas(base.NovaObject):
raise exception.ProjectQuotaNotFound(project_id=project_id) raise exception.ProjectQuotaNotFound(project_id=project_id)
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_all_in_db_by_project_and_user(context, project_id, user_id): def _destroy_all_in_db_by_project_and_user(context, project_id, user_id):
result = context.session.query(api_models.ProjectUserQuota).\ result = context.session.query(api_models.ProjectUserQuota).\
filter_by(project_id=project_id).\ filter_by(project_id=project_id).\
@ -158,7 +160,7 @@ class Quotas(base.NovaObject):
user_id=user_id) user_id=user_id)
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_class_from_db(context, class_name, resource): def _get_class_from_db(context, class_name, resource):
result = context.session.query(api_models.QuotaClass).\ result = context.session.query(api_models.QuotaClass).\
filter_by(class_name=class_name).\ filter_by(class_name=class_name).\
@ -169,7 +171,7 @@ class Quotas(base.NovaObject):
return result return result
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_all_class_from_db_by_name(context, class_name): def _get_all_class_from_db_by_name(context, class_name):
# by_name refers to the returned dict that has a 'class_name' key # by_name refers to the returned dict that has a 'class_name' key
rows = context.session.query(api_models.QuotaClass).\ rows = context.session.query(api_models.QuotaClass).\
@ -181,14 +183,16 @@ class Quotas(base.NovaObject):
return result return result
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _create_limit_in_db(context, project_id, resource, limit, def _create_limit_in_db(context, project_id, resource, limit,
user_id=None): user_id=None):
# TODO(melwitt): We won't have per project resources after nova-network # TODO(melwitt): We won't have per project resources after nova-network
# is removed. # is removed.
# TODO(stephenfin): We need to do something here now...but what? # TODO(stephenfin): We need to do something here now...but what?
per_user = (user_id and per_user = (
resource not in db.quota_get_per_project_resources()) user_id and
resource not in main_db_api.quota_get_per_project_resources()
)
quota_ref = (api_models.ProjectUserQuota() if per_user quota_ref = (api_models.ProjectUserQuota() if per_user
else api_models.Quota()) else api_models.Quota())
if per_user: if per_user:
@ -204,14 +208,16 @@ class Quotas(base.NovaObject):
return quota_ref return quota_ref
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _update_limit_in_db(context, project_id, resource, limit, def _update_limit_in_db(context, project_id, resource, limit,
user_id=None): user_id=None):
# TODO(melwitt): We won't have per project resources after nova-network # TODO(melwitt): We won't have per project resources after nova-network
# is removed. # is removed.
# TODO(stephenfin): We need to do something here now...but what? # TODO(stephenfin): We need to do something here now...but what?
per_user = (user_id and per_user = (
resource not in db.quota_get_per_project_resources()) user_id and
resource not in main_db_api.quota_get_per_project_resources()
)
model = api_models.ProjectUserQuota if per_user else api_models.Quota model = api_models.ProjectUserQuota if per_user else api_models.Quota
query = context.session.query(model).\ query = context.session.query(model).\
filter_by(project_id=project_id).\ filter_by(project_id=project_id).\
@ -228,7 +234,7 @@ class Quotas(base.NovaObject):
raise exception.ProjectQuotaNotFound(project_id=project_id) raise exception.ProjectQuotaNotFound(project_id=project_id)
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _create_class_in_db(context, class_name, resource, limit): def _create_class_in_db(context, class_name, resource, limit):
# NOTE(melwitt): There's no unique constraint on the QuotaClass model, # NOTE(melwitt): There's no unique constraint on the QuotaClass model,
# so check for duplicate manually. # so check for duplicate manually.
@ -247,7 +253,7 @@ class Quotas(base.NovaObject):
return quota_class_ref return quota_class_ref
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _update_class_in_db(context, class_name, resource, limit): def _update_class_in_db(context, class_name, resource, limit):
result = context.session.query(api_models.QuotaClass).\ result = context.session.query(api_models.QuotaClass).\
filter_by(class_name=class_name).\ filter_by(class_name=class_name).\
@ -366,7 +372,8 @@ class Quotas(base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def create_limit(cls, context, project_id, resource, limit, user_id=None): def create_limit(cls, context, project_id, resource, limit, user_id=None):
try: try:
db.quota_get(context, project_id, resource, user_id=user_id) main_db_api.quota_get(
context, project_id, resource, user_id=user_id)
except exception.QuotaNotFound: except exception.QuotaNotFound:
cls._create_limit_in_db(context, project_id, resource, limit, cls._create_limit_in_db(context, project_id, resource, limit,
user_id=user_id) user_id=user_id)
@ -380,13 +387,13 @@ class Quotas(base.NovaObject):
cls._update_limit_in_db(context, project_id, resource, limit, cls._update_limit_in_db(context, project_id, resource, limit,
user_id=user_id) user_id=user_id)
except exception.QuotaNotFound: except exception.QuotaNotFound:
db.quota_update(context, project_id, resource, limit, main_db_api.quota_update(context, project_id, resource, limit,
user_id=user_id) user_id=user_id)
@classmethod @classmethod
def create_class(cls, context, class_name, resource, limit): def create_class(cls, context, class_name, resource, limit):
try: try:
db.quota_class_get(context, class_name, resource) main_db_api.quota_class_get(context, class_name, resource)
except exception.QuotaClassNotFound: except exception.QuotaClassNotFound:
cls._create_class_in_db(context, class_name, resource, limit) cls._create_class_in_db(context, class_name, resource, limit)
else: else:
@ -398,7 +405,8 @@ class Quotas(base.NovaObject):
try: try:
cls._update_class_in_db(context, class_name, resource, limit) cls._update_class_in_db(context, class_name, resource, limit)
except exception.QuotaClassNotFound: except exception.QuotaClassNotFound:
db.quota_class_update(context, class_name, resource, limit) main_db_api.quota_class_update(
context, class_name, resource, limit)
# NOTE(melwitt): The following methods are not remotable and return # NOTE(melwitt): The following methods are not remotable and return
# dict-like database model objects. We are using classmethods to provide # dict-like database model objects. We are using classmethods to provide
@ -409,21 +417,22 @@ class Quotas(base.NovaObject):
quota = cls._get_from_db(context, project_id, resource, quota = cls._get_from_db(context, project_id, resource,
user_id=user_id) user_id=user_id)
except exception.QuotaNotFound: except exception.QuotaNotFound:
quota = db.quota_get(context, project_id, resource, quota = main_db_api.quota_get(context, project_id, resource,
user_id=user_id) user_id=user_id)
return quota return quota
@classmethod @classmethod
def get_all(cls, context, project_id): def get_all(cls, context, project_id):
api_db_quotas = cls._get_all_from_db(context, project_id) api_db_quotas = cls._get_all_from_db(context, project_id)
main_db_quotas = db.quota_get_all(context, project_id) main_db_quotas = main_db_api.quota_get_all(context, project_id)
return api_db_quotas + main_db_quotas return api_db_quotas + main_db_quotas
@classmethod @classmethod
def get_all_by_project(cls, context, project_id): def get_all_by_project(cls, context, project_id):
api_db_quotas_dict = cls._get_all_from_db_by_project(context, api_db_quotas_dict = cls._get_all_from_db_by_project(context,
project_id) project_id)
main_db_quotas_dict = db.quota_get_all_by_project(context, project_id) main_db_quotas_dict = main_db_api.quota_get_all_by_project(
context, project_id)
for k, v in api_db_quotas_dict.items(): for k, v in api_db_quotas_dict.items():
main_db_quotas_dict[k] = v main_db_quotas_dict[k] = v
return main_db_quotas_dict return main_db_quotas_dict
@ -432,7 +441,7 @@ class Quotas(base.NovaObject):
def get_all_by_project_and_user(cls, context, project_id, user_id): def get_all_by_project_and_user(cls, context, project_id, user_id):
api_db_quotas_dict = cls._get_all_from_db_by_project_and_user( api_db_quotas_dict = cls._get_all_from_db_by_project_and_user(
context, project_id, user_id) context, project_id, user_id)
main_db_quotas_dict = db.quota_get_all_by_project_and_user( main_db_quotas_dict = main_db_api.quota_get_all_by_project_and_user(
context, project_id, user_id) context, project_id, user_id)
for k, v in api_db_quotas_dict.items(): for k, v in api_db_quotas_dict.items():
main_db_quotas_dict[k] = v main_db_quotas_dict[k] = v
@ -443,7 +452,7 @@ class Quotas(base.NovaObject):
try: try:
cls._destroy_all_in_db_by_project(context, project_id) cls._destroy_all_in_db_by_project(context, project_id)
except exception.ProjectQuotaNotFound: except exception.ProjectQuotaNotFound:
db.quota_destroy_all_by_project(context, project_id) main_db_api.quota_destroy_all_by_project(context, project_id)
@classmethod @classmethod
def destroy_all_by_project_and_user(cls, context, project_id, user_id): def destroy_all_by_project_and_user(cls, context, project_id, user_id):
@ -451,31 +460,31 @@ class Quotas(base.NovaObject):
cls._destroy_all_in_db_by_project_and_user(context, project_id, cls._destroy_all_in_db_by_project_and_user(context, project_id,
user_id) user_id)
except exception.ProjectUserQuotaNotFound: except exception.ProjectUserQuotaNotFound:
db.quota_destroy_all_by_project_and_user(context, project_id, main_db_api.quota_destroy_all_by_project_and_user(
user_id) context, project_id, user_id)
@classmethod @classmethod
def get_class(cls, context, class_name, resource): def get_class(cls, context, class_name, resource):
try: try:
qclass = cls._get_class_from_db(context, class_name, resource) qclass = cls._get_class_from_db(context, class_name, resource)
except exception.QuotaClassNotFound: except exception.QuotaClassNotFound:
qclass = db.quota_class_get(context, class_name, resource) qclass = main_db_api.quota_class_get(context, class_name, resource)
return qclass return qclass
@classmethod @classmethod
def get_default_class(cls, context): def get_default_class(cls, context):
try: try:
qclass = cls._get_all_class_from_db_by_name( qclass = cls._get_all_class_from_db_by_name(
context, db._DEFAULT_QUOTA_NAME) context, main_db_api._DEFAULT_QUOTA_NAME)
except exception.QuotaClassNotFound: except exception.QuotaClassNotFound:
qclass = db.quota_class_get_default(context) qclass = main_db_api.quota_class_get_default(context)
return qclass return qclass
@classmethod @classmethod
def get_all_class_by_name(cls, context, class_name): def get_all_class_by_name(cls, context, class_name):
api_db_quotas_dict = cls._get_all_class_from_db_by_name(context, api_db_quotas_dict = cls._get_all_class_from_db_by_name(context,
class_name) class_name)
main_db_quotas_dict = db.quota_class_get_all_by_name(context, main_db_quotas_dict = main_db_api.quota_class_get_all_by_name(context,
class_name) class_name)
for k, v in api_db_quotas_dict.items(): for k, v in api_db_quotas_dict.items():
main_db_quotas_dict[k] = v main_db_quotas_dict[k] = v
@ -501,8 +510,8 @@ class QuotasNoOp(Quotas):
pass pass
@db.require_context @db_utils.require_context
@db.pick_context_manager_reader @main_db_api.pick_context_manager_reader
def _get_main_per_project_limits(context, limit): def _get_main_per_project_limits(context, limit):
return context.session.query(main_models.Quota).\ return context.session.query(main_models.Quota).\
filter_by(deleted=0).\ filter_by(deleted=0).\
@ -510,8 +519,8 @@ def _get_main_per_project_limits(context, limit):
all() all()
@db.require_context @db_utils.require_context
@db.pick_context_manager_reader @main_db_api.pick_context_manager_reader
def _get_main_per_user_limits(context, limit): def _get_main_per_user_limits(context, limit):
return context.session.query(main_models.ProjectUserQuota).\ return context.session.query(main_models.ProjectUserQuota).\
filter_by(deleted=0).\ filter_by(deleted=0).\
@ -519,8 +528,8 @@ def _get_main_per_user_limits(context, limit):
all() all()
@db.require_context @db_utils.require_context
@db.pick_context_manager_writer @main_db_api.pick_context_manager_writer
def _destroy_main_per_project_limits(context, project_id, resource): def _destroy_main_per_project_limits(context, project_id, resource):
context.session.query(main_models.Quota).\ context.session.query(main_models.Quota).\
filter_by(deleted=0).\ filter_by(deleted=0).\
@ -529,8 +538,8 @@ def _destroy_main_per_project_limits(context, project_id, resource):
soft_delete(synchronize_session=False) soft_delete(synchronize_session=False)
@db.require_context @db_utils.require_context
@db.pick_context_manager_writer @main_db_api.pick_context_manager_writer
def _destroy_main_per_user_limits(context, project_id, resource, user_id): def _destroy_main_per_user_limits(context, project_id, resource, user_id):
context.session.query(main_models.ProjectUserQuota).\ context.session.query(main_models.ProjectUserQuota).\
filter_by(deleted=0).\ filter_by(deleted=0).\
@ -540,7 +549,7 @@ def _destroy_main_per_user_limits(context, project_id, resource, user_id):
soft_delete(synchronize_session=False) soft_delete(synchronize_session=False)
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _create_limits_in_api_db(context, db_limits, per_user=False): def _create_limits_in_api_db(context, db_limits, per_user=False):
for db_limit in db_limits: for db_limit in db_limits:
user_id = db_limit.user_id if per_user else None user_id = db_limit.user_id if per_user else None
@ -587,8 +596,8 @@ def migrate_quota_limits_to_api_db(context, count):
return len(main_per_project_limits) + len(main_per_user_limits), done return len(main_per_project_limits) + len(main_per_user_limits), done
@db.require_context @db_utils.require_context
@db.pick_context_manager_reader @main_db_api.pick_context_manager_reader
def _get_main_quota_classes(context, limit): def _get_main_quota_classes(context, limit):
return context.session.query(main_models.QuotaClass).\ return context.session.query(main_models.QuotaClass).\
filter_by(deleted=0).\ filter_by(deleted=0).\
@ -596,7 +605,7 @@ def _get_main_quota_classes(context, limit):
all() all()
@db.pick_context_manager_writer @main_db_api.pick_context_manager_writer
def _destroy_main_quota_classes(context, db_classes): def _destroy_main_quota_classes(context, db_classes):
for db_class in db_classes: for db_class in db_classes:
context.session.query(main_models.QuotaClass).\ context.session.query(main_models.QuotaClass).\
@ -605,7 +614,7 @@ def _destroy_main_quota_classes(context, db_classes):
soft_delete(synchronize_session=False) soft_delete(synchronize_session=False)
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _create_classes_in_api_db(context, db_classes): def _create_classes_in_api_db(context, db_classes):
for db_class in db_classes: for db_class in db_classes:
Quotas._create_class_in_db(context, db_class.class_name, Quotas._create_class_in_db(context, db_class.class_name,

View File

@ -20,8 +20,8 @@ from oslo_log import log as logging
from oslo_serialization import jsonutils from oslo_serialization import jsonutils
from oslo_utils import versionutils from oslo_utils import versionutils
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova.objects import base from nova.objects import base
@ -630,7 +630,7 @@ class RequestSpec(base.NovaObject):
return spec return spec
@staticmethod @staticmethod
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _get_by_instance_uuid_from_db(context, instance_uuid): def _get_by_instance_uuid_from_db(context, instance_uuid):
db_spec = context.session.query(api_models.RequestSpec).filter_by( db_spec = context.session.query(api_models.RequestSpec).filter_by(
instance_uuid=instance_uuid).first() instance_uuid=instance_uuid).first()
@ -645,7 +645,7 @@ class RequestSpec(base.NovaObject):
return cls._from_db_object(context, cls(), db_spec) return cls._from_db_object(context, cls(), db_spec)
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _create_in_db(context, updates): def _create_in_db(context, updates):
db_spec = api_models.RequestSpec() db_spec = api_models.RequestSpec()
db_spec.update(updates) db_spec.update(updates)
@ -709,7 +709,7 @@ class RequestSpec(base.NovaObject):
self._from_db_object(self._context, self, db_spec) self._from_db_object(self._context, self, db_spec)
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _save_in_db(context, instance_uuid, updates): def _save_in_db(context, instance_uuid, updates):
# FIXME(sbauza): Provide a classmethod when oslo.db bug #1520195 is # FIXME(sbauza): Provide a classmethod when oslo.db bug #1520195 is
# fixed and released # fixed and released
@ -729,7 +729,7 @@ class RequestSpec(base.NovaObject):
self.obj_reset_changes() self.obj_reset_changes()
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_in_db(context, instance_uuid): def _destroy_in_db(context, instance_uuid):
result = context.session.query(api_models.RequestSpec).filter_by( result = context.session.query(api_models.RequestSpec).filter_by(
instance_uuid=instance_uuid).delete() instance_uuid=instance_uuid).delete()
@ -741,7 +741,7 @@ class RequestSpec(base.NovaObject):
self._destroy_in_db(self._context, self.instance_uuid) self._destroy_in_db(self._context, self.instance_uuid)
@staticmethod @staticmethod
@db.api_context_manager.writer @api_db_api.context_manager.writer
def _destroy_bulk_in_db(context, instance_uuids): def _destroy_bulk_in_db(context, instance_uuids):
return context.session.query(api_models.RequestSpec).filter( return context.session.query(api_models.RequestSpec).filter(
api_models.RequestSpec.instance_uuid.in_(instance_uuids)).\ api_models.RequestSpec.instance_uuid.in_(instance_uuids)).\

View File

@ -16,8 +16,9 @@ from oslo_log import log as logging
from oslo_utils import versionutils from oslo_utils import versionutils
from nova import context as nova_context from nova import context as nova_context
from nova.db.main import api as db from nova.db.api import api as api_db_api
from nova.db.main import models from nova.db.main import api as main_db_api
from nova.db.main import models as main_db_models
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova.objects import base from nova.objects import base
@ -70,26 +71,26 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_by_id(cls, context, vif_id): def get_by_id(cls, context, vif_id):
db_vif = db.virtual_interface_get(context, vif_id) db_vif = main_db_api.virtual_interface_get(context, vif_id)
if db_vif: if db_vif:
return cls._from_db_object(context, cls(), db_vif) return cls._from_db_object(context, cls(), db_vif)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_uuid(cls, context, vif_uuid): def get_by_uuid(cls, context, vif_uuid):
db_vif = db.virtual_interface_get_by_uuid(context, vif_uuid) db_vif = main_db_api.virtual_interface_get_by_uuid(context, vif_uuid)
if db_vif: if db_vif:
return cls._from_db_object(context, cls(), db_vif) return cls._from_db_object(context, cls(), db_vif)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_address(cls, context, address): def get_by_address(cls, context, address):
db_vif = db.virtual_interface_get_by_address(context, address) db_vif = main_db_api.virtual_interface_get_by_address(context, address)
if db_vif: if db_vif:
return cls._from_db_object(context, cls(), db_vif) return cls._from_db_object(context, cls(), db_vif)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_instance_and_network(cls, context, instance_uuid, network_id): def get_by_instance_and_network(cls, context, instance_uuid, network_id):
db_vif = db.virtual_interface_get_by_instance_and_network(context, db_vif = main_db_api.virtual_interface_get_by_instance_and_network(
instance_uuid, network_id) context, instance_uuid, network_id)
if db_vif: if db_vif:
return cls._from_db_object(context, cls(), db_vif) return cls._from_db_object(context, cls(), db_vif)
@ -99,7 +100,7 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject):
raise exception.ObjectActionError(action='create', raise exception.ObjectActionError(action='create',
reason='already created') reason='already created')
updates = self.obj_get_changes() updates = self.obj_get_changes()
db_vif = db.virtual_interface_create(self._context, updates) db_vif = main_db_api.virtual_interface_create(self._context, updates)
self._from_db_object(self._context, self, db_vif) self._from_db_object(self._context, self, db_vif)
@base.remotable @base.remotable
@ -108,17 +109,18 @@ class VirtualInterface(base.NovaPersistentObject, base.NovaObject):
if 'address' in updates: if 'address' in updates:
raise exception.ObjectActionError(action='save', raise exception.ObjectActionError(action='save',
reason='address is not mutable') reason='address is not mutable')
db_vif = db.virtual_interface_update(self._context, self.address, db_vif = main_db_api.virtual_interface_update(
updates) self._context, self.address, updates)
return self._from_db_object(self._context, self, db_vif) return self._from_db_object(self._context, self, db_vif)
@base.remotable_classmethod @base.remotable_classmethod
def delete_by_instance_uuid(cls, context, instance_uuid): def delete_by_instance_uuid(cls, context, instance_uuid):
db.virtual_interface_delete_by_instance(context, instance_uuid) main_db_api.virtual_interface_delete_by_instance(
context, instance_uuid)
@base.remotable @base.remotable
def destroy(self): def destroy(self):
db.virtual_interface_delete(self._context, self.id) main_db_api.virtual_interface_delete(self._context, self.id)
@base.NovaObjectRegistry.register @base.NovaObjectRegistry.register
@ -131,15 +133,16 @@ class VirtualInterfaceList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_all(cls, context): def get_all(cls, context):
db_vifs = db.virtual_interface_get_all(context) db_vifs = main_db_api.virtual_interface_get_all(context)
return base.obj_make_list(context, cls(context), return base.obj_make_list(context, cls(context),
objects.VirtualInterface, db_vifs) objects.VirtualInterface, db_vifs)
@staticmethod @staticmethod
@db.select_db_reader_mode @main_db_api.select_db_reader_mode
def _db_virtual_interface_get_by_instance(context, instance_uuid, def _db_virtual_interface_get_by_instance(context, instance_uuid,
use_slave=False): use_slave=False):
return db.virtual_interface_get_by_instance(context, instance_uuid) return main_db_api.virtual_interface_get_by_instance(
context, instance_uuid)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_instance_uuid(cls, context, instance_uuid, use_slave=False): def get_by_instance_uuid(cls, context, instance_uuid, use_slave=False):
@ -149,7 +152,7 @@ class VirtualInterfaceList(base.ObjectListBase, base.NovaObject):
objects.VirtualInterface, db_vifs) objects.VirtualInterface, db_vifs)
@db.api_context_manager.writer @api_db_api.context_manager.writer
def fill_virtual_interface_list(context, max_count): def fill_virtual_interface_list(context, max_count):
"""This fills missing VirtualInterface Objects in Nova DB""" """This fills missing VirtualInterface Objects in Nova DB"""
count_hit = 0 count_hit = 0
@ -287,14 +290,14 @@ def fill_virtual_interface_list(context, max_count):
# we checked. # we checked.
# Please notice that because of virtual_interfaces_instance_uuid_fkey # Please notice that because of virtual_interfaces_instance_uuid_fkey
# we need to have FAKE_UUID instance object, even deleted one. # we need to have FAKE_UUID instance object, even deleted one.
@db.pick_context_manager_writer @main_db_api.pick_context_manager_writer
def _set_or_delete_marker_for_migrate_instances(context, marker=None): def _set_or_delete_marker_for_migrate_instances(context, marker=None):
context.session.query(models.VirtualInterface).filter_by( context.session.query(main_db_models.VirtualInterface).filter_by(
instance_uuid=FAKE_UUID).delete() instance_uuid=FAKE_UUID).delete()
# Create FAKE_UUID instance objects, only for marker, if doesn't exist. # Create FAKE_UUID instance objects, only for marker, if doesn't exist.
# It is needed due constraint: virtual_interfaces_instance_uuid_fkey # It is needed due constraint: virtual_interfaces_instance_uuid_fkey
instance = context.session.query(models.Instance).filter_by( instance = context.session.query(main_db_models.Instance).filter_by(
uuid=FAKE_UUID).first() uuid=FAKE_UUID).first()
if not instance: if not instance:
instance = objects.Instance(context) instance = objects.Instance(context)
@ -316,9 +319,9 @@ def _set_or_delete_marker_for_migrate_instances(context, marker=None):
db_mapping.create() db_mapping.create()
@db.pick_context_manager_reader @main_db_api.pick_context_manager_reader
def _get_marker_for_migrate_instances(context): def _get_marker_for_migrate_instances(context):
vif = (context.session.query(models.VirtualInterface).filter_by( vif = (context.session.query(main_db_models.VirtualInterface).filter_by(
instance_uuid=FAKE_UUID)).first() instance_uuid=FAKE_UUID)).first()
marker = vif['tag'] if vif else None marker = vif['tag'] if vif else None
return marker return marker

View File

@ -24,8 +24,9 @@ from sqlalchemy import sql
import nova.conf import nova.conf
from nova import context as nova_context from nova import context as nova_context
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db from nova.db.main import api as main_db_api
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova.scheduler.client import report from nova.scheduler.client import report
@ -177,7 +178,10 @@ class DbQuotaDriver(object):
# displaying used limits. They are always zero. # displaying used limits. They are always zero.
usages[resource.name] = {'in_use': 0} usages[resource.name] = {'in_use': 0}
else: else:
if resource.name in db.quota_get_per_project_resources(): if (
resource.name in
main_db_api.quota_get_per_project_resources()
):
count = resource.count_as_dict(context, project_id) count = resource.count_as_dict(context, project_id)
key = 'project' key = 'project'
else: else:
@ -1045,7 +1049,7 @@ class QuotaEngine(object):
return 0 return 0
@db.api_context_manager.reader @api_db_api.context_manager.reader
def _user_id_queued_for_delete_populated(context, project_id=None): def _user_id_queued_for_delete_populated(context, project_id=None):
"""Determine whether user_id and queued_for_delete are set. """Determine whether user_id and queued_for_delete are set.

View File

@ -43,7 +43,8 @@ from nova.api import wsgi
from nova.compute import multi_cell_list from nova.compute import multi_cell_list
from nova.compute import rpcapi as compute_rpcapi from nova.compute import rpcapi as compute_rpcapi
from nova import context from nova import context
from nova.db.main import api as session from nova.db.api import api as api_db_api
from nova.db.main import api as main_db_api
from nova.db import migration from nova.db import migration
from nova import exception from nova import exception
from nova import objects from nova import objects
@ -61,7 +62,6 @@ LOG = logging.getLogger(__name__)
DB_SCHEMA = collections.defaultdict(str) DB_SCHEMA = collections.defaultdict(str)
SESSION_CONFIGURED = False SESSION_CONFIGURED = False
PROJECT_ID = '6f70656e737461636b20342065766572' PROJECT_ID = '6f70656e737461636b20342065766572'
@ -532,7 +532,7 @@ class CellDatabases(fixtures.Fixture):
# will house the sqlite:// connection for this cell's in-memory # will house the sqlite:// connection for this cell's in-memory
# database. Store/index it by the connection string, which is # database. Store/index it by the connection string, which is
# how we identify cells in CellMapping. # how we identify cells in CellMapping.
ctxt_mgr = session.create_context_manager() ctxt_mgr = main_db_api.create_context_manager()
self._ctxt_mgrs[connection_str] = ctxt_mgr self._ctxt_mgrs[connection_str] = ctxt_mgr
# NOTE(melwitt): The first DB access through service start is # NOTE(melwitt): The first DB access through service start is
@ -595,30 +595,45 @@ class CellDatabases(fixtures.Fixture):
class Database(fixtures.Fixture): class Database(fixtures.Fixture):
# TODO(stephenfin): The 'version' argument is unused and can be removed
def __init__(self, database='main', version=None, connection=None): def __init__(self, database='main', version=None, connection=None):
"""Create a database fixture. """Create a database fixture.
:param database: The type of database, 'main', or 'api' :param database: The type of database, 'main', or 'api'
:param connection: The connection string to use :param connection: The connection string to use
""" """
super(Database, self).__init__() super().__init__()
# NOTE(pkholkin): oslo_db.enginefacade is configured in tests the same
# way as it is done for any other service that uses db # NOTE(pkholkin): oslo_db.enginefacade is configured in tests the
# same way as it is done for any other service that uses DB
global SESSION_CONFIGURED global SESSION_CONFIGURED
if not SESSION_CONFIGURED: if not SESSION_CONFIGURED:
session.configure(CONF) main_db_api.configure(CONF)
api_db_api.configure(CONF)
SESSION_CONFIGURED = True SESSION_CONFIGURED = True
assert database in {'main', 'api'}, f'Unrecognized database {database}'
self.database = database self.database = database
self.version = version self.version = version
if database == 'main': if database == 'main':
if connection is not None: if connection is not None:
ctxt_mgr = session.create_context_manager( ctxt_mgr = main_db_api.create_context_manager(
connection=connection) connection=connection)
self.get_engine = ctxt_mgr.writer.get_engine self.get_engine = ctxt_mgr.writer.get_engine
else: else:
self.get_engine = session.get_engine self.get_engine = main_db_api.get_engine
elif database == 'api': elif database == 'api':
self.get_engine = session.get_api_engine assert connection is None, 'Not supported for the API database'
self.get_engine = api_db_api.get_engine
def setUp(self):
super(Database, self).setUp()
self.reset()
self.addCleanup(self.cleanup)
def _cache_schema(self): def _cache_schema(self):
global DB_SCHEMA global DB_SCHEMA
@ -642,11 +657,6 @@ class Database(fixtures.Fixture):
conn.connection.executescript( conn.connection.executescript(
DB_SCHEMA[(self.database, self.version)]) DB_SCHEMA[(self.database, self.version)])
def setUp(self):
super(Database, self).setUp()
self.reset()
self.addCleanup(self.cleanup)
class DefaultFlavorsFixture(fixtures.Fixture): class DefaultFlavorsFixture(fixtures.Fixture):
def setUp(self): def setUp(self):

View File

@ -18,8 +18,8 @@ from oslo_utils.fixture import uuidsentinel
from oslo_utils import timeutils from oslo_utils import timeutils
from nova import context from nova import context
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api
from nova import exception from nova import exception
import nova.objects.aggregate as aggregate_obj import nova.objects.aggregate as aggregate_obj
from nova import test from nova import test
@ -59,7 +59,7 @@ def _get_fake_metadata(db_id):
'unique_key': 'unique_value_' + str(db_id)} 'unique_key': 'unique_value_' + str(db_id)}
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_aggregate(context, values=_get_fake_aggregate(1, result=False), def _create_aggregate(context, values=_get_fake_aggregate(1, result=False),
metadata=_get_fake_metadata(1)): metadata=_get_fake_metadata(1)):
aggregate = api_models.Aggregate() aggregate = api_models.Aggregate()
@ -77,7 +77,7 @@ def _create_aggregate(context, values=_get_fake_aggregate(1, result=False),
return aggregate return aggregate
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_aggregate_with_hosts(context, def _create_aggregate_with_hosts(context,
values=_get_fake_aggregate(1, result=False), values=_get_fake_aggregate(1, result=False),
metadata=_get_fake_metadata(1), metadata=_get_fake_metadata(1),
@ -92,13 +92,13 @@ def _create_aggregate_with_hosts(context,
return aggregate return aggregate
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _aggregate_host_get_all(context, aggregate_id): def _aggregate_host_get_all(context, aggregate_id):
return context.session.query(api_models.AggregateHost).\ return context.session.query(api_models.AggregateHost).\
filter_by(aggregate_id=aggregate_id).all() filter_by(aggregate_id=aggregate_id).all()
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _aggregate_metadata_get_all(context, aggregate_id): def _aggregate_metadata_get_all(context, aggregate_id):
results = context.session.query(api_models.AggregateMetadata).\ results = context.session.query(api_models.AggregateMetadata).\
filter_by(aggregate_id=aggregate_id).all() filter_by(aggregate_id=aggregate_id).all()

View File

@ -11,8 +11,8 @@
# under the License. # under the License.
from nova import context from nova import context
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova import test from nova import test
@ -111,7 +111,7 @@ class FlavorObjectTestCase(test.NoDBTestCase):
self.assertEqual(set(projects), set(flavor2.projects)) self.assertEqual(set(projects), set(flavor2.projects))
@staticmethod @staticmethod
@db_api.api_context_manager.reader @api_db_api.context_manager.reader
def _collect_flavor_residue_api(context, flavor): def _collect_flavor_residue_api(context, flavor):
flavors = context.session.query(api_models.Flavors).\ flavors = context.session.query(api_models.Flavors).\
filter_by(id=flavor.id).all() filter_by(id=flavor.id).all()

View File

@ -46,11 +46,14 @@ document_root = /tmp
self.useFixture(config_fixture.Config()) self.useFixture(config_fixture.Config())
@mock.patch('sys.argv', return_value=mock.sentinel.argv) @mock.patch('sys.argv', return_value=mock.sentinel.argv)
@mock.patch('nova.db.api.api.configure')
@mock.patch('nova.db.main.api.configure') @mock.patch('nova.db.main.api.configure')
@mock.patch('nova.api.openstack.wsgi_app._setup_service') @mock.patch('nova.api.openstack.wsgi_app._setup_service')
@mock.patch('nova.api.openstack.wsgi_app._get_config_files') @mock.patch('nova.api.openstack.wsgi_app._get_config_files')
def test_init_application_called_twice(self, mock_get_files, mock_setup, def test_init_application_called_twice(
mock_db_configure, mock_argv): self, mock_get_files, mock_setup, mock_main_db_configure,
mock_api_db_configure, mock_argv,
):
"""Test that init_application can tolerate being called twice in a """Test that init_application can tolerate being called twice in a
single python interpreter instance. single python interpreter instance.
@ -65,14 +68,15 @@ document_root = /tmp
""" """
mock_get_files.return_value = [self.conf.name] mock_get_files.return_value = [self.conf.name]
mock_setup.side_effect = [test.TestingException, None] mock_setup.side_effect = [test.TestingException, None]
# We need to mock the global database configure() method, else we will # We need to mock the global database configure() methods, else we will
# be affected by global database state altered by other tests that ran # be affected by global database state altered by other tests that ran
# before this test, causing this test to fail with # before this test, causing this test to fail with
# oslo_db.sqlalchemy.enginefacade.AlreadyStartedError. We can instead # oslo_db.sqlalchemy.enginefacade.AlreadyStartedError. We can instead
# mock the method to raise an exception if it's called a second time in # mock the method to raise an exception if it's called a second time in
# this test to simulate the fact that the database does not tolerate # this test to simulate the fact that the database does not tolerate
# re-init [after a database query has been made]. # re-init [after a database query has been made].
mock_db_configure.side_effect = [None, test.TestingException] mock_main_db_configure.side_effect = [None, test.TestingException]
mock_api_db_configure.side_effect = [None, test.TestingException]
# Run init_application the first time, simulating an exception being # Run init_application the first time, simulating an exception being
# raised during it. # raised during it.
self.assertRaises(test.TestingException, wsgi_app.init_application, self.assertRaises(test.TestingException, wsgi_app.init_application,

View File

@ -39,8 +39,9 @@ from nova.conductor.tasks import live_migrate
from nova.conductor.tasks import migrate from nova.conductor.tasks import migrate
from nova import conf from nova import conf
from nova import context from nova import context
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db from nova.db.main import api as main_db_api
from nova import exception as exc from nova import exception as exc
from nova.image import glance as image_api from nova.image import glance as image_api
from nova import objects from nova import objects
@ -440,7 +441,7 @@ class _BaseTaskTestCase(object):
@mock.patch.object(conductor_manager.ComputeTaskManager, @mock.patch.object(conductor_manager.ComputeTaskManager,
'_create_and_bind_arqs') '_create_and_bind_arqs')
@mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance') @mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance')
@mock.patch.object(db, 'block_device_mapping_get_all_by_instance', @mock.patch.object(main_db_api, 'block_device_mapping_get_all_by_instance',
return_value=[]) return_value=[])
@mock.patch.object(conductor_manager.ComputeTaskManager, @mock.patch.object(conductor_manager.ComputeTaskManager,
'_schedule_instances') '_schedule_instances')
@ -547,7 +548,7 @@ class _BaseTaskTestCase(object):
@mock.patch.object(conductor_manager.ComputeTaskManager, @mock.patch.object(conductor_manager.ComputeTaskManager,
'_create_and_bind_arqs') '_create_and_bind_arqs')
@mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance') @mock.patch.object(compute_rpcapi.ComputeAPI, 'build_and_run_instance')
@mock.patch.object(db, 'block_device_mapping_get_all_by_instance', @mock.patch.object(main_db_api, 'block_device_mapping_get_all_by_instance',
return_value=[]) return_value=[])
@mock.patch.object(conductor_manager.ComputeTaskManager, @mock.patch.object(conductor_manager.ComputeTaskManager,
'_schedule_instances') '_schedule_instances')
@ -2740,7 +2741,7 @@ class ConductorTaskTestCase(_BaseTaskTestCase, test_compute.BaseTestCase):
self.assertEqual(0, len(build_requests)) self.assertEqual(0, len(build_requests))
@db.api_context_manager.reader @api_db_api.context_manager.reader
def request_spec_get_all(context): def request_spec_get_all(context):
return context.session.query(api_models.RequestSpec).all() return context.session.query(api_models.RequestSpec).all()

View File

View File

@ -0,0 +1,24 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from nova.db.api import api as db_api
from nova import test
class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase):
@mock.patch.object(db_api, 'context_manager')
def test_get_engine(self, mock_ctxt_mgr):
db_api.get_engine()
mock_ctxt_mgr.writer.get_engine.assert_called_once_with()

View File

@ -1,5 +1,3 @@
# encoding=UTF8
# Copyright 2010 United States Government as represented by the # Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration. # Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved. # All Rights Reserved.
@ -52,6 +50,7 @@ from nova import context
from nova.db.main import api as db from nova.db.main import api as db
from nova.db.main import models from nova.db.main import models
from nova.db import types as col_types from nova.db import types as col_types
from nova.db import utils as db_utils
from nova import exception from nova import exception
from nova.objects import fields from nova.objects import fields
from nova import test from nova import test
@ -207,7 +206,7 @@ class DecoratorTestCase(test.TestCase):
self.assertEqual(test_func.__module__, decorated_func.__module__) self.assertEqual(test_func.__module__, decorated_func.__module__)
def test_require_context_decorator_wraps_functions_properly(self): def test_require_context_decorator_wraps_functions_properly(self):
self._test_decorator_wraps_helper(db.require_context) self._test_decorator_wraps_helper(db_utils.require_context)
def test_require_deadlock_retry_wraps_functions_properly(self): def test_require_deadlock_retry_wraps_functions_properly(self):
self._test_decorator_wraps_helper( self._test_decorator_wraps_helper(
@ -628,7 +627,7 @@ class ModelQueryTestCase(DbTestCase):
@mock.patch.object(sqlalchemyutils, 'model_query') @mock.patch.object(sqlalchemyutils, 'model_query')
def test_model_query_use_context_session(self, mock_model_query): def test_model_query_use_context_session(self, mock_model_query):
@db.main_context_manager.reader @db.context_manager.reader
def fake_method(context): def fake_method(context):
session = context.session session = context.session
db.model_query(context, models.Instance) db.model_query(context, models.Instance)
@ -642,15 +641,15 @@ class ModelQueryTestCase(DbTestCase):
class EngineFacadeTestCase(DbTestCase): class EngineFacadeTestCase(DbTestCase):
def test_use_single_context_session_writer(self): def test_use_single_context_session_writer(self):
# Checks that session in context would not be overwritten by # Checks that session in context would not be overwritten by
# annotation @db.main_context_manager.writer if annotation # annotation @db.context_manager.writer if annotation
# is used twice. # is used twice.
@db.main_context_manager.writer @db.context_manager.writer
def fake_parent_method(context): def fake_parent_method(context):
session = context.session session = context.session
return fake_child_method(context), session return fake_child_method(context), session
@db.main_context_manager.writer @db.context_manager.writer
def fake_child_method(context): def fake_child_method(context):
session = context.session session = context.session
db.model_query(context, models.Instance) db.model_query(context, models.Instance)
@ -661,15 +660,15 @@ class EngineFacadeTestCase(DbTestCase):
def test_use_single_context_session_reader(self): def test_use_single_context_session_reader(self):
# Checks that session in context would not be overwritten by # Checks that session in context would not be overwritten by
# annotation @db.main_context_manager.reader if annotation # annotation @db.context_manager.reader if annotation
# is used twice. # is used twice.
@db.main_context_manager.reader @db.context_manager.reader
def fake_parent_method(context): def fake_parent_method(context):
session = context.session session = context.session
return fake_child_method(context), session return fake_child_method(context), session
@db.main_context_manager.reader @db.context_manager.reader
def fake_child_method(context): def fake_child_method(context):
session = context.session session = context.session
db.model_query(context, models.Instance) db.model_query(context, models.Instance)
@ -757,12 +756,12 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase):
self.assertEqual('|', filter('|')) self.assertEqual('|', filter('|'))
self.assertEqual('LIKE', op) self.assertEqual('LIKE', op)
@mock.patch.object(db, 'main_context_manager') @mock.patch.object(db, 'context_manager')
def test_get_engine(self, mock_ctxt_mgr): def test_get_engine(self, mock_ctxt_mgr):
db.get_engine() db.get_engine()
mock_ctxt_mgr.writer.get_engine.assert_called_once_with() mock_ctxt_mgr.writer.get_engine.assert_called_once_with()
@mock.patch.object(db, 'main_context_manager') @mock.patch.object(db, 'context_manager')
def test_get_engine_use_slave(self, mock_ctxt_mgr): def test_get_engine_use_slave(self, mock_ctxt_mgr):
db.get_engine(use_slave=True) db.get_engine(use_slave=True)
mock_ctxt_mgr.reader.get_engine.assert_called_once_with() mock_ctxt_mgr.reader.get_engine.assert_called_once_with()
@ -774,11 +773,6 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase):
connection='fake://') connection='fake://')
self.assertEqual('fake://', db_conf['connection']) self.assertEqual('fake://', db_conf['connection'])
@mock.patch.object(db, 'api_context_manager')
def test_get_api_engine(self, mock_ctxt_mgr):
db.get_api_engine()
mock_ctxt_mgr.writer.get_engine.assert_called_once_with()
@mock.patch.object(db, '_instance_get_by_uuid') @mock.patch.object(db, '_instance_get_by_uuid')
@mock.patch.object(db, '_instances_fill_metadata') @mock.patch.object(db, '_instances_fill_metadata')
@mock.patch('oslo_db.sqlalchemy.utils.paginate_query') @mock.patch('oslo_db.sqlalchemy.utils.paginate_query')
@ -1013,114 +1007,6 @@ class SqlAlchemyDbApiTestCase(DbTestCase):
self.assertEqual(1, len(instances)) self.assertEqual(1, len(instances))
class ProcessSortParamTestCase(test.TestCase):
def test_process_sort_params_defaults(self):
'''Verifies default sort parameters.'''
sort_keys, sort_dirs = db.process_sort_params([], [])
self.assertEqual(['created_at', 'id'], sort_keys)
self.assertEqual(['asc', 'asc'], sort_dirs)
sort_keys, sort_dirs = db.process_sort_params(None, None)
self.assertEqual(['created_at', 'id'], sort_keys)
self.assertEqual(['asc', 'asc'], sort_dirs)
def test_process_sort_params_override_default_keys(self):
'''Verifies that the default keys can be overridden.'''
sort_keys, sort_dirs = db.process_sort_params(
[], [], default_keys=['key1', 'key2', 'key3'])
self.assertEqual(['key1', 'key2', 'key3'], sort_keys)
self.assertEqual(['asc', 'asc', 'asc'], sort_dirs)
def test_process_sort_params_override_default_dir(self):
'''Verifies that the default direction can be overridden.'''
sort_keys, sort_dirs = db.process_sort_params(
[], [], default_dir='dir1')
self.assertEqual(['created_at', 'id'], sort_keys)
self.assertEqual(['dir1', 'dir1'], sort_dirs)
def test_process_sort_params_override_default_key_and_dir(self):
'''Verifies that the default key and dir can be overridden.'''
sort_keys, sort_dirs = db.process_sort_params(
[], [], default_keys=['key1', 'key2', 'key3'],
default_dir='dir1')
self.assertEqual(['key1', 'key2', 'key3'], sort_keys)
self.assertEqual(['dir1', 'dir1', 'dir1'], sort_dirs)
sort_keys, sort_dirs = db.process_sort_params(
[], [], default_keys=[], default_dir='dir1')
self.assertEqual([], sort_keys)
self.assertEqual([], sort_dirs)
def test_process_sort_params_non_default(self):
'''Verifies that non-default keys are added correctly.'''
sort_keys, sort_dirs = db.process_sort_params(
['key1', 'key2'], ['asc', 'desc'])
self.assertEqual(['key1', 'key2', 'created_at', 'id'], sort_keys)
# First sort_dir in list is used when adding the default keys
self.assertEqual(['asc', 'desc', 'asc', 'asc'], sort_dirs)
def test_process_sort_params_default(self):
'''Verifies that default keys are added correctly.'''
sort_keys, sort_dirs = db.process_sort_params(
['id', 'key2'], ['asc', 'desc'])
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['asc', 'desc', 'asc'], sort_dirs)
# Include default key value, rely on default direction
sort_keys, sort_dirs = db.process_sort_params(
['id', 'key2'], [])
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['asc', 'asc', 'asc'], sort_dirs)
def test_process_sort_params_default_dir(self):
'''Verifies that the default dir is applied to all keys.'''
# Direction is set, ignore default dir
sort_keys, sort_dirs = db.process_sort_params(
['id', 'key2'], ['desc'], default_dir='dir')
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['desc', 'desc', 'desc'], sort_dirs)
# But should be used if no direction is set
sort_keys, sort_dirs = db.process_sort_params(
['id', 'key2'], [], default_dir='dir')
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['dir', 'dir', 'dir'], sort_dirs)
def test_process_sort_params_unequal_length(self):
'''Verifies that a sort direction list is applied correctly.'''
sort_keys, sort_dirs = db.process_sort_params(
['id', 'key2', 'key3'], ['desc'])
self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys)
self.assertEqual(['desc', 'desc', 'desc', 'desc'], sort_dirs)
# Default direction is the first key in the list
sort_keys, sort_dirs = db.process_sort_params(
['id', 'key2', 'key3'], ['desc', 'asc'])
self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys)
self.assertEqual(['desc', 'asc', 'desc', 'desc'], sort_dirs)
sort_keys, sort_dirs = db.process_sort_params(
['id', 'key2', 'key3'], ['desc', 'asc', 'asc'])
self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys)
self.assertEqual(['desc', 'asc', 'asc', 'desc'], sort_dirs)
def test_process_sort_params_extra_dirs_lengths(self):
'''InvalidInput raised if more directions are given.'''
self.assertRaises(exception.InvalidInput,
db.process_sort_params,
['key1', 'key2'],
['asc', 'desc', 'desc'])
def test_process_sort_params_invalid_sort_dir(self):
'''InvalidInput raised if invalid directions are given.'''
for dirs in [['foo'], ['asc', 'foo'], ['asc', 'desc', 'foo']]:
self.assertRaises(exception.InvalidInput,
db.process_sort_params,
['key'],
dirs)
class MigrationTestCase(test.TestCase): class MigrationTestCase(test.TestCase):
def setUp(self): def setUp(self):
@ -1877,7 +1763,7 @@ class InstanceTestCase(test.TestCase, ModelsObjectComparatorMixin):
@mock.patch.object(query.Query, 'filter') @mock.patch.object(query.Query, 'filter')
def test_instance_metadata_get_multi_no_uuids(self, mock_query_filter): def test_instance_metadata_get_multi_no_uuids(self, mock_query_filter):
with db.main_context_manager.reader.using(self.ctxt): with db.context_manager.reader.using(self.ctxt):
db._instance_metadata_get_multi(self.ctxt, []) db._instance_metadata_get_multi(self.ctxt, [])
self.assertFalse(mock_query_filter.called) self.assertFalse(mock_query_filter.called)

View File

@ -17,7 +17,8 @@ from migrate.versioning import api as versioning_api
import mock import mock
import sqlalchemy import sqlalchemy
from nova.db.main import api as db_api from nova.db.api import api as api_db_api
from nova.db.main import api as main_db_api
from nova.db import migration from nova.db import migration
from nova import test from nova import test
@ -144,15 +145,17 @@ class TestDbVersionControl(test.NoDBTestCase):
class TestGetEngine(test.NoDBTestCase): class TestGetEngine(test.NoDBTestCase):
def test_get_main_engine(self): def test_get_main_engine(self):
with mock.patch.object(db_api, 'get_engine', with mock.patch.object(
return_value='engine') as mock_get_engine: main_db_api, 'get_engine', return_value='engine',
) as mock_get_engine:
engine = migration.get_engine() engine = migration.get_engine()
self.assertEqual('engine', engine) self.assertEqual('engine', engine)
mock_get_engine.assert_called_once_with(context=None) mock_get_engine.assert_called_once_with(context=None)
def test_get_api_engine(self): def test_get_api_engine(self):
with mock.patch.object(db_api, 'get_api_engine', with mock.patch.object(
return_value='api_engine') as mock_get_engine: api_db_api, 'get_engine', return_value='engine',
) as mock_get_engine:
engine = migration.get_engine('api') engine = migration.get_engine('api')
self.assertEqual('api_engine', engine) self.assertEqual('engine', engine)
mock_get_engine.assert_called_once_with() mock_get_engine.assert_called_once_with()

View File

@ -0,0 +1,123 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from nova.db import utils
from nova import exception
from nova import test
class ProcessSortParamTestCase(test.TestCase):
def test_process_sort_params_defaults(self):
"""Verifies default sort parameters."""
sort_keys, sort_dirs = utils.process_sort_params([], [])
self.assertEqual(['created_at', 'id'], sort_keys)
self.assertEqual(['asc', 'asc'], sort_dirs)
sort_keys, sort_dirs = utils.process_sort_params(None, None)
self.assertEqual(['created_at', 'id'], sort_keys)
self.assertEqual(['asc', 'asc'], sort_dirs)
def test_process_sort_params_override_default_keys(self):
"""Verifies that the default keys can be overridden."""
sort_keys, sort_dirs = utils.process_sort_params(
[], [], default_keys=['key1', 'key2', 'key3'])
self.assertEqual(['key1', 'key2', 'key3'], sort_keys)
self.assertEqual(['asc', 'asc', 'asc'], sort_dirs)
def test_process_sort_params_override_default_dir(self):
"""Verifies that the default direction can be overridden."""
sort_keys, sort_dirs = utils.process_sort_params(
[], [], default_dir='dir1')
self.assertEqual(['created_at', 'id'], sort_keys)
self.assertEqual(['dir1', 'dir1'], sort_dirs)
def test_process_sort_params_override_default_key_and_dir(self):
"""Verifies that the default key and dir can be overridden."""
sort_keys, sort_dirs = utils.process_sort_params(
[], [], default_keys=['key1', 'key2', 'key3'],
default_dir='dir1')
self.assertEqual(['key1', 'key2', 'key3'], sort_keys)
self.assertEqual(['dir1', 'dir1', 'dir1'], sort_dirs)
sort_keys, sort_dirs = utils.process_sort_params(
[], [], default_keys=[], default_dir='dir1')
self.assertEqual([], sort_keys)
self.assertEqual([], sort_dirs)
def test_process_sort_params_non_default(self):
"""Verifies that non-default keys are added correctly."""
sort_keys, sort_dirs = utils.process_sort_params(
['key1', 'key2'], ['asc', 'desc'])
self.assertEqual(['key1', 'key2', 'created_at', 'id'], sort_keys)
# First sort_dir in list is used when adding the default keys
self.assertEqual(['asc', 'desc', 'asc', 'asc'], sort_dirs)
def test_process_sort_params_default(self):
"""Verifies that default keys are added correctly."""
sort_keys, sort_dirs = utils.process_sort_params(
['id', 'key2'], ['asc', 'desc'])
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['asc', 'desc', 'asc'], sort_dirs)
# Include default key value, rely on default direction
sort_keys, sort_dirs = utils.process_sort_params(
['id', 'key2'], [])
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['asc', 'asc', 'asc'], sort_dirs)
def test_process_sort_params_default_dir(self):
"""Verifies that the default dir is applied to all keys."""
# Direction is set, ignore default dir
sort_keys, sort_dirs = utils.process_sort_params(
['id', 'key2'], ['desc'], default_dir='dir')
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['desc', 'desc', 'desc'], sort_dirs)
# But should be used if no direction is set
sort_keys, sort_dirs = utils.process_sort_params(
['id', 'key2'], [], default_dir='dir')
self.assertEqual(['id', 'key2', 'created_at'], sort_keys)
self.assertEqual(['dir', 'dir', 'dir'], sort_dirs)
def test_process_sort_params_unequal_length(self):
"""Verifies that a sort direction list is applied correctly."""
sort_keys, sort_dirs = utils.process_sort_params(
['id', 'key2', 'key3'], ['desc'])
self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys)
self.assertEqual(['desc', 'desc', 'desc', 'desc'], sort_dirs)
# Default direction is the first key in the list
sort_keys, sort_dirs = utils.process_sort_params(
['id', 'key2', 'key3'], ['desc', 'asc'])
self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys)
self.assertEqual(['desc', 'asc', 'desc', 'desc'], sort_dirs)
sort_keys, sort_dirs = utils.process_sort_params(
['id', 'key2', 'key3'], ['desc', 'asc', 'asc'])
self.assertEqual(['id', 'key2', 'key3', 'created_at'], sort_keys)
self.assertEqual(['desc', 'asc', 'asc', 'desc'], sort_dirs)
def test_process_sort_params_extra_dirs_lengths(self):
"""InvalidInput raised if more directions are given."""
self.assertRaises(
exception.InvalidInput,
utils.process_sort_params,
['key1', 'key2'],
['asc', 'desc', 'desc'])
def test_process_sort_params_invalid_sort_dir(self):
"""InvalidInput raised if invalid directions are given."""
for dirs in [['foo'], ['asc', 'foo'], ['asc', 'desc', 'foo']]:
self.assertRaises(
exception.InvalidInput,
utils.process_sort_params, ['key'], dirs)

View File

@ -19,8 +19,8 @@ from oslo_db import exception as db_exc
from oslo_utils import uuidutils from oslo_utils import uuidutils
from nova import context as nova_context from nova import context as nova_context
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models from nova.db.api import models as api_models
from nova.db.main import api as db_api
from nova import exception from nova import exception
from nova import objects from nova import objects
from nova.objects import fields from nova.objects import fields
@ -89,7 +89,7 @@ class _TestFlavor(object):
mock_get.assert_called_once_with(self.context, 'm1.foo') mock_get.assert_called_once_with(self.context, 'm1.foo')
@staticmethod @staticmethod
@db_api.api_context_manager.writer @api_db_api.context_manager.writer
def _create_api_flavor(context, altid=None): def _create_api_flavor(context, altid=None):
fake_db_flavor = dict(fake_flavor) fake_db_flavor = dict(fake_flavor)
del fake_db_flavor['extra_specs'] del fake_db_flavor['extra_specs']

View File

@ -87,19 +87,24 @@ class TestParseArgs(test.NoDBTestCase):
def setUp(self): def setUp(self):
super(TestParseArgs, self).setUp() super(TestParseArgs, self).setUp()
m = mock.patch('nova.db.main.api.configure') m = mock.patch('nova.db.main.api.configure')
self.nova_db_config_mock = m.start() self.main_db_config_mock = m.start()
self.addCleanup(self.nova_db_config_mock.stop) self.addCleanup(self.main_db_config_mock.stop)
m = mock.patch('nova.db.api.api.configure')
self.api_db_config_mock = m.start()
self.addCleanup(self.api_db_config_mock.stop)
@mock.patch.object(config.log, 'register_options') @mock.patch.object(config.log, 'register_options')
def test_parse_args_glance_debug_false(self, register_options): def test_parse_args_glance_debug_false(self, register_options):
self.flags(debug=False, group='glance') self.flags(debug=False, group='glance')
config.parse_args([], configure_db=False, init_rpc=False) config.parse_args([], configure_db=False, init_rpc=False)
self.assertIn('glanceclient=WARN', config.CONF.default_log_levels) self.assertIn('glanceclient=WARN', config.CONF.default_log_levels)
self.nova_db_config_mock.assert_not_called() self.main_db_config_mock.assert_not_called()
self.api_db_config_mock.assert_not_called()
@mock.patch.object(config.log, 'register_options') @mock.patch.object(config.log, 'register_options')
def test_parse_args_glance_debug_true(self, register_options): def test_parse_args_glance_debug_true(self, register_options):
self.flags(debug=True, group='glance') self.flags(debug=True, group='glance')
config.parse_args([], configure_db=True, init_rpc=False) config.parse_args([], configure_db=True, init_rpc=False)
self.assertIn('glanceclient=DEBUG', config.CONF.default_log_levels) self.assertIn('glanceclient=DEBUG', config.CONF.default_log_levels)
self.nova_db_config_mock.assert_called_once_with(config.CONF) self.main_db_config_mock.assert_called_once_with(config.CONF)
self.api_db_config_mock.assert_called_once_with(config.CONF)

View File

@ -34,7 +34,8 @@ import testtools
from nova.compute import rpcapi as compute_rpcapi from nova.compute import rpcapi as compute_rpcapi
from nova import conductor from nova import conductor
from nova import context from nova import context
from nova.db.main import api as session from nova.db.api import api as api_db_api
from nova.db.main import api as main_db_api
from nova import exception from nova import exception
from nova.network import neutron as neutron_api from nova.network import neutron as neutron_api
from nova import objects from nova import objects
@ -121,7 +122,7 @@ class TestDatabaseFixture(testtools.TestCase):
# because this sets up reasonable db connection strings # because this sets up reasonable db connection strings
self.useFixture(fixtures.ConfFixture()) self.useFixture(fixtures.ConfFixture())
self.useFixture(fixtures.Database()) self.useFixture(fixtures.Database())
engine = session.get_engine() engine = main_db_api.get_engine()
conn = engine.connect() conn = engine.connect()
result = conn.execute("select * from instance_types") result = conn.execute("select * from instance_types")
rows = result.fetchall() rows = result.fetchall()
@ -152,7 +153,7 @@ class TestDatabaseFixture(testtools.TestCase):
# This sets up reasonable db connection strings # This sets up reasonable db connection strings
self.useFixture(fixtures.ConfFixture()) self.useFixture(fixtures.ConfFixture())
self.useFixture(fixtures.Database(database='api')) self.useFixture(fixtures.Database(database='api'))
engine = session.get_api_engine() engine = api_db_api.get_engine()
conn = engine.connect() conn = engine.connect()
result = conn.execute("select * from cell_mappings") result = conn.execute("select * from cell_mappings")
rows = result.fetchall() rows = result.fetchall()
@ -186,7 +187,7 @@ class TestDatabaseFixture(testtools.TestCase):
fix.cleanup() fix.cleanup()
# ensure the db contains nothing # ensure the db contains nothing
engine = session.get_engine() engine = main_db_api.get_engine()
conn = engine.connect() conn = engine.connect()
schema = "".join(line for line in conn.connection.iterdump()) schema = "".join(line for line in conn.connection.iterdump())
self.assertEqual(schema, "BEGIN TRANSACTION;COMMIT;") self.assertEqual(schema, "BEGIN TRANSACTION;COMMIT;")
@ -198,7 +199,7 @@ class TestDatabaseFixture(testtools.TestCase):
self.useFixture(fix) self.useFixture(fix)
# No data inserted by migrations so we need to add a row # No data inserted by migrations so we need to add a row
engine = session.get_api_engine() engine = api_db_api.get_engine()
conn = engine.connect() conn = engine.connect()
uuid = uuidutils.generate_uuid() uuid = uuidutils.generate_uuid()
conn.execute("insert into cell_mappings (uuid, name) VALUES " conn.execute("insert into cell_mappings (uuid, name) VALUES "
@ -211,7 +212,7 @@ class TestDatabaseFixture(testtools.TestCase):
fix.cleanup() fix.cleanup()
# Ensure the db contains nothing # Ensure the db contains nothing
engine = session.get_api_engine() engine = api_db_api.get_engine()
conn = engine.connect() conn = engine.connect()
schema = "".join(line for line in conn.connection.iterdump()) schema = "".join(line for line in conn.connection.iterdump())
self.assertEqual("BEGIN TRANSACTION;COMMIT;", schema) self.assertEqual("BEGIN TRANSACTION;COMMIT;", schema)
@ -224,7 +225,7 @@ class TestDefaultFlavorsFixture(testtools.TestCase):
self.useFixture(fixtures.Database()) self.useFixture(fixtures.Database())
self.useFixture(fixtures.Database(database='api')) self.useFixture(fixtures.Database(database='api'))
engine = session.get_api_engine() engine = api_db_api.get_engine()
conn = engine.connect() conn = engine.connect()
result = conn.execute("select * from flavors") result = conn.execute("select * from flavors")
rows = result.fetchall() rows = result.fetchall()