db: Replace use of strings in join, defer operations

Resolve the following RemovedIn20Warning warnings:

  Using strings to indicate column or relationship paths in loader
  options is deprecated and will be removed in SQLAlchemy 2.0. Please
  use the class-bound attribute directly.

  Using strings to indicate relationship names in Query.join() is
  deprecated and will be removed in SQLAlchemy 2.0. Please use the
  class-bound attribute directly.

This is rather tricky to resolve. In most cases, we can simply make use
of getattr to fetch the class-bound attribute, however, there are a
number of places were we were doing "nested" joins, e.g.
'instances.info_cache' on the 'SecurityGroup' model. These need a little
more thought.

Change-Id: I1355ac92202cb504a7814afaa1338a4a511f9b54
Signed-off-by: Stephen Finucane <sfinucan@redhat.com>
This commit is contained in:
Stephen Finucane 2022-04-07 17:34:52 +01:00
parent 523297bdfa
commit 0939b3c4d1
8 changed files with 319 additions and 200 deletions

View File

@ -79,11 +79,28 @@ def _context_manager_from_context(context):
pass
def _joinedload_all(column):
def _joinedload_all(lead_entity, column):
"""Do a nested load.
For example, resolve the following::
_joinedload_all(models.SecurityGroup, 'instances.info_cache')
to:
orm.joinedload(
models.SecurityGroup.instances
).joinedload(
Instance.info_cache
)
"""
elements = column.split('.')
joined = orm.joinedload(elements.pop(0))
relationship_attr = getattr(lead_entity, elements.pop(0))
joined = orm.joinedload(relationship_attr)
for element in elements:
joined = joined.joinedload(element)
relationship_entity = relationship_attr.entity.class_
relationship_attr = getattr(relationship_entity, element)
joined = joined.joinedload(relationship_attr)
return joined
@ -1381,9 +1398,9 @@ def instance_get_by_uuid(context, uuid, columns_to_join=None):
def _instance_get_by_uuid(context, uuid, columns_to_join=None):
result = _build_instance_get(context, columns_to_join=columns_to_join).\
filter_by(uuid=uuid).\
first()
result = _build_instance_get(
context, columns_to_join=columns_to_join
).filter_by(uuid=uuid).first()
if not result:
raise exception.InstanceNotFound(instance_id=uuid)
@ -1411,9 +1428,13 @@ def instance_get(context, instance_id, columns_to_join=None):
def _build_instance_get(context, columns_to_join=None):
query = model_query(context, models.Instance, project_only=True).\
options(_joinedload_all('security_groups.rules')).\
options(orm.joinedload('info_cache'))
query = model_query(
context, models.Instance, project_only=True,
).options(
orm.joinedload(
models.Instance.security_groups
).joinedload(models.SecurityGroup.rules)
).options(orm.joinedload(models.Instance.info_cache))
if columns_to_join is None:
columns_to_join = ['metadata', 'system_metadata']
for column in columns_to_join:
@ -1421,7 +1442,10 @@ def _build_instance_get(context, columns_to_join=None):
# Already always joined above
continue
if 'extra.' in column:
query = query.options(orm.undefer(column))
column_ref = getattr(models.InstanceExtra, column.split('.')[1])
query = query.options(
orm.joinedload(models.Instance.extra).undefer(column_ref)
)
elif column in ['metadata', 'system_metadata']:
# NOTE(melwitt): We use subqueryload() instead of joinedload() for
# metadata and system_metadata because of the one-to-many
@ -1431,13 +1455,16 @@ def _build_instance_get(context, columns_to_join=None):
# in a large data transfer. Instead, the subqueryload() will
# perform additional queries to obtain metadata and system_metadata
# for the instance.
query = query.options(orm.subqueryload(column))
column_ref = getattr(models.Instance, column)
query = query.options(orm.subqueryload(column_ref))
else:
query = query.options(orm.joinedload(column))
column_ref = getattr(models.Instance, column)
query = query.options(orm.joinedload(column_ref))
# NOTE(alaski) Stop lazy loading of columns not needed.
for col in ['metadata', 'system_metadata']:
if col not in columns_to_join:
query = query.options(orm.noload(col))
for column in ['metadata', 'system_metadata']:
if column not in columns_to_join:
column_ref = getattr(models.Instance, column)
query = query.options(orm.noload(column_ref))
# NOTE(melwitt): We need to use order_by(<unique column>) so that the
# additional queries emitted by subqueryload() include the same ordering as
# used by the parent query.
@ -1530,7 +1557,8 @@ def instance_get_all(context, columns_to_join=None):
_manual_join_columns(columns_to_join))
query = model_query(context, models.Instance)
for column in columns_to_join_new:
query = query.options(orm.joinedload(column))
column_ref = getattr(models.Instance, column)
query = query.options(orm.joinedload(column_ref))
if not context.is_admin:
# If we're not admin context, add appropriate filter..
if context.project_id:
@ -1671,9 +1699,13 @@ def instance_get_all_by_filters_sort(context, filters, limit=None, marker=None,
query_prefix = context.session.query(models.Instance)
for column in columns_to_join_new:
if 'extra.' in column:
query_prefix = query_prefix.options(orm.undefer(column))
column_ref = getattr(models.InstanceExtra, column.split('.')[1])
query_prefix = query_prefix.options(
orm.joinedload(models.Instance.extra).undefer(column_ref)
)
else:
query_prefix = query_prefix.options(orm.joinedload(column))
column_ref = getattr(models.Instance, column)
query_prefix = query_prefix.options(orm.joinedload(column_ref))
# Note: order_by is done in the sqlalchemy.utils.py paginate_query(),
# no need to do it here as well
@ -1683,9 +1715,9 @@ def instance_get_all_by_filters_sort(context, filters, limit=None, marker=None,
filters = copy.deepcopy(filters)
model_object = models.Instance
query_prefix = _get_query_nova_resource_by_changes_time(query_prefix,
filters,
model_object)
query_prefix = _get_query_nova_resource_by_changes_time(
query_prefix, filters, model_object,
)
if 'deleted' in filters:
# Instances can be soft or hard deleted and the query needs to
@ -1697,14 +1729,12 @@ def instance_get_all_by_filters_sort(context, filters, limit=None, marker=None,
models.Instance.deleted == models.Instance.id,
models.Instance.vm_state == vm_states.SOFT_DELETED
)
query_prefix = query_prefix.\
filter(delete)
query_prefix = query_prefix.filter(delete)
else:
query_prefix = query_prefix.\
filter(models.Instance.deleted == models.Instance.id)
else:
query_prefix = query_prefix.\
filter_by(deleted=0)
query_prefix = query_prefix.filter_by(deleted=0)
if not filters.pop('soft_deleted', False):
# It would be better to have vm_state not be nullable
# but until then we test it explicitly as a workaround.
@ -1794,19 +1824,25 @@ def instance_get_all_by_filters_sort(context, filters, limit=None, marker=None,
if marker is not None:
try:
marker = _instance_get_by_uuid(
context.elevated(read_deleted='yes'), marker)
context.elevated(read_deleted='yes'), marker,
)
except exception.InstanceNotFound:
raise exception.MarkerNotFound(marker=marker)
try:
query_prefix = sqlalchemyutils.paginate_query(query_prefix,
models.Instance, limit,
sort_keys,
marker=marker,
sort_dirs=sort_dirs)
query_prefix = sqlalchemyutils.paginate_query(
query_prefix,
models.Instance,
limit,
sort_keys,
marker=marker,
sort_dirs=sort_dirs,
)
except db_exc.InvalidSortKey:
raise exception.InvalidSortKey()
return _instances_fill_metadata(context, query_prefix.all(), manual_joins)
instances = query_prefix.all()
return _instances_fill_metadata(context, instances, manual_joins)
@require_context
@ -2059,9 +2095,13 @@ def instance_get_active_by_window_joined(context, begin, end=None,
for column in columns_to_join_new:
if 'extra.' in column:
query = query.options(orm.undefer(column))
column_ref = getattr(models.InstanceExtra, column.split('.')[1])
query = query.options(
orm.joinedload(models.Instance.extra).undefer(column_ref)
)
else:
query = query.options(orm.joinedload(column))
column_ref = getattr(models.Instance, column)
query = query.options(orm.joinedload(column_ref))
query = query.filter(sql.or_(
models.Instance.terminated_at == sql.null(),
@ -2081,23 +2121,31 @@ def instance_get_active_by_window_joined(context, begin, end=None,
raise exception.MarkerNotFound(marker=marker)
query = sqlalchemyutils.paginate_query(
query, models.Instance, limit, ['project_id', 'uuid'], marker=marker)
query, models.Instance, limit, ['project_id', 'uuid'], marker=marker,
)
instances = query.all()
return _instances_fill_metadata(context, query.all(), manual_joins)
return _instances_fill_metadata(context, instances, manual_joins)
def _instance_get_all_query(context, project_only=False, joins=None):
if joins is None:
joins = ['info_cache', 'security_groups']
query = model_query(context,
models.Instance,
project_only=project_only)
query = model_query(
context,
models.Instance,
project_only=project_only,
)
for column in joins:
if 'extra.' in column:
query = query.options(orm.undefer(column))
column_ref = getattr(models.InstanceExtra, column.split('.')[1])
query = query.options(
orm.joinedload(models.Instance.extra).undefer(column_ref)
)
else:
query = query.options(orm.joinedload(column))
column_ref = getattr(models.Instance, column)
query = query.options(orm.joinedload(column_ref))
return query
@ -2105,9 +2153,12 @@ def _instance_get_all_query(context, project_only=False, joins=None):
def instance_get_all_by_host(context, host, columns_to_join=None):
"""Get all instances belonging to a host."""
query = _instance_get_all_query(context, joins=columns_to_join)
return _instances_fill_metadata(context,
query.filter_by(host=host).all(),
manual_joins=columns_to_join)
instances = query.filter_by(host=host).all()
return _instances_fill_metadata(
context,
instances,
manual_joins=columns_to_join,
)
def _instance_get_all_uuids_by_hosts(context, hosts):
@ -2147,19 +2198,26 @@ def instance_get_all_by_host_and_node(
candidates = ['system_metadata', 'metadata']
manual_joins = [x for x in columns_to_join if x in candidates]
columns_to_join = list(set(columns_to_join) - set(candidates))
return _instances_fill_metadata(context,
_instance_get_all_query(
context,
joins=columns_to_join).filter_by(host=host).
filter_by(node=node).all(), manual_joins=manual_joins)
instances = _instance_get_all_query(
context,
joins=columns_to_join,
).filter_by(host=host).filter_by(node=node).all()
return _instances_fill_metadata(
context,
instances,
manual_joins=manual_joins,
)
@pick_context_manager_reader
def instance_get_all_by_host_and_not_type(context, host, type_id=None):
"""Get all instances belonging to a host with a different type_id."""
return _instances_fill_metadata(context,
_instance_get_all_query(context).filter_by(host=host).
filter(models.Instance.instance_type_id != type_id).all())
instances = _instance_get_all_query(context).filter_by(
host=host,
).filter(
models.Instance.instance_type_id != type_id
).all()
return _instances_fill_metadata(context, instances)
# NOTE(hanlind): This method can be removed as conductor RPC API moves to v2.0.
@ -2172,11 +2230,14 @@ def instance_get_all_hung_in_rebooting(context, reboot_window):
# NOTE(danms): this is only used in the _poll_rebooting_instances()
# call in compute/manager, so we can avoid the metadata lookups
# explicitly
return _instances_fill_metadata(context,
model_query(context, models.Instance).
filter(models.Instance.updated_at <= reboot_window).
filter_by(task_state=task_states.REBOOTING).all(),
manual_joins=[])
instances = model_query(context, models.Instance).filter(
models.Instance.updated_at <= reboot_window
).filter_by(task_state=task_states.REBOOTING).all()
return _instances_fill_metadata(
context,
instances,
manual_joins=[],
)
def _retry_instance_update():
@ -2505,13 +2566,15 @@ def instance_extra_get_by_instance_uuid(
:param instance_uuid: UUID of the instance tied to the topology record
:param columns: A list of the columns to load, or None for 'all of them'
"""
query = model_query(context, models.InstanceExtra).\
filter_by(instance_uuid=instance_uuid)
query = model_query(context, models.InstanceExtra).filter_by(
instance_uuid=instance_uuid,
)
if columns is None:
columns = ['numa_topology', 'pci_requests', 'flavor', 'vcpu_model',
'trusted_certs', 'resources', 'migration_context']
for column in columns:
query = query.options(orm.undefer(column))
column_ref = getattr(models.InstanceExtra, column)
query = query.options(orm.undefer(column_ref))
instance_extra = query.first()
return instance_extra
@ -2733,7 +2796,8 @@ def _block_device_mapping_get_query(context, columns_to_join=None):
query = model_query(context, models.BlockDeviceMapping)
for column in columns_to_join:
query = query.options(orm.joinedload(column))
column_ref = getattr(models.BlockDeviceMapping, column)
query = query.options(orm.joinedload(column_ref))
return query
@ -2950,10 +3014,18 @@ def security_group_create(context, values):
def _security_group_get_query(context, read_deleted=None,
project_only=False, join_rules=True):
query = model_query(context, models.SecurityGroup,
read_deleted=read_deleted, project_only=project_only)
query = model_query(
context,
models.SecurityGroup,
read_deleted=read_deleted,
project_only=project_only,
)
if join_rules:
query = query.options(_joinedload_all('rules.grantee_group'))
query = query.options(
orm.joinedload(
models.SecurityGroup.rules
).joinedload(models.SecurityGroupIngressRule.grantee_group)
)
return query
@ -2998,8 +3070,7 @@ def security_group_get(context, security_group_id, columns_to_join=None):
if columns_to_join is None:
columns_to_join = []
for column in columns_to_join:
if column.startswith('instances'):
query = query.options(_joinedload_all(column))
query = query.options(_joinedload_all(models.SecurityGroup, column))
result = query.first()
if not result:
@ -3011,25 +3082,27 @@ def security_group_get(context, security_group_id, columns_to_join=None):
@require_context
@pick_context_manager_reader
def security_group_get_by_name(
context, project_id, group_name, columns_to_join=None,
):
def security_group_get_by_name(context, project_id, group_name):
"""Returns a security group with the specified name from a project."""
query = _security_group_get_query(context,
read_deleted="no", join_rules=False).\
filter_by(project_id=project_id).\
filter_by(name=group_name)
if columns_to_join is None:
columns_to_join = ['instances', 'rules.grantee_group']
for column in columns_to_join:
query = query.options(_joinedload_all(column))
query = _security_group_get_query(
context, read_deleted="no", join_rules=False,
).filter_by(
project_id=project_id,
).filter_by(
name=group_name,
).options(
orm.joinedload(models.SecurityGroup.instances)
).options(
orm.joinedload(
models.SecurityGroup.rules
).joinedload(models.SecurityGroupIngressRule.grantee_group)
)
result = query.first()
if not result:
raise exception.SecurityGroupNotFoundForProject(
project_id=project_id, security_group_id=group_name)
project_id=project_id, security_group_id=group_name,
)
return result
@ -3077,14 +3150,11 @@ def security_group_in_use(context, group_id):
@require_context
@pick_context_manager_writer
def security_group_update(context, security_group_id, values,
columns_to_join=None):
def security_group_update(context, security_group_id, values):
"""Update a security group."""
query = model_query(context, models.SecurityGroup).filter_by(
id=security_group_id)
if columns_to_join:
for column in columns_to_join:
query = query.options(_joinedload_all(column))
id=security_group_id,
)
security_group_ref = query.first()
if not security_group_ref:
@ -3265,20 +3335,36 @@ def migration_get_in_progress_by_host_and_node(context, host, node):
# 'finished' means a resize is finished on the destination host
# and the instance is in VERIFY_RESIZE state, so the end state
# for a resize is actually 'confirmed' or 'reverted'.
return model_query(context, models.Migration).\
filter(sql.or_(
sql.and_(
models.Migration.source_compute == host,
models.Migration.source_node == node),
sql.and_(
models.Migration.dest_compute == host,
models.Migration.dest_node == node))).\
filter(~models.Migration.status.in_(['confirmed', 'reverted',
'error', 'failed',
'completed', 'cancelled',
'done'])).\
options(_joinedload_all('instance.system_metadata')).\
all()
return model_query(
context, models.Migration,
).filter(
sql.or_(
sql.and_(
models.Migration.source_compute == host,
models.Migration.source_node == node,
),
sql.and_(
models.Migration.dest_compute == host,
models.Migration.dest_node == node,
),
)
).filter(
~models.Migration.status.in_(
[
'confirmed',
'reverted',
'error',
'failed',
'completed',
'cancelled',
'done',
]
)
).options(
orm.joinedload(
models.Migration.instance
).joinedload(models.Instance.system_metadata)
).all()
@pick_context_manager_reader
@ -3413,19 +3499,32 @@ def migration_get_in_progress_and_error_by_host_and_node(context, host, node):
"""Finds all in progress migrations and error migrations for the given
host and node.
"""
return model_query(context, models.Migration).\
filter(sql.or_(
sql.and_(
models.Migration.source_compute == host,
models.Migration.source_node == node),
sql.and_(
models.Migration.dest_compute == host,
models.Migration.dest_node == node))).\
filter(~models.Migration.status.in_(['confirmed', 'reverted',
'failed', 'completed',
'cancelled', 'done'])).\
options(_joinedload_all('instance.system_metadata')).\
all()
return model_query(
context, models.Migration,
).filter(
sql.or_(
sql.and_(
models.Migration.source_compute == host,
models.Migration.source_node == node),
sql.and_(
models.Migration.dest_compute == host,
models.Migration.dest_node == node,
),
)
).filter(
~models.Migration.status.in_([
'confirmed',
'reverted',
'failed',
'completed',
'cancelled',
'done',
])
).options(
orm.joinedload(
models.Migration.instance
).joinedload(models.Instance.system_metadata)
).all()
########################

View File

@ -35,8 +35,8 @@ DEPRECATED_FIELDS = ['deleted', 'deleted_at']
@api_db_api.context_manager.reader
def _aggregate_get_from_db(context, aggregate_id):
query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\
options(orm.joinedload('_metadata'))
options(orm.joinedload(api_models.Aggregate._hosts)).\
options(orm.joinedload(api_models.Aggregate._metadata))
query = query.filter(api_models.Aggregate.id == aggregate_id)
aggregate = query.first()
@ -50,8 +50,8 @@ def _aggregate_get_from_db(context, aggregate_id):
@api_db_api.context_manager.reader
def _aggregate_get_from_db_by_uuid(context, aggregate_uuid):
query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\
options(orm.joinedload('_metadata'))
options(orm.joinedload(api_models.Aggregate._hosts)).\
options(orm.joinedload(api_models.Aggregate._metadata))
query = query.filter(api_models.Aggregate.uuid == aggregate_uuid)
aggregate = query.first()
@ -414,8 +414,8 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject):
@api_db_api.context_manager.reader
def _get_all_from_db(context):
query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\
options(orm.joinedload('_metadata'))
options(orm.joinedload(api_models.Aggregate._hosts)).\
options(orm.joinedload(api_models.Aggregate._metadata))
return query.all()
@ -423,13 +423,13 @@ def _get_all_from_db(context):
@api_db_api.context_manager.reader
def _get_by_host_from_db(context, host, key=None):
query = context.session.query(api_models.Aggregate).\
options(orm.joinedload('_hosts')).\
options(orm.joinedload('_metadata'))
query = query.join('_hosts')
options(orm.joinedload(api_models.Aggregate._hosts)).\
options(orm.joinedload(api_models.Aggregate._metadata))
query = query.join(api_models.Aggregate._hosts)
query = query.filter(api_models.AggregateHost.host == host)
if key:
query = query.join("_metadata").filter(
query = query.join(api_models.Aggregate._metadata).filter(
api_models.AggregateMetadata.key == key)
return query.all()
@ -439,13 +439,15 @@ def _get_by_host_from_db(context, host, key=None):
def _get_by_metadata_from_db(context, key=None, value=None):
assert(key is not None or value is not None)
query = context.session.query(api_models.Aggregate)
query = query.join("_metadata")
query = query.join(api_models.Aggregate._metadata)
if key is not None:
query = query.filter(api_models.AggregateMetadata.key == key)
if value is not None:
query = query.filter(api_models.AggregateMetadata.value == value)
query = query.options(orm.contains_eager("_metadata"))
query = query.options(orm.joinedload("_hosts"))
query = query.options(
orm.contains_eager(api_models.Aggregate._metadata)
)
query = query.options(orm.joinedload(api_models.Aggregate._hosts))
return query.all()
@ -468,16 +470,19 @@ def _get_non_matching_by_metadata_keys_from_db(context, ignored_keys,
raise ValueError(_('key_prefix mandatory field.'))
query = context.session.query(api_models.Aggregate)
query = query.join("_metadata")
query = query.join(api_models.Aggregate._metadata)
query = query.filter(api_models.AggregateMetadata.value == value)
query = query.filter(api_models.AggregateMetadata.key.like(
key_prefix + '%'))
if len(ignored_keys) > 0:
query = query.filter(~api_models.AggregateMetadata.key.in_(
ignored_keys))
query = query.filter(
~api_models.AggregateMetadata.key.in_(ignored_keys)
)
query = query.options(orm.contains_eager("_metadata"))
query = query.options(orm.joinedload("_hosts"))
query = query.options(
orm.contains_eager(api_models.Aggregate._metadata)
)
query = query.options(orm.joinedload(api_models.Aggregate._hosts))
return query.all()

View File

@ -54,10 +54,11 @@ def _dict_with_extra_specs(flavor_model):
# issues are resolved.
@api_db_api.context_manager.reader
def _get_projects_from_db(context, flavorid):
db_flavor = context.session.query(api_models.Flavors).\
filter_by(flavorid=flavorid).\
options(orm.joinedload('projects')).\
first()
db_flavor = context.session.query(api_models.Flavors).filter_by(
flavorid=flavorid
).options(
orm.joinedload(api_models.Flavors.projects)
).first()
if not db_flavor:
raise exception.FlavorNotFound(flavor_id=flavorid)
return [x['project_id'] for x in db_flavor['projects']]
@ -271,8 +272,9 @@ class Flavor(base.NovaPersistentObject, base.NovaObject,
@staticmethod
@api_db_api.context_manager.reader
def _flavor_get_query_from_db(context):
query = context.session.query(api_models.Flavors).\
options(orm.joinedload('extra_specs'))
query = context.session.query(api_models.Flavors).options(
orm.joinedload(api_models.Flavors.extra_specs)
)
if not context.is_admin:
the_filter = [api_models.Flavors.is_public == sql.true()]
the_filter.extend([

View File

@ -89,9 +89,13 @@ class HostMapping(base.NovaTimestampObject, base.NovaObject):
@staticmethod
@api_db_api.context_manager.reader
def _get_by_host_from_db(context, host):
db_mapping = context.session.query(api_models.HostMapping)\
.options(orm.joinedload('cell_mapping'))\
.filter(api_models.HostMapping.host == host).first()
db_mapping = context.session.query(
api_models.HostMapping
).options(
orm.joinedload(api_models.HostMapping.cell_mapping)
).filter(
api_models.HostMapping.host == host
).first()
if not db_mapping:
raise exception.HostMappingNotFound(name=host)
return db_mapping
@ -159,18 +163,19 @@ class HostMappingList(base.ObjectListBase, base.NovaObject):
@staticmethod
@api_db_api.context_manager.reader
def _get_from_db(context, cell_id=None):
query = (context.session.query(api_models.HostMapping)
.options(orm.joinedload('cell_mapping')))
query = context.session.query(api_models.HostMapping).options(
orm.joinedload(api_models.HostMapping.cell_mapping)
)
if cell_id:
query = query.filter(api_models.HostMapping.cell_id == cell_id)
return query.all()
@base.remotable_classmethod
@ base.remotable_classmethod
def get_by_cell_id(cls, context, cell_id):
db_mappings = cls._get_from_db(context, cell_id)
return base.obj_make_list(context, cls(), HostMapping, db_mappings)
@base.remotable_classmethod
@ base.remotable_classmethod
def get_all(cls, context):
db_mappings = cls._get_from_db(context)
return base.obj_make_list(context, cls(), HostMapping, db_mappings)

View File

@ -36,8 +36,8 @@ LOG = logging.getLogger(__name__)
def _instance_group_get_query(context, id_field=None, id=None):
query = context.session.query(api_models.InstanceGroup).\
options(orm.joinedload('_policies')).\
options(orm.joinedload('_members'))
options(orm.joinedload(api_models.InstanceGroup._policies)).\
options(orm.joinedload(api_models.InstanceGroup._members))
if not context.is_admin:
query = query.filter_by(project_id=context.project_id)
if id and id_field:
@ -84,16 +84,22 @@ def _instance_group_members_add(context, group, members):
def _instance_group_members_add_by_uuid(context, group_uuid, members):
# NOTE(melwitt): The condition on the join limits the number of members
# returned to only those we wish to check as already existing.
group = context.session.query(api_models.InstanceGroup).\
outerjoin(api_models.InstanceGroupMember,
api_models.InstanceGroupMember.instance_uuid.in_(set(members))).\
filter(api_models.InstanceGroup.uuid == group_uuid).\
options(orm.contains_eager('_members')).first()
group = context.session.query(api_models.InstanceGroup).outerjoin(
api_models.InstanceGroupMember,
api_models.InstanceGroupMember.instance_uuid.in_(set(members))
).filter(
api_models.InstanceGroup.uuid == group_uuid
).options(orm.contains_eager(api_models.InstanceGroup._members)).first()
if not group:
raise exception.InstanceGroupNotFound(group_uuid=group_uuid)
return _instance_group_model_add(context, api_models.InstanceGroupMember,
members, group._members, 'instance_uuid',
group.id)
return _instance_group_model_add(
context,
api_models.InstanceGroupMember,
members,
group._members,
'instance_uuid',
group.id,
)
# TODO(berrange): Remove NovaObjectDictCompat

View File

@ -99,7 +99,7 @@ class InstanceMapping(base.NovaTimestampObject, base.NovaObject):
@api_db_api.context_manager.reader
def _get_by_instance_uuid_from_db(context, instance_uuid):
db_mapping = context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\
.options(orm.joinedload(api_models.InstanceMapping.cell_mapping))\
.filter(api_models.InstanceMapping.instance_uuid == instance_uuid)\
.first()
if not db_mapping:
@ -312,7 +312,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
@api_db_api.context_manager.reader
def _get_by_project_id_from_db(context, project_id):
return context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\
.options(orm.joinedload(api_models.InstanceMapping.cell_mapping))\
.filter(api_models.InstanceMapping.project_id == project_id).all()
@base.remotable_classmethod
@ -326,7 +326,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
@api_db_api.context_manager.reader
def _get_by_cell_id_from_db(context, cell_id):
return context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\
.options(orm.joinedload(api_models.InstanceMapping.cell_mapping))\
.filter(api_models.InstanceMapping.cell_id == cell_id).all()
@base.remotable_classmethod
@ -339,7 +339,7 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
@api_db_api.context_manager.reader
def _get_by_instance_uuids_from_db(context, uuids):
return context.session.query(api_models.InstanceMapping)\
.options(orm.joinedload('cell_mapping'))\
.options(orm.joinedload(api_models.InstanceMapping.cell_mapping))\
.filter(api_models.InstanceMapping.instance_uuid.in_(uuids))\
.all()
@ -373,12 +373,16 @@ class InstanceMappingList(base.ObjectListBase, base.NovaObject):
# queued_for_delete was not run) and False (cases when the online
# data migration for queued_for_delete was run) are assumed to mean
# that the instance is not queued for deletion.
query = (query.filter(sql.or_(
api_models.InstanceMapping.queued_for_delete == sql.false(),
api_models.InstanceMapping.queued_for_delete.is_(None)))
.join('cell_mapping')
.options(orm.joinedload('cell_mapping'))
.filter(api_models.CellMapping.uuid == cell_uuid))
query = query.filter(
sql.or_(
api_models.InstanceMapping.queued_for_delete == sql.false(),
api_models.InstanceMapping.queued_for_delete.is_(None)
)
).join(
api_models.InstanceMapping.cell_mapping
).options(
orm.joinedload(api_models.InstanceMapping.cell_mapping)
).filter(api_models.CellMapping.uuid == cell_uuid)
if limit is not None:
query = query.limit(limit)
return query.all()

View File

@ -891,18 +891,6 @@ class WarningsFixture(fixtures.Fixture):
message=r'The Connection.connect\(\) method is considered .*',
category=sqla_exc.SADeprecationWarning)
warnings.filterwarnings(
'ignore',
module='nova',
message=r'Using strings to indicate column or relationship .*',
category=sqla_exc.SADeprecationWarning)
warnings.filterwarnings(
'ignore',
module='nova',
message=r'Using strings to indicate relationship names .*',
category=sqla_exc.SADeprecationWarning)
warnings.filterwarnings(
'ignore',
module='nova',

View File

@ -167,27 +167,25 @@ class DbTestCase(test.TestCase):
class HelperTestCase(test.TestCase):
@mock.patch('sqlalchemy.orm.joinedload')
def test_joinedload_helper(self, mock_jl):
query = db._joinedload_all('foo.bar.baz')
query = db._joinedload_all(
models.SecurityGroup, 'instances.info_cache'
)
# We call sqlalchemy.orm.joinedload() on the first element
mock_jl.assert_called_once_with('foo')
mock_jl.assert_called_once_with(models.SecurityGroup.instances)
# Then first.joinedload(second)
column2 = mock_jl.return_value
column2.joinedload.assert_called_once_with('bar')
column2.joinedload.assert_called_once_with(models.Instance.info_cache)
# Then second.joinedload(third)
column3 = column2.joinedload.return_value
column3.joinedload.assert_called_once_with('baz')
self.assertEqual(column3.joinedload.return_value, query)
self.assertEqual(column2.joinedload.return_value, query)
@mock.patch('sqlalchemy.orm.joinedload')
def test_joinedload_helper_single(self, mock_jl):
query = db._joinedload_all('foo')
query = db._joinedload_all(models.SecurityGroup, 'instances')
# We call sqlalchemy.orm.joinedload() on the first element
mock_jl.assert_called_once_with('foo')
mock_jl.assert_called_once_with(models.SecurityGroup.instances)
# We should have gotten back just the result of the joinedload()
# call if there were no other elements
@ -1683,28 +1681,40 @@ class InstanceTestCase(test.TestCase, ModelsObjectComparatorMixin):
instances = db.instance_get_all_by_filters_sort(self.ctxt, filters)
self.assertEqual([], instances)
@mock.patch('sqlalchemy.orm.undefer')
@mock.patch('sqlalchemy.orm.joinedload')
def test_instance_get_all_by_filters_extra_columns(self,
mock_joinedload,
mock_undefer):
def test_instance_get_all_by_filters_extra_columns(self, mock_joinedload):
db.instance_get_all_by_filters_sort(
self.ctxt, {},
columns_to_join=['info_cache', 'extra.pci_requests'])
mock_joinedload.assert_called_once_with('info_cache')
mock_undefer.assert_called_once_with('extra.pci_requests')
columns_to_join=['info_cache', 'extra.pci_requests'],
)
mock_joinedload.assert_has_calls(
[
mock.call(models.Instance.info_cache),
mock.ANY,
mock.call(models.Instance.extra),
mock.ANY,
mock.ANY,
]
)
@mock.patch('sqlalchemy.orm.undefer')
@mock.patch('sqlalchemy.orm.joinedload')
def test_instance_get_active_by_window_extra_columns(self,
mock_joinedload,
mock_undefer):
def test_instance_get_active_by_window_extra_columns(
self, mock_joinedload,
):
now = datetime.datetime(2013, 10, 10, 17, 16, 37, 156701)
db.instance_get_active_by_window_joined(
self.ctxt, now,
columns_to_join=['info_cache', 'extra.pci_requests'])
mock_joinedload.assert_called_once_with('info_cache')
mock_undefer.assert_called_once_with('extra.pci_requests')
columns_to_join=['info_cache', 'extra.pci_requests'],
)
mock_joinedload.assert_has_calls(
[
mock.call(models.Instance.info_cache),
mock.ANY,
mock.call(models.Instance.extra),
mock.ANY,
mock.ANY,
]
)
def test_instance_get_all_by_filters_with_meta(self):
self.create_instance_with_args()