From b51eb4f35028d4861d4fc7df5eb9dc39b29c26e7 Mon Sep 17 00:00:00 2001 From: Andrey Pavlov Date: Tue, 21 Apr 2015 17:27:16 +0300 Subject: [PATCH] fix bugs with project id in DB for public objects - public ami images in DB had incorrect project id - public aki,ari images had incorrect project id Change-Id: I752d0f98fed150cfa7c5de792d50b091e00afae1 --- ec2api/api/ec2utils.py | 18 ++++-- ec2api/api/image.py | 62 ++++++++------------- ec2api/db/api.py | 6 +- ec2api/db/sqlalchemy/api.py | 10 ++-- ec2api/tests/contrib/post_test_hook.sh | 2 +- ec2api/tests/unit/test_address.py | 3 +- ec2api/tests/unit/test_ec2utils.py | 4 +- ec2api/tests/unit/test_image.py | 15 ++--- ec2api/tests/unit/test_instance.py | 15 +++-- ec2api/tests/unit/test_internet_gateway.py | 2 +- ec2api/tests/unit/test_network_interface.py | 6 +- ec2api/tests/unit/test_route_table.py | 3 +- ec2api/tests/unit/test_security_group.py | 6 +- ec2api/tests/unit/test_snapshot.py | 3 +- ec2api/tests/unit/test_subnet.py | 3 +- ec2api/tests/unit/test_volume.py | 6 +- ec2api/tests/unit/test_vpc.py | 6 +- ec2api/tests/unit/tools.py | 4 +- 18 files changed, 88 insertions(+), 86 deletions(-) diff --git a/ec2api/api/ec2utils.py b/ec2api/api/ec2utils.py index 8e563596..1a0a388f 100644 --- a/ec2api/api/ec2utils.py +++ b/ec2api/api/ec2utils.py @@ -225,16 +225,21 @@ def register_auto_create_db_item_extension(kind, extension): _auto_create_db_item_extensions[kind] = extension -def auto_create_db_item(context, kind, os_id, **extension_kwargs): +# TODO(Alex): The project_id passing mechanism can be potentially +# reconsidered in future. +def auto_create_db_item(context, kind, os_id, project_id=None, + **extension_kwargs): item = {'os_id': os_id} extension = _auto_create_db_item_extensions.get(kind) if extension: extension(context, item, **extension_kwargs) - return db_api.add_item(context, kind, item) + return db_api.add_item(context, kind, item, project_id=project_id) +# TODO(Alex): The project_id passing mechanism can be potentially +# reconsidered in future. def get_db_item_by_os_id(context, kind, os_id, items_by_os_id=None, - **extension_kwargs): + project_id=None, **extension_kwargs): """Get DB item by OS id (create if it doesn't exist). Args: @@ -264,14 +269,15 @@ def get_db_item_by_os_id(context, kind, os_id, items_by_os_id=None, item = next((i for i in db_api.get_items(context, kind) if i['os_id'] == os_id), None) if not item: - item = auto_create_db_item(context, kind, os_id, **extension_kwargs) + item = auto_create_db_item(context, kind, os_id, project_id=project_id, + **extension_kwargs) if items_by_os_id is not None: items_by_os_id[os_id] = item return item -# TODO(andrey-mp): project_id is a temporary workaround which should be -# reworked asap. (c) by ftersin. +# TODO(Alex): The project_id passing mechanism can be potentially +# reconsidered in future. def os_id_to_ec2_id(context, kind, os_id, items_by_os_id=None, ids_by_os_id=None, project_id=None): if os_id is None: diff --git a/ec2api/api/image.py b/ec2api/api/image.py index d0348a02..28c7798e 100644 --- a/ec2api/api/image.py +++ b/ec2api/api/image.py @@ -36,7 +36,6 @@ from ec2api.api import clients from ec2api.api import common from ec2api.api import ec2utils from ec2api.api import instance as instance_api -from ec2api import context as ec2_context from ec2api.db import api as db_api from ec2api import exception from ec2api.i18n import _, _LE, _LI @@ -219,15 +218,15 @@ def register_image(context, name=None, image_location=None, def deregister_image(context, image_id): - # TODO(ft): AWS returns AuthFailure for public images, - # but we return NotFound due searching for local images only - image = ec2utils.get_db_item(context, image_id) + os_image = ec2utils.get_os_image(context, image_id) + _check_owner(context, os_image) + glance = clients.glance(context) try: - glance.images.delete(image['os_id']) + glance.images.delete(os_image.id) except glance_exception.HTTPNotFound: pass - db_api.delete_item(context, image['id']) + db_api.delete_item(context, image_id) return True @@ -291,13 +290,12 @@ class ImageDescriber(common.TaggableItemsDescriber): def get_os_items(self): return clients.glance(self.context).images.list() - # TODO(andrey-mp): project_id will be invalid for new public images def auto_update_db(self, image, os_image): if not image: kind = _get_os_image_kind(os_image) image = ec2utils.get_db_item_by_os_id( self.context, kind, os_image.id, self.items_dict, - os_image=os_image) + os_image=os_image, project_id=os_image.owner) elif (image['os_id'] in self.local_images_os_ids and image['is_public'] != os_image.is_public): image['is_public'] = os_image.is_public @@ -323,14 +321,10 @@ def describe_images(context, executable_by=None, image_id=None, def describe_image_attribute(context, image_id, attribute): - image = ec2utils.get_db_item(context, image_id) - try: - glance = clients.glance(ec2_context.get_os_admin_context()) - glance.images.get(image['os_id']) - except glance_exception.HTTPNotFound: - raise exception.InvalidAMIIDNotFound(id=image_id) - os_image = _get_owned_os_image(context, image_id, image['os_id']) + os_image = ec2utils.get_os_image(context, image_id) + _check_owner(context, os_image) _prepare_mappings(os_image) + image = ec2utils.get_db_item(context, image_id) def _block_device_mapping_attribute(result): _cloud_format_mappings(context, os_image.properties, result) @@ -386,13 +380,7 @@ def modify_image_attribute(context, image_id, attribute=None, user_group=None, operation_type=None, description=None, launch_permission=None, product_code=None, user_id=None, value=None): - image = ec2utils.get_db_item(context, image_id) - - try: - glance = clients.glance(ec2_context.get_os_admin_context()) - os_image = glance.images.get(image['os_id']) - except glance_exception.HTTPNotFound: - raise exception.InvalidAMIIDNotFound(id=image_id) + os_image = ec2utils.get_os_image(context, image_id) attributes = set() @@ -447,7 +435,7 @@ def modify_image_attribute(context, image_id, attribute=None, value='operation_type', reason=msg) - os_image = _get_owned_os_image(context, image_id, image['os_id']) + _check_owner(context, os_image) os_image.update(is_public=(operation_type == 'add')) return True @@ -456,9 +444,8 @@ def modify_image_attribute(context, image_id, attribute=None, raise exception.MissingParameter( 'The request must contain the parameter description') - # Just check image accessibility - _get_owned_os_image(context, image_id, image['os_id']) - + _check_owner(context, os_image) + image = ec2utils.get_db_item(context, image_id) image['description'] = value db_api.update_item(context, image) return True @@ -468,22 +455,17 @@ def reset_image_attribute(context, image_id, attribute): if attribute != 'launchPermission': raise exception.InvalidRequest() - image = ec2utils.get_db_item(context, image_id) - os_image = _get_owned_os_image(context, image_id, image['os_id']) + os_image = ec2utils.get_os_image(context, image_id) + _check_owner(context, os_image) + os_image.update(is_public=False) return True -def _get_owned_os_image(context, image_id, os_image_id): - glance = clients.glance(context) - try: - os_image = glance.images.get(os_image_id) - except glance_exception.HTTPNotFound: - os_image = None - if os_image is None or os_image.owner != context.project_id: +def _check_owner(context, os_image): + if os_image.owner != context.project_id: raise exception.AuthFailure(_('Not authorized for image:%s') - % image_id) - return os_image + % os_image.id) def _format_image(context, image, os_image, images_dict, ids_dict, @@ -507,12 +489,14 @@ def _format_image(context, image, os_image, images_dict, ids_dict, if kernel_id: ec2_image['kernelId'] = ec2utils.os_id_to_ec2_id( context, 'aki', kernel_id, - items_by_os_id=images_dict, ids_by_os_id=ids_dict) + items_by_os_id=images_dict, ids_by_os_id=ids_dict, + project_id=os_image.owner) ramdisk_id = os_image.properties.get('ramdisk_id') if ramdisk_id: ec2_image['ramdiskId'] = ec2utils.os_id_to_ec2_id( context, 'ari', ramdisk_id, - items_by_os_id=images_dict, ids_by_os_id=ids_dict) + items_by_os_id=images_dict, ids_by_os_id=ids_dict, + project_id=os_image.owner) name = os_image.name img_loc = os_image.properties.get('image_location') diff --git a/ec2api/db/api.py b/ec2api/db/api.py index 5349527c..7b1762b3 100644 --- a/ec2api/db/api.py +++ b/ec2api/db/api.py @@ -79,12 +79,12 @@ IMPL = EC2DBAPI() LOG = logging.getLogger(__name__) -def add_item(context, kind, data): - return IMPL.add_item(context, kind, data) +def add_item(context, kind, data, project_id=None): + return IMPL.add_item(context, kind, data, project_id=project_id) def add_item_id(context, kind, os_id, project_id=None): - return IMPL.add_item_id(context, kind, os_id, project_id) + return IMPL.add_item_id(context, kind, os_id, project_id=project_id) def update_item(context, item): diff --git a/ec2api/db/sqlalchemy/api.py b/ec2api/db/sqlalchemy/api.py index 0be7e054..369e0df3 100644 --- a/ec2api/db/sqlalchemy/api.py +++ b/ec2api/db/sqlalchemy/api.py @@ -90,10 +90,12 @@ def _new_id(kind, os_id): @require_context -def add_item(context, kind, data): +def add_item(context, kind, data, project_id=None): + if not project_id: + project_id = context.project_id item_ref = models.Item() item_ref.update({ - "project_id": context.project_id, + "project_id": project_id, "id": _new_id(kind, data.get("os_id")), }) item_ref.update(_pack_item_data(data)) @@ -105,14 +107,14 @@ def add_item(context, kind, data): raise item_ref = (model_query(context, models.Item). filter_by(os_id=data["os_id"]). - filter(or_(models.Item.project_id == context.project_id, + filter(or_(models.Item.project_id == project_id, models.Item.project_id.is_(None))). filter(models.Item.id.like('%s-%%' % kind)). one()) item_data = _unpack_item_data(item_ref) item_data.update(data) item_ref.update(_pack_item_data(item_data)) - item_ref.project_id = context.project_id + item_ref.project_id = project_id item_ref.save() return _unpack_item_data(item_ref) diff --git a/ec2api/tests/contrib/post_test_hook.sh b/ec2api/tests/contrib/post_test_hook.sh index c993b3cd..7bab0d61 100755 --- a/ec2api/tests/contrib/post_test_hook.sh +++ b/ec2api/tests/contrib/post_test_hook.sh @@ -112,7 +112,7 @@ image_id = $image_id ebs_image_id = $ebs_image_id EOF" - # local workaround for LP#1439819. its doesn't work in gating because glance check isatty property. + # local workaround for LP#1439819. it doesn't work in gating because glance check isatty property. #glance image-update $image_name --container-format ami --disk-format ami fi diff --git a/ec2api/tests/unit/test_address.py b/ec2api/tests/unit/test_address.py index 5d2e5561..fe9d8339 100644 --- a/ec2api/tests/unit/test_address.py +++ b/ec2api/tests/unit/test_address.py @@ -57,7 +57,8 @@ class AddressTestCase(base.ApiTestCase): self.db_api.add_item.assert_called_once_with( mock.ANY, 'eipalloc', tools.purge_dict(fakes.DB_ADDRESS_1, - ('id', 'vpc_id'))) + ('id', 'vpc_id')), + project_id=None) self.neutron.create_floatingip.assert_called_once_with( {'floatingip': { 'floating_network_id': diff --git a/ec2api/tests/unit/test_ec2utils.py b/ec2api/tests/unit/test_ec2utils.py index 004403fd..74ea3c0c 100644 --- a/ec2api/tests/unit/test_ec2utils.py +++ b/ec2api/tests/unit/test_ec2utils.py @@ -169,7 +169,7 @@ class EC2UtilsTestCase(testtools.TestCase): item_id = ec2utils.os_id_to_ec2_id(fake_context, 'fake', fake_os_id) self.assertEqual(fake_id, item_id) db_api.add_item_id.assert_called_once_with( - fake_context, 'fake', fake_os_id, None) + fake_context, 'fake', fake_os_id, project_id=None) # no item in cache, item isn't found db_api.reset_mock() @@ -180,7 +180,7 @@ class EC2UtilsTestCase(testtools.TestCase): self.assertIn(fake_os_id, ids_cache) self.assertEqual(fake_id, ids_cache[fake_os_id]) db_api.add_item_id.assert_called_once_with( - fake_context, 'fake', fake_os_id, None) + fake_context, 'fake', fake_os_id, project_id=None) # no item in cache, item is found db_api.reset_mock() diff --git a/ec2api/tests/unit/test_image.py b/ec2api/tests/unit/test_image.py index 63018430..e66bd6ee 100644 --- a/ec2api/tests/unit/test_image.py +++ b/ec2api/tests/unit/test_image.py @@ -92,15 +92,6 @@ FILE_MANIFEST_XML = """ class ImageTestCase(base.ApiTestCase): - def setUp(self): - super(ImageTestCase, self).setUp() - get_os_admin_context_patcher = ( - mock.patch('ec2api.context.get_os_admin_context')) - self.get_os_admin_context = get_os_admin_context_patcher.start() - self.addCleanup(get_os_admin_context_patcher.stop) - self.get_os_admin_context.return_value = ( - self._create_context(auth_token='admin_token')) - @mock.patch('ec2api.api.instance._is_ebs_instance') def _test_create_image(self, instance_status, no_reboot, is_ebs_instance): self.set_mock_db_items(fakes.DB_INSTANCE_2) @@ -132,7 +123,8 @@ class ImageTestCase(base.ApiTestCase): self.db_api.add_item.assert_called_once_with( mock.ANY, 'ami', {'os_id': image_id, 'is_public': False, - 'description': 'fake desc'}) + 'description': 'fake desc'}, + project_id=None) if not no_reboot: os_instance.stop.assert_called_once_with() os_instance.get.assert_called_once_with() @@ -208,7 +200,8 @@ class ImageTestCase(base.ApiTestCase): self.db_api.add_item.assert_called_once_with( mock.ANY, 'ami', {'os_id': fakes.ID_OS_IMAGE_2, 'is_public': False, - 'description': None}) + 'description': None}, + project_id=None) self.assertEqual(1, self.glance.images.create.call_count) self.assertEqual((), self.glance.images.create.call_args[0]) self.assertIn('properties', self.glance.images.create.call_args[1]) diff --git a/ec2api/tests/unit/test_instance.py b/ec2api/tests/unit/test_instance.py index f1acbdcb..50380137 100644 --- a/ec2api/tests/unit/test_instance.py +++ b/ec2api/tests/unit/test_instance.py @@ -155,7 +155,8 @@ class InstanceTestCase(base.ApiTestCase): nics=[{'port-id': fakes.ID_OS_PORT_1}], key_name=None, userdata=None) self.db_api.add_item.assert_called_once_with( - mock.ANY, 'i', tools.purge_dict(fakes.DB_INSTANCE_1, ('id',))) + mock.ANY, 'i', tools.purge_dict(fakes.DB_INSTANCE_1, ('id',)), + project_id=None) (self.network_interface_api. _attach_network_interface_item.assert_called_once_with( mock.ANY, fakes.DB_NETWORK_INTERFACE_1, @@ -296,7 +297,8 @@ class InstanceTestCase(base.ApiTestCase): [0, 1] * 2, [True, False, True, False])])) self.db_api.add_item.assert_has_calls([ - mock.call(mock.ANY, 'i', tools.purge_dict(db_instance, ['id'])) + mock.call(mock.ANY, 'i', tools.purge_dict(db_instance, ['id']), + project_id=None) for db_instance in self.DB_INSTANCES]) @mock.patch('ec2api.api.instance._parse_block_device_mapping') @@ -355,7 +357,7 @@ class InstanceTestCase(base.ApiTestCase): 'client_token': 'fake_client_token'} db_instance.update(extra_db_instance) self.db_api.add_item.assert_called_once_with( - mock.ANY, 'i', db_instance) + mock.ANY, 'i', db_instance, project_id=None) self.db_api.reset_mock() parse_block_device_mapping.assert_called_once_with( mock.ANY, @@ -1460,9 +1462,10 @@ class InstancePrivateTestCase(test_base.BaseTestCase): fake_context, instance, os_instance, [], {}, os_flavors=fake_flavors) db_api.add_item_id.assert_has_calls( - [mock.call(mock.ANY, 'ami', os_instance.image['id'], None), - mock.call(mock.ANY, 'aki', kernel_id, None), - mock.call(mock.ANY, 'ari', ramdisk_id, None)], + [mock.call(mock.ANY, 'ami', os_instance.image['id'], + project_id=None), + mock.call(mock.ANY, 'aki', kernel_id, project_id=None), + mock.call(mock.ANY, 'ari', ramdisk_id, project_id=None)], any_order=True) @mock.patch('cinderclient.client.Client') diff --git a/ec2api/tests/unit/test_internet_gateway.py b/ec2api/tests/unit/test_internet_gateway.py index cf136f7c..535f68e9 100644 --- a/ec2api/tests/unit/test_internet_gateway.py +++ b/ec2api/tests/unit/test_internet_gateway.py @@ -39,7 +39,7 @@ class IgwTestCase(base.ApiTestCase): igw = resp['internetGateway'] self.assertThat(fakes.EC2_IGW_2, matchers.DictMatches(igw)) self.db_api.add_item.assert_called_with( - mock.ANY, 'igw', {}) + mock.ANY, 'igw', {}, project_id=None) def test_attach_igw(self): self.configure(external_network=fakes.NAME_OS_PUBLIC_NETWORK) diff --git a/ec2api/tests/unit/test_network_interface.py b/ec2api/tests/unit/test_network_interface.py index 70a4a335..a0147b40 100644 --- a/ec2api/tests/unit/test_network_interface.py +++ b/ec2api/tests/unit/test_network_interface.py @@ -41,7 +41,8 @@ class NetworkInterfaceTestCase(base.ApiTestCase): matchers.DictMatches(resp['networkInterface'])) self.db_api.add_item.assert_called_once_with( mock.ANY, 'eni', - tools.purge_dict(fakes.DB_NETWORK_INTERFACE_1, ('id',))) + tools.purge_dict(fakes.DB_NETWORK_INTERFACE_1, ('id',)), + project_id=None) if auto_ips: self.neutron.create_port.assert_called_once_with( {'port': @@ -128,7 +129,8 @@ class NetworkInterfaceTestCase(base.ApiTestCase): 'device_index', 'instance_id', 'delete_on_termination', - 'attach_time'))) + 'attach_time')), + project_id=None) self.neutron.update_port.assert_called_once_with( fakes.ID_OS_PORT_2, {'port': {'name': diff --git a/ec2api/tests/unit/test_route_table.py b/ec2api/tests/unit/test_route_table.py index 40d6c371..3a69da62 100644 --- a/ec2api/tests/unit/test_route_table.py +++ b/ec2api/tests/unit/test_route_table.py @@ -42,7 +42,8 @@ class RouteTableTestCase(base.ApiTestCase): 'rtb', {'vpc_id': fakes.ID_EC2_VPC_1, 'routes': [{'destination_cidr_block': fakes.CIDR_VPC_1, - 'gateway_id': None}]}) + 'gateway_id': None}]}, + project_id=None) self.db_api.get_item_by_id.assert_called_once_with( mock.ANY, fakes.ID_EC2_VPC_1) diff --git a/ec2api/tests/unit/test_security_group.py b/ec2api/tests/unit/test_security_group.py index e97de097..464dcc40 100644 --- a/ec2api/tests/unit/test_security_group.py +++ b/ec2api/tests/unit/test_security_group.py @@ -63,7 +63,8 @@ class SecurityGroupTestCase(base.ApiTestCase): self.assertEqual(fakes.ID_EC2_SECURITY_GROUP_2, resp['groupId']) self.db_api.add_item.assert_called_once_with( mock.ANY, 'sg', - tools.purge_dict(fakes.DB_SECURITY_GROUP_2, ('id',))) + tools.purge_dict(fakes.DB_SECURITY_GROUP_2, ('id',)), + project_id=None) self.nova.security_groups.create.assert_called_once_with( 'groupname', 'Group description') @@ -319,7 +320,8 @@ class SecurityGroupTestCase(base.ApiTestCase): resp = self.execute('DescribeSecurityGroups', {}) self.db_api.add_item.assert_called_once_with( mock.ANY, 'sg', - tools.purge_dict(fakes.DB_SECURITY_GROUP_1, ('id',))) + tools.purge_dict(fakes.DB_SECURITY_GROUP_1, ('id',)), + project_id=None) self.nova.security_groups.create.assert_called_once_with( fakes.ID_EC2_VPC_1, 'Default VPC security group') diff --git a/ec2api/tests/unit/test_snapshot.py b/ec2api/tests/unit/test_snapshot.py index 26ae6acd..fdb4200e 100644 --- a/ec2api/tests/unit/test_snapshot.py +++ b/ec2api/tests/unit/test_snapshot.py @@ -113,7 +113,8 @@ class SnapshotTestCase(base.ApiTestCase): self.assertThat(fakes.EC2_SNAPSHOT_1, matchers.DictMatches(resp)) self.db_api.add_item.assert_called_once_with( mock.ANY, 'snap', - tools.purge_dict(fakes.DB_SNAPSHOT_1, ('id',))) + tools.purge_dict(fakes.DB_SNAPSHOT_1, ('id',)), + project_id=None) self.cinder.volume_snapshots.create.assert_called_once_with( fakes.ID_OS_VOLUME_2, force=True, display_description=None) diff --git a/ec2api/tests/unit/test_subnet.py b/ec2api/tests/unit/test_subnet.py index 7b6523eb..b70e1ec7 100644 --- a/ec2api/tests/unit/test_subnet.py +++ b/ec2api/tests/unit/test_subnet.py @@ -38,7 +38,8 @@ class SubnetTestCase(base.ApiTestCase): resp['subnet'])) self.db_api.add_item.assert_called_once_with( mock.ANY, 'subnet', - tools.purge_dict(fakes.DB_SUBNET_1, ('id',))) + tools.purge_dict(fakes.DB_SUBNET_1, ('id',)), + project_id=None) self.neutron.create_network.assert_called_once_with( {'network': {}}) self.neutron.update_network.assert_called_once_with( diff --git a/ec2api/tests/unit/test_volume.py b/ec2api/tests/unit/test_volume.py index 00c858dc..5de5e9c9 100644 --- a/ec2api/tests/unit/test_volume.py +++ b/ec2api/tests/unit/test_volume.py @@ -110,7 +110,8 @@ class VolumeTestCase(base.ApiTestCase): self.assertThat(fakes.EC2_VOLUME_1, matchers.DictMatches(resp)) self.db_api.add_item.assert_called_once_with( mock.ANY, 'vol', - tools.purge_dict(fakes.DB_VOLUME_1, ('id',))) + tools.purge_dict(fakes.DB_VOLUME_1, ('id',)), + project_id=None) self.cinder.volumes.create.assert_called_once_with( None, snapshot_id=None, volume_type=None, @@ -130,7 +131,8 @@ class VolumeTestCase(base.ApiTestCase): self.assertThat(fakes.EC2_VOLUME_3, matchers.DictMatches(resp)) self.db_api.add_item.assert_called_once_with( mock.ANY, 'vol', - tools.purge_dict(fakes.DB_VOLUME_3, ('id',))) + tools.purge_dict(fakes.DB_VOLUME_3, ('id',)), + project_id=None) self.cinder.volumes.create.assert_called_once_with( None, snapshot_id=fakes.ID_OS_SNAPSHOT_1, volume_type=None, diff --git a/ec2api/tests/unit/test_vpc.py b/ec2api/tests/unit/test_vpc.py index c5224bea..4e862613 100644 --- a/ec2api/tests/unit/test_vpc.py +++ b/ec2api/tests/unit/test_vpc.py @@ -49,11 +49,13 @@ class VpcTestCase(base.ApiTestCase): self.db_api.add_item.assert_any_call( mock.ANY, 'vpc', tools.purge_dict(fakes.DB_VPC_1, - ('id', 'vpc_id', 'route_table_id'))) + ('id', 'vpc_id', 'route_table_id')), + project_id=None) self.db_api.add_item.assert_any_call( mock.ANY, 'rtb', tools.purge_dict(fakes.DB_ROUTE_TABLE_1, - ('id',))) + ('id',)), + project_id=None) self.db_api.update_item.assert_called_once_with( mock.ANY, fakes.DB_VPC_1) diff --git a/ec2api/tests/unit/tools.py b/ec2api/tests/unit/tools.py index 0a899404..36f89ceb 100644 --- a/ec2api/tests/unit/tools.py +++ b/ec2api/tests/unit/tools.py @@ -49,12 +49,14 @@ def patch_dict(dict1, dict2, trash_iter): def get_db_api_add_item(item_id_dict): """Generate db_api.add_item mock function.""" - def db_api_add_item(context, kind, data): + def db_api_add_item(context, kind, data, project_id=None): if isinstance(item_id_dict, dict): item_id = item_id_dict[kind] else: item_id = item_id_dict data = update_dict(data, {'id': item_id}) + if project_id: + data = update_dict(data, {'project_id': project_id}) data.setdefault('os_id') data.setdefault('vpc_id') return data