From b5dd59f8f9029a845ac352e87ab110d6e84511d0 Mon Sep 17 00:00:00 2001 From: Stephen Finucane Date: Thu, 1 Apr 2021 10:48:54 +0100 Subject: [PATCH] db: Synchronize function signatures A number of the abstract APIs in 'nova.db.api' had a different function signature to their concrete implementations in 'nova.db.sqlalchemy.api'. Correct this. Functions changed include: - action_get_by_request_id - create_context_manager - instance_add_security_group - instance_extra_update_by_uuid - instance_get_all_by_filters - instance_remove_security_group - migration_get - migration_get_by_id_and_instance - pci_device_update - service_get_minimum_version - virtual_interface_delete_by_instance - virtual_interface_get_by_instance - virtual_interface_get_by_instance_and_network To do this, the following script was used: >>> import nova.db.api as base_api >>> import nova.db.sqlalchemy.api as sqla_api >>> import collections >>> import inspect >>> for name in dir(base_api): ... fn_base = getattr(base_api, name) ... if not isinstance(fn_base, collections.Callable): ... continue ... fn_sqla = getattr(sqla_api, name, None) ... if not fn_sqla or not isinstance(fn_sqla, collections.Callable): ... print(f'missing function in nova.api.sqlalchemy.db: {name}') ... spec_base = inspect.getfullargspec(fn_base) ... spec_sqla = inspect.getfullargspec(fn_sqla) ... if spec_base != spec_sqla: ... print('mismatched function specs:') ... print(f'base: {spec_base}') ... print(f'sqla: {spec_sqla}') ... break In order for *this* to work, it was necessary to update the many decorators in 'nova.db.sqlalchemy.api' so that function signatures were preserved. This is possible by setting the signature of the wrapper to that of the wrapped function. Change-Id: Icb97a8b4e17fdbb2146ddf2729c906757c664f66 Signed-off-by: Stephen Finucane --- nova/db/api.py | 48 ++++++++++++++++++++------------------- nova/db/sqlalchemy/api.py | 47 ++++++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 76af136ff573..09ab9936f756 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -74,7 +74,7 @@ def not_equal(*values): return IMPL.not_equal(*values) -def create_context_manager(connection): +def create_context_manager(connection=None): """Create a database context manager object for a cell database connection. :param connection: The database connection string @@ -117,9 +117,9 @@ def service_get_by_uuid(context, service_uuid): return IMPL.service_get_by_uuid(context, service_uuid) -def service_get_minimum_version(context, binary): +def service_get_minimum_version(context, binaries): """Get the minimum service version in the database.""" - return IMPL.service_get_minimum_version(context, binary) + return IMPL.service_get_minimum_version(context, binaries) def service_get_by_host_and_topic(context, host, topic): @@ -395,9 +395,9 @@ def certificate_get_all_by_user_and_project(context, user_id, project_id): #################### -def migration_update(context, id, values): +def migration_update(context, migration_id, values): """Update a migration instance.""" - return IMPL.migration_update(context, id, values) + return IMPL.migration_update(context, migration_id, values) def migration_create(context, values): @@ -522,31 +522,31 @@ def virtual_interface_get_by_uuid(context, vif_uuid): return IMPL.virtual_interface_get_by_uuid(context, vif_uuid) -def virtual_interface_get_by_instance(context, instance_id): +def virtual_interface_get_by_instance(context, instance_uuid): """Gets all virtual interfaces for instance. :param instance_uuid: UUID of the instance to filter on. """ - return IMPL.virtual_interface_get_by_instance(context, instance_id) + return IMPL.virtual_interface_get_by_instance(context, instance_uuid) def virtual_interface_get_by_instance_and_network( - context, instance_id, network_id, + context, instance_uuid, network_id, ): """Get all virtual interface for instance that's associated with network. """ - return IMPL.virtual_interface_get_by_instance_and_network(context, - instance_id, - network_id) + return IMPL.virtual_interface_get_by_instance_and_network( + context, instance_uuid, network_id, + ) -def virtual_interface_delete_by_instance(context, instance_id): +def virtual_interface_delete_by_instance(context, instance_uuid): """Delete virtual interface records associated with instance. :param instance_uuid: UUID of the instance to filter on. """ - return IMPL.virtual_interface_delete_by_instance(context, instance_id) + return IMPL.virtual_interface_delete_by_instance(context, instance_uuid) def virtual_interface_delete(context, id): @@ -786,16 +786,18 @@ def instance_update_and_get_original(context, instance_uuid, values, return rv -def instance_add_security_group(context, instance_id, security_group_id): +def instance_add_security_group(context, instance_uuid, security_group_id): """Associate the given security group with the given instance.""" - return IMPL.instance_add_security_group(context, instance_id, - security_group_id) + return IMPL.instance_add_security_group( + context, instance_uuid, security_group_id, + ) -def instance_remove_security_group(context, instance_id, security_group_id): +def instance_remove_security_group(context, instance_uuid, security_group_id): """Disassociate the given security group from the given instance.""" - return IMPL.instance_remove_security_group(context, instance_id, - security_group_id) + return IMPL.instance_remove_security_group( + context, instance_uuid, security_group_id, + ) #################### @@ -1126,9 +1128,9 @@ def pci_device_destroy(context, node_id, address): return IMPL.pci_device_destroy(context, node_id, address) -def pci_device_update(context, node_id, address, value): +def pci_device_update(context, node_id, address, values): """Update a pci device.""" - return IMPL.pci_device_update(context, node_id, address, value) + return IMPL.pci_device_update(context, node_id, address, values) #################### @@ -1265,9 +1267,9 @@ def actions_get(context, instance_uuid, limit=None, marker=None, return IMPL.actions_get(context, instance_uuid, limit, marker, filters) -def action_get_by_request_id(context, uuid, request_id): +def action_get_by_request_id(context, instance_uuid, request_id): """Get the action by request_id and given instance.""" - return IMPL.action_get_by_request_id(context, uuid, request_id) + return IMPL.action_get_by_request_id(context, instance_uuid, request_id) def action_event_start(context, values): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 1353293aa654..67a4cba33081 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -174,6 +174,7 @@ def require_context(f): def wrapper(*args, **kwargs): nova.context.require_context(args[0]) return f(*args, **kwargs) + wrapper.__signature__ = inspect.signature(f) return wrapper @@ -202,6 +203,7 @@ def select_db_reader_mode(f): with reader_mode.using(context): return f(*args, **kwargs) + wrapper.__signature__ = inspect.signature(f) return wrapper @@ -213,11 +215,12 @@ def pick_context_manager_writer(f): Wrapped function must have a RequestContext in the arguments. """ @functools.wraps(f) - def wrapped(context, *args, **kwargs): + def wrapper(context, *args, **kwargs): ctxt_mgr = get_context_manager(context) with ctxt_mgr.writer.using(context): return f(context, *args, **kwargs) - return wrapped + wrapper.__signature__ = inspect.signature(f) + return wrapper def pick_context_manager_reader(f): @@ -228,11 +231,12 @@ def pick_context_manager_reader(f): Wrapped function must have a RequestContext in the arguments. """ @functools.wraps(f) - def wrapped(context, *args, **kwargs): + def wrapper(context, *args, **kwargs): ctxt_mgr = get_context_manager(context) with ctxt_mgr.reader.using(context): return f(context, *args, **kwargs) - return wrapped + wrapper.__signature__ = inspect.signature(f) + return wrapper def pick_context_manager_reader_allow_async(f): @@ -243,11 +247,12 @@ def pick_context_manager_reader_allow_async(f): Wrapped function must have a RequestContext in the arguments. """ @functools.wraps(f) - def wrapped(context, *args, **kwargs): + def wrapper(context, *args, **kwargs): ctxt_mgr = get_context_manager(context) with ctxt_mgr.reader.allow_async.using(context): return f(context, *args, **kwargs) - return wrapped + wrapper.__signature__ = inspect.signature(f) + return wrapper def model_query( @@ -1539,8 +1544,10 @@ def instance_get_all(context, columns_to_join=None): @require_context @pick_context_manager_reader_allow_async -def instance_get_all_by_filters(context, filters, sort_key, sort_dir, - limit=None, marker=None, columns_to_join=None): +def instance_get_all_by_filters( + context, filters, sort_key='created_at', sort_dir='desc', limit=None, + marker=None, columns_to_join=None, +): """Get all instances matching all filters sorted by the primary key. See instance_get_all_by_filters_sort for more information. @@ -2541,7 +2548,7 @@ def _instance_extra_create(context, values): @pick_context_manager_writer -def instance_extra_update_by_uuid(context, instance_uuid, values): +def instance_extra_update_by_uuid(context, instance_uuid, updates): """Update the instance extra record by instance uuid :param instance_uuid: UUID of the instance tied to the record @@ -2549,10 +2556,10 @@ def instance_extra_update_by_uuid(context, instance_uuid, values): """ rows_updated = model_query(context, models.InstanceExtra).\ filter_by(instance_uuid=instance_uuid).\ - update(values) + update(updates) if not rows_updated: LOG.debug("Created instance_extra for %s", instance_uuid) - create_values = copy.copy(values) + create_values = copy.copy(updates) create_values["instance_uuid"] = instance_uuid _instance_extra_create(context, create_values) rows_updated = 1 @@ -3305,23 +3312,23 @@ def migration_create(context, values): @oslo_db_api.wrap_db_retry(max_retries=5, retry_on_deadlock=True) @pick_context_manager_writer -def migration_update(context, id, values): +def migration_update(context, migration_id, values): """Update a migration instance.""" - migration = migration_get(context, id) + migration = migration_get(context, migration_id) migration.update(values) return migration @pick_context_manager_reader -def migration_get(context, id): +def migration_get(context, migration_id): """Finds a migration by the ID.""" result = model_query(context, models.Migration, read_deleted="yes").\ - filter_by(id=id).\ + filter_by(id=migration_id).\ first() if not result: - raise exception.MigrationNotFound(migration_id=id) + raise exception.MigrationNotFound(migration_id=migration_id) return result @@ -3340,16 +3347,16 @@ def migration_get_by_uuid(context, migration_uuid): @pick_context_manager_reader -def migration_get_by_id_and_instance(context, id, instance_uuid): +def migration_get_by_id_and_instance(context, migration_id, instance_uuid): """Finds a migration by the migration ID and the instance UUID.""" result = model_query(context, models.Migration).\ - filter_by(id=id).\ + filter_by(id=migration_id).\ filter_by(instance_uuid=instance_uuid).\ first() if not result: - raise exception.MigrationNotFoundForInstance(migration_id=id, - instance_id=instance_uuid) + raise exception.MigrationNotFoundForInstance( + migration_id=migration_id, instance_id=instance_uuid) return result