db: Synchronize function signatures

A number of the abstract APIs in 'nova.db.api' had a different function
signature to their concrete implementations in 'nova.db.sqlalchemy.api'.
Correct this. Functions changed include:

- action_get_by_request_id
- create_context_manager
- instance_add_security_group
- instance_extra_update_by_uuid
- instance_get_all_by_filters
- instance_remove_security_group
- migration_get
- migration_get_by_id_and_instance
- pci_device_update
- service_get_minimum_version
- virtual_interface_delete_by_instance
- virtual_interface_get_by_instance
- virtual_interface_get_by_instance_and_network

To do this, the following script was used:

  >>> import nova.db.api as base_api
  >>> import nova.db.sqlalchemy.api as sqla_api
  >>> import collections
  >>> import inspect
  >>> for name in dir(base_api):
  ...     fn_base = getattr(base_api, name)
  ...     if not isinstance(fn_base, collections.Callable):
  ...         continue
  ...     fn_sqla = getattr(sqla_api, name, None)
  ...     if not fn_sqla or not isinstance(fn_sqla, collections.Callable):
  ...         print(f'missing function in nova.api.sqlalchemy.db: {name}')
  ...     spec_base = inspect.getfullargspec(fn_base)
  ...     spec_sqla = inspect.getfullargspec(fn_sqla)
  ...     if spec_base != spec_sqla:
  ...         print('mismatched function specs:')
  ...         print(f'base: {spec_base}')
  ...         print(f'sqla: {spec_sqla}')
  ...     break

In order for *this* to work, it was necessary to update the many
decorators in 'nova.db.sqlalchemy.api' so that function signatures were
preserved. This is possible by setting the signature of the wrapper to
that of the wrapped function.

Change-Id: Icb97a8b4e17fdbb2146ddf2729c906757c664f66
Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
Stephen Finucane
2021-04-01 10:48:54 +01:00
parent 1d60cd7e05
commit b5dd59f8f9
2 changed files with 52 additions and 43 deletions

View File

@@ -74,7 +74,7 @@ def not_equal(*values):
return IMPL.not_equal(*values)
def create_context_manager(connection):
def create_context_manager(connection=None):
"""Create a database context manager object for a cell database connection.
:param connection: The database connection string
@@ -117,9 +117,9 @@ def service_get_by_uuid(context, service_uuid):
return IMPL.service_get_by_uuid(context, service_uuid)
def service_get_minimum_version(context, binary):
def service_get_minimum_version(context, binaries):
"""Get the minimum service version in the database."""
return IMPL.service_get_minimum_version(context, binary)
return IMPL.service_get_minimum_version(context, binaries)
def service_get_by_host_and_topic(context, host, topic):
@@ -395,9 +395,9 @@ def certificate_get_all_by_user_and_project(context, user_id, project_id):
####################
def migration_update(context, id, values):
def migration_update(context, migration_id, values):
"""Update a migration instance."""
return IMPL.migration_update(context, id, values)
return IMPL.migration_update(context, migration_id, values)
def migration_create(context, values):
@@ -522,31 +522,31 @@ def virtual_interface_get_by_uuid(context, vif_uuid):
return IMPL.virtual_interface_get_by_uuid(context, vif_uuid)
def virtual_interface_get_by_instance(context, instance_id):
def virtual_interface_get_by_instance(context, instance_uuid):
"""Gets all virtual interfaces for instance.
:param instance_uuid: UUID of the instance to filter on.
"""
return IMPL.virtual_interface_get_by_instance(context, instance_id)
return IMPL.virtual_interface_get_by_instance(context, instance_uuid)
def virtual_interface_get_by_instance_and_network(
context, instance_id, network_id,
context, instance_uuid, network_id,
):
"""Get all virtual interface for instance that's associated with
network.
"""
return IMPL.virtual_interface_get_by_instance_and_network(context,
instance_id,
network_id)
return IMPL.virtual_interface_get_by_instance_and_network(
context, instance_uuid, network_id,
)
def virtual_interface_delete_by_instance(context, instance_id):
def virtual_interface_delete_by_instance(context, instance_uuid):
"""Delete virtual interface records associated with instance.
:param instance_uuid: UUID of the instance to filter on.
"""
return IMPL.virtual_interface_delete_by_instance(context, instance_id)
return IMPL.virtual_interface_delete_by_instance(context, instance_uuid)
def virtual_interface_delete(context, id):
@@ -786,16 +786,18 @@ def instance_update_and_get_original(context, instance_uuid, values,
return rv
def instance_add_security_group(context, instance_id, security_group_id):
def instance_add_security_group(context, instance_uuid, security_group_id):
"""Associate the given security group with the given instance."""
return IMPL.instance_add_security_group(context, instance_id,
security_group_id)
return IMPL.instance_add_security_group(
context, instance_uuid, security_group_id,
)
def instance_remove_security_group(context, instance_id, security_group_id):
def instance_remove_security_group(context, instance_uuid, security_group_id):
"""Disassociate the given security group from the given instance."""
return IMPL.instance_remove_security_group(context, instance_id,
security_group_id)
return IMPL.instance_remove_security_group(
context, instance_uuid, security_group_id,
)
####################
@@ -1126,9 +1128,9 @@ def pci_device_destroy(context, node_id, address):
return IMPL.pci_device_destroy(context, node_id, address)
def pci_device_update(context, node_id, address, value):
def pci_device_update(context, node_id, address, values):
"""Update a pci device."""
return IMPL.pci_device_update(context, node_id, address, value)
return IMPL.pci_device_update(context, node_id, address, values)
####################
@@ -1265,9 +1267,9 @@ def actions_get(context, instance_uuid, limit=None, marker=None,
return IMPL.actions_get(context, instance_uuid, limit, marker, filters)
def action_get_by_request_id(context, uuid, request_id):
def action_get_by_request_id(context, instance_uuid, request_id):
"""Get the action by request_id and given instance."""
return IMPL.action_get_by_request_id(context, uuid, request_id)
return IMPL.action_get_by_request_id(context, instance_uuid, request_id)
def action_event_start(context, values):

View File

@@ -174,6 +174,7 @@ def require_context(f):
def wrapper(*args, **kwargs):
nova.context.require_context(args[0])
return f(*args, **kwargs)
wrapper.__signature__ = inspect.signature(f)
return wrapper
@@ -202,6 +203,7 @@ def select_db_reader_mode(f):
with reader_mode.using(context):
return f(*args, **kwargs)
wrapper.__signature__ = inspect.signature(f)
return wrapper
@@ -213,11 +215,12 @@ def pick_context_manager_writer(f):
Wrapped function must have a RequestContext in the arguments.
"""
@functools.wraps(f)
def wrapped(context, *args, **kwargs):
def wrapper(context, *args, **kwargs):
ctxt_mgr = get_context_manager(context)
with ctxt_mgr.writer.using(context):
return f(context, *args, **kwargs)
return wrapped
wrapper.__signature__ = inspect.signature(f)
return wrapper
def pick_context_manager_reader(f):
@@ -228,11 +231,12 @@ def pick_context_manager_reader(f):
Wrapped function must have a RequestContext in the arguments.
"""
@functools.wraps(f)
def wrapped(context, *args, **kwargs):
def wrapper(context, *args, **kwargs):
ctxt_mgr = get_context_manager(context)
with ctxt_mgr.reader.using(context):
return f(context, *args, **kwargs)
return wrapped
wrapper.__signature__ = inspect.signature(f)
return wrapper
def pick_context_manager_reader_allow_async(f):
@@ -243,11 +247,12 @@ def pick_context_manager_reader_allow_async(f):
Wrapped function must have a RequestContext in the arguments.
"""
@functools.wraps(f)
def wrapped(context, *args, **kwargs):
def wrapper(context, *args, **kwargs):
ctxt_mgr = get_context_manager(context)
with ctxt_mgr.reader.allow_async.using(context):
return f(context, *args, **kwargs)
return wrapped
wrapper.__signature__ = inspect.signature(f)
return wrapper
def model_query(
@@ -1539,8 +1544,10 @@ def instance_get_all(context, columns_to_join=None):
@require_context
@pick_context_manager_reader_allow_async
def instance_get_all_by_filters(context, filters, sort_key, sort_dir,
limit=None, marker=None, columns_to_join=None):
def instance_get_all_by_filters(
context, filters, sort_key='created_at', sort_dir='desc', limit=None,
marker=None, columns_to_join=None,
):
"""Get all instances matching all filters sorted by the primary key.
See instance_get_all_by_filters_sort for more information.
@@ -2541,7 +2548,7 @@ def _instance_extra_create(context, values):
@pick_context_manager_writer
def instance_extra_update_by_uuid(context, instance_uuid, values):
def instance_extra_update_by_uuid(context, instance_uuid, updates):
"""Update the instance extra record by instance uuid
:param instance_uuid: UUID of the instance tied to the record
@@ -2549,10 +2556,10 @@ def instance_extra_update_by_uuid(context, instance_uuid, values):
"""
rows_updated = model_query(context, models.InstanceExtra).\
filter_by(instance_uuid=instance_uuid).\
update(values)
update(updates)
if not rows_updated:
LOG.debug("Created instance_extra for %s", instance_uuid)
create_values = copy.copy(values)
create_values = copy.copy(updates)
create_values["instance_uuid"] = instance_uuid
_instance_extra_create(context, create_values)
rows_updated = 1
@@ -3305,23 +3312,23 @@ def migration_create(context, values):
@oslo_db_api.wrap_db_retry(max_retries=5, retry_on_deadlock=True)
@pick_context_manager_writer
def migration_update(context, id, values):
def migration_update(context, migration_id, values):
"""Update a migration instance."""
migration = migration_get(context, id)
migration = migration_get(context, migration_id)
migration.update(values)
return migration
@pick_context_manager_reader
def migration_get(context, id):
def migration_get(context, migration_id):
"""Finds a migration by the ID."""
result = model_query(context, models.Migration, read_deleted="yes").\
filter_by(id=id).\
filter_by(id=migration_id).\
first()
if not result:
raise exception.MigrationNotFound(migration_id=id)
raise exception.MigrationNotFound(migration_id=migration_id)
return result
@@ -3340,16 +3347,16 @@ def migration_get_by_uuid(context, migration_uuid):
@pick_context_manager_reader
def migration_get_by_id_and_instance(context, id, instance_uuid):
def migration_get_by_id_and_instance(context, migration_id, instance_uuid):
"""Finds a migration by the migration ID and the instance UUID."""
result = model_query(context, models.Migration).\
filter_by(id=id).\
filter_by(id=migration_id).\
filter_by(instance_uuid=instance_uuid).\
first()
if not result:
raise exception.MigrationNotFoundForInstance(migration_id=id,
instance_id=instance_uuid)
raise exception.MigrationNotFoundForInstance(
migration_id=migration_id, instance_id=instance_uuid)
return result