Simplify rollback of route changes

Change-Id: I6118d7d57074750687bd084b15cc1810280dfa88
This commit is contained in:
Feodor Tersin 2015-05-12 17:19:59 +03:00
parent 8b13458b19
commit 76b5766d0b
2 changed files with 48 additions and 93 deletions

View File

@ -77,8 +77,7 @@ def delete_route(context, route_table_id, destination_cidr_block):
cleaner.addCleanup(db_api.update_item, context, cleaner.addCleanup(db_api.update_item, context,
rollback_route_table_state) rollback_route_table_state)
_update_routes_in_associated_subnets(context, route_table, cleaner, _update_routes_in_associated_subnets(context, route_table, cleaner)
rollback_route_table_state)
return True return True
@ -98,15 +97,11 @@ def associate_route_table(context, route_table_id, subnet_id):
msg = msg % {'rtb_id': route_table_id} msg = msg % {'rtb_id': route_table_id}
raise exception.ResourceAlreadyAssociated(msg) raise exception.ResourceAlreadyAssociated(msg)
vpc = db_api.get_item_by_id(context, subnet['vpc_id'])
main_route_table = db_api.get_item_by_id(context, vpc['route_table_id'])
with common.OnCrashCleaner() as cleaner: with common.OnCrashCleaner() as cleaner:
_associate_subnet_item(context, subnet, route_table['id']) _associate_subnet_item(context, subnet, route_table['id'])
cleaner.addCleanup(_disassociate_subnet_item, context, subnet) cleaner.addCleanup(_disassociate_subnet_item, context, subnet)
_update_subnet_host_routes( _update_subnet_host_routes(context, subnet, route_table, cleaner)
context, subnet, route_table,
cleaner=cleaner, rollback_route_table_object=main_route_table)
return {'associationId': ec2utils.change_ec2_id_kind(subnet['id'], return {'associationId': ec2utils.change_ec2_id_kind(subnet['id'],
'rtbassoc')} 'rtbassoc')}
@ -119,27 +114,23 @@ def replace_route_table_association(context, association_id, route_table_id):
vpc = db_api.get_item_by_id( vpc = db_api.get_item_by_id(
context, ec2utils.change_ec2_id_kind(association_id, 'vpc')) context, ec2utils.change_ec2_id_kind(association_id, 'vpc'))
if vpc is None: if vpc is None:
raise exception.InvalidAssociationIDNotFound( raise exception.InvalidAssociationIDNotFound(id=association_id)
id=association_id)
rollabck_route_table_object = db_api.get_item_by_id( rollback_route_table_id = vpc['route_table_id']
context, vpc['route_table_id'])
with common.OnCrashCleaner() as cleaner: with common.OnCrashCleaner() as cleaner:
_associate_vpc_item(context, vpc, route_table['id']) _associate_vpc_item(context, vpc, route_table['id'])
cleaner.addCleanup(_associate_vpc_item, context, vpc, cleaner.addCleanup(_associate_vpc_item, context, vpc,
rollabck_route_table_object['id']) rollback_route_table_id)
# NOTE(ft): this can cause unnecessary update of subnets, which are # NOTE(ft): this can cause unnecessary update of subnets, which are
# associated with the route table # associated with the route table
_update_routes_in_associated_subnets( _update_routes_in_associated_subnets(
context, route_table, cleaner, context, route_table, cleaner, is_main=True)
rollabck_route_table_object, is_main=True)
else: else:
subnet = db_api.get_item_by_id( subnet = db_api.get_item_by_id(
context, ec2utils.change_ec2_id_kind(association_id, 'subnet')) context, ec2utils.change_ec2_id_kind(association_id, 'subnet'))
if subnet is None or 'route_table_id' not in subnet: if subnet is None or 'route_table_id' not in subnet:
raise exception.InvalidAssociationIDNotFound( raise exception.InvalidAssociationIDNotFound(id=association_id)
id=association_id)
if subnet['vpc_id'] != route_table['vpc_id']: if subnet['vpc_id'] != route_table['vpc_id']:
msg = _('Route table association %(rtbassoc_id)s and route table ' msg = _('Route table association %(rtbassoc_id)s and route table '
'%(rtb_id)s belong to different networks') '%(rtb_id)s belong to different networks')
@ -147,16 +138,13 @@ def replace_route_table_association(context, association_id, route_table_id):
'rtb_id': route_table_id} 'rtb_id': route_table_id}
raise exception.InvalidParameterValue(msg) raise exception.InvalidParameterValue(msg)
rollabck_route_table_object = db_api.get_item_by_id( rollback_route_table_id = subnet['route_table_id']
context, subnet['route_table_id'])
with common.OnCrashCleaner() as cleaner: with common.OnCrashCleaner() as cleaner:
_associate_subnet_item(context, subnet, route_table['id']) _associate_subnet_item(context, subnet, route_table['id'])
cleaner.addCleanup(_associate_subnet_item, context, subnet, cleaner.addCleanup(_associate_subnet_item, context, subnet,
rollabck_route_table_object['id']) rollback_route_table_id)
_update_subnet_host_routes( _update_subnet_host_routes(context, subnet, route_table, cleaner)
context, subnet, route_table, cleaner=cleaner,
rollback_route_table_object=rollabck_route_table_object)
return {'newAssociationId': association_id} return {'newAssociationId': association_id}
@ -168,27 +156,22 @@ def disassociate_route_table(context, association_id):
vpc = db_api.get_item_by_id( vpc = db_api.get_item_by_id(
context, ec2utils.change_ec2_id_kind(association_id, 'vpc')) context, ec2utils.change_ec2_id_kind(association_id, 'vpc'))
if vpc is None: if vpc is None:
raise exception.InvalidAssociationIDNotFound( raise exception.InvalidAssociationIDNotFound(id=association_id)
id=association_id)
msg = _('Cannot disassociate the main route table association ' msg = _('Cannot disassociate the main route table association '
'%(rtbassoc_id)s') % {'rtbassoc_id': association_id} '%(rtbassoc_id)s') % {'rtbassoc_id': association_id}
raise exception.InvalidParameterValue(msg) raise exception.InvalidParameterValue(msg)
if 'route_table_id' not in subnet: if 'route_table_id' not in subnet:
raise exception.InvalidAssociationIDNotFound( raise exception.InvalidAssociationIDNotFound(id=association_id)
id=association_id)
rollback_route_table_object = db_api.get_item_by_id( rollback_route_table_id = subnet['route_table_id']
context, subnet['route_table_id'])
vpc = db_api.get_item_by_id(context, subnet['vpc_id']) vpc = db_api.get_item_by_id(context, subnet['vpc_id'])
main_route_table = db_api.get_item_by_id(context, vpc['route_table_id']) main_route_table = db_api.get_item_by_id(context, vpc['route_table_id'])
with common.OnCrashCleaner() as cleaner: with common.OnCrashCleaner() as cleaner:
_disassociate_subnet_item(context, subnet) _disassociate_subnet_item(context, subnet)
cleaner.addCleanup(_associate_subnet_item, context, subnet, cleaner.addCleanup(_associate_subnet_item, context, subnet,
rollback_route_table_object['id']) rollback_route_table_id)
_update_subnet_host_routes( _update_subnet_host_routes(context, subnet, main_route_table, cleaner)
context, subnet, main_route_table, cleaner=cleaner,
rollback_route_table_object=rollback_route_table_object)
return True return True
@ -377,8 +360,7 @@ def _set_route(context, route_table_id, destination_cidr_block,
db_api.update_item(context, route_table) db_api.update_item(context, route_table)
cleaner.addCleanup(db_api.update_item, context, cleaner.addCleanup(db_api.update_item, context,
rollabck_route_table_state) rollabck_route_table_state)
_update_routes_in_associated_subnets(context, route_table, cleaner, _update_routes_in_associated_subnets(context, route_table, cleaner)
rollabck_route_table_state)
return True return True
@ -458,7 +440,6 @@ def _format_route_table(context, route_table, is_main=False,
def _update_routes_in_associated_subnets(context, route_table, cleaner, def _update_routes_in_associated_subnets(context, route_table, cleaner,
rollabck_route_table_object,
is_main=None): is_main=None):
if is_main is None: if is_main is None:
vpc = db_api.get_item_by_id(context, route_table['vpc_id']) vpc = db_api.get_item_by_id(context, route_table['vpc_id'])
@ -473,13 +454,11 @@ def _update_routes_in_associated_subnets(context, route_table, cleaner,
if (subnet['vpc_id'] == route_table['vpc_id'] and if (subnet['vpc_id'] == route_table['vpc_id'] and
subnet.get('route_table_id') in appropriate_rtb_ids): subnet.get('route_table_id') in appropriate_rtb_ids):
_update_subnet_host_routes( _update_subnet_host_routes(
context, subnet, route_table, cleaner=cleaner, context, subnet, route_table, cleaner,
rollback_route_table_object=rollabck_route_table_object,
router_objects=router_objects, neutron=neutron) router_objects=router_objects, neutron=neutron)
def _update_subnet_host_routes(context, subnet, route_table, cleaner=None, def _update_subnet_host_routes(context, subnet, route_table, cleaner,
rollback_route_table_object=None,
router_objects=None, neutron=None): router_objects=None, neutron=None):
neutron = neutron or clients.neutron(context) neutron = neutron or clients.neutron(context)
os_subnet = neutron.show_subnet(subnet['os_id'])['subnet'] os_subnet = neutron.show_subnet(subnet['os_id'])['subnet']
@ -489,38 +468,35 @@ def _update_subnet_host_routes(context, subnet, route_table, cleaner=None,
router_objects) router_objects)
neutron.update_subnet(subnet['os_id'], neutron.update_subnet(subnet['os_id'],
{'subnet': {'host_routes': host_routes}}) {'subnet': {'host_routes': host_routes}})
if cleaner and rollback_route_table_object: cleaner.addCleanup(
cleaner.addCleanup(_update_subnet_host_routes, context, subnet, neutron.update_subnet, subnet['os_id'],
rollback_route_table_object) {'subnet': {'host_routes': os_subnet['host_routes']}})
def _get_router_objects(context, route_table): def _get_router_objects(context, route_table):
return dict((route['gateway_id'], object_ids = [route[id_key]
db_api.get_item_by_id(context, route['gateway_id'])) for route in route_table['routes']
if route.get('gateway_id') else for id_key in ('gateway_id', 'network_interface_id')
(route['network_interface_id'], if id_key in route and route[id_key]]
db_api.get_item_by_id(context, route['network_interface_id'])) return dict((item['id'], item)
for route in route_table['routes'] for item in db_api.get_items_by_ids(context, object_ids))
if route.get('gateway_id') or 'network_interface_id' in route)
def _get_subnet_host_routes(context, route_table, gateway_ip, def _get_subnet_host_routes(context, route_table, gateway_ip,
router_objects=None): router_objects=None):
if router_objects is None:
router_objects = _get_router_objects(context, route_table)
def get_nexthop(route): def get_nexthop(route):
if 'gateway_id' in route: if 'gateway_id' in route:
gateway_id = route['gateway_id'] gateway_id = route['gateway_id']
if gateway_id: if gateway_id:
gateway = (router_objects[route['gateway_id']] gateway = router_objects.get(route['gateway_id'])
if router_objects else
db_api.get_item_by_id(context, gateway_id))
if (not gateway or if (not gateway or
gateway.get('vpc_id') != route_table['vpc_id']): gateway['vpc_id'] != route_table['vpc_id']):
return '127.0.0.1' return '127.0.0.1'
return gateway_ip return gateway_ip
network_interface = ( network_interface = router_objects.get(route['network_interface_id'])
router_objects[route['network_interface_id']]
if router_objects else
db_api.get_item_by_id(context, route['network_interface_id']))
if not network_interface: if not network_interface:
return '127.0.0.1' return '127.0.0.1'
return network_interface['private_ip_address'] return network_interface['private_ip_address']

View File

@ -67,8 +67,7 @@ class RouteTableTestCase(base.ApiTestCase):
self.db_api.update_item.assert_called_once_with( self.db_api.update_item.assert_called_once_with(
mock.ANY, route_table) mock.ANY, route_table)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, route_table, mock.ANY, mock.ANY, route_table, mock.ANY)
rollback_route_table_state)
self.db_api.update_item.reset_mock() self.db_api.update_item.reset_mock()
routes_updater.reset_mock() routes_updater.reset_mock()
@ -251,8 +250,7 @@ class RouteTableTestCase(base.ApiTestCase):
'network_interface_id': fakes.ID_EC2_NETWORK_INTERFACE_1, 'network_interface_id': fakes.ID_EC2_NETWORK_INTERFACE_1,
'destination_cidr_block': '0.0.0.0/0'}) 'destination_cidr_block': '0.0.0.0/0'})
self.db_api.update_item.assert_called_once_with(mock.ANY, route_table) self.db_api.update_item.assert_called_once_with(mock.ANY, route_table)
routes_updater.assert_called_once_with(mock.ANY, route_table, mock.ANY, routes_updater.assert_called_once_with(mock.ANY, route_table, mock.ANY)
rollback_route_table_state)
def test_replace_route_invalid_parameters(self): def test_replace_route_invalid_parameters(self):
self.set_mock_db_items(fakes.DB_ROUTE_TABLE_1, self.set_mock_db_items(fakes.DB_ROUTE_TABLE_1,
@ -278,7 +276,7 @@ class RouteTableTestCase(base.ApiTestCase):
if r['destination_cidr_block'] != fakes.CIDR_EXTERNAL_NETWORK] if r['destination_cidr_block'] != fakes.CIDR_EXTERNAL_NETWORK]
self.db_api.update_item.assert_called_once_with(mock.ANY, route_table) self.db_api.update_item.assert_called_once_with(mock.ANY, route_table)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, route_table, mock.ANY, fakes.DB_ROUTE_TABLE_2) mock.ANY, route_table, mock.ANY)
def test_delete_route_invalid_parameters(self): def test_delete_route_invalid_parameters(self):
self.set_mock_db_items() self.set_mock_db_items()
@ -327,8 +325,7 @@ class RouteTableTestCase(base.ApiTestCase):
self.db_api.update_item.assert_called_once_with( self.db_api.update_item.assert_called_once_with(
mock.ANY, subnet) mock.ANY, subnet)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, cleaner=mock.ANY, mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, mock.ANY)
rollback_route_table_object=fakes.DB_ROUTE_TABLE_1)
def test_associate_route_table_invalid_parameters(self): def test_associate_route_table_invalid_parameters(self):
def do_check(params, error_code): def do_check(params, error_code):
@ -392,8 +389,7 @@ class RouteTableTestCase(base.ApiTestCase):
self.db_api.update_item.assert_called_once_with( self.db_api.update_item.assert_called_once_with(
mock.ANY, subnet) mock.ANY, subnet)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, subnet, fakes.DB_ROUTE_TABLE_2, cleaner=mock.ANY, mock.ANY, subnet, fakes.DB_ROUTE_TABLE_2, mock.ANY)
rollback_route_table_object=fakes.DB_ROUTE_TABLE_3)
@mock.patch('ec2api.api.route_table._update_routes_in_associated_subnets') @mock.patch('ec2api.api.route_table._update_routes_in_associated_subnets')
def test_replace_route_table_association_main(self, routes_updater): def test_replace_route_table_association_main(self, routes_updater):
@ -411,8 +407,7 @@ class RouteTableTestCase(base.ApiTestCase):
self.db_api.update_item.assert_called_once_with( self.db_api.update_item.assert_called_once_with(
mock.ANY, vpc) mock.ANY, vpc)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, fakes.DB_ROUTE_TABLE_2, mock.ANY, mock.ANY, fakes.DB_ROUTE_TABLE_2, mock.ANY, is_main=True)
fakes.DB_ROUTE_TABLE_1, is_main=True)
def test_replace_route_table_association_invalid_parameters(self): def test_replace_route_table_association_invalid_parameters(self):
def do_check(params, error_code): def do_check(params, error_code):
@ -499,9 +494,7 @@ class RouteTableTestCase(base.ApiTestCase):
self.db_api.update_item.assert_called_once_with( self.db_api.update_item.assert_called_once_with(
mock.ANY, subnet) mock.ANY, subnet)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, mock.ANY, subnet, fakes.DB_ROUTE_TABLE_1, mock.ANY)
cleaner=mock.ANY,
rollback_route_table_object=fakes.DB_ROUTE_TABLE_3)
def test_disassociate_route_table_invalid_parameter(self): def test_disassociate_route_table_invalid_parameter(self):
def do_check(params, error_code): def do_check(params, error_code):
@ -713,7 +706,8 @@ class RouteTableTestCase(base.ApiTestCase):
route_table._update_subnet_host_routes( route_table._update_subnet_host_routes(
self._create_context(), fakes.DB_SUBNET_1, self._create_context(), fakes.DB_SUBNET_1,
fakes.DB_ROUTE_TABLE_1, router_objects={'fake': 'objects'}) fakes.DB_ROUTE_TABLE_1, common.OnCrashCleaner(),
router_objects={'fake': 'objects'})
self.neutron.show_subnet.assert_called_once_with(fakes.ID_OS_SUBNET_1) self.neutron.show_subnet.assert_called_once_with(fakes.ID_OS_SUBNET_1)
routes_getter.assert_called_once_with( routes_getter.assert_called_once_with(
@ -724,32 +718,21 @@ class RouteTableTestCase(base.ApiTestCase):
{'subnet': {'host_routes': 'fake_routes'}}) {'subnet': {'host_routes': 'fake_routes'}})
self.neutron.reset_mock() self.neutron.reset_mock()
routes_getter.reset_mock()
routes_getter.side_effect = ['fake_routes', 'fake_previous_routes']
try: try:
with common.OnCrashCleaner() as cleaner: with common.OnCrashCleaner() as cleaner:
route_table._update_subnet_host_routes( route_table._update_subnet_host_routes(
self._create_context(), fakes.DB_SUBNET_1, self._create_context(), fakes.DB_SUBNET_1,
fakes.DB_ROUTE_TABLE_1, cleaner, fakes.DB_ROUTE_TABLE_1, cleaner,
fakes.DB_ROUTE_TABLE_2,
router_objects={'fake': 'objects'}) router_objects={'fake': 'objects'})
raise Exception('fake_exception') raise Exception('fake_exception')
except Exception as ex: except Exception as ex:
if ex.message != 'fake_exception': if ex.message != 'fake_exception':
raise raise
self.neutron.show_subnet.assert_any_call(fakes.ID_OS_SUBNET_1)
routes_getter.assert_any_call(
mock.ANY, fakes.DB_ROUTE_TABLE_1, fakes.IP_GATEWAY_SUBNET_1,
{'fake': 'objects'})
routes_getter.assert_any_call(
mock.ANY, fakes.DB_ROUTE_TABLE_2, fakes.IP_GATEWAY_SUBNET_1,
None)
self.neutron.update_subnet.assert_any_call( self.neutron.update_subnet.assert_any_call(
fakes.ID_OS_SUBNET_1, fakes.ID_OS_SUBNET_1,
{'subnet': {'host_routes': 'fake_previous_routes'}}) {'subnet': {'host_routes': fakes.OS_SUBNET_1['host_routes']}})
@mock.patch('ec2api.api.route_table._get_router_objects') @mock.patch('ec2api.api.route_table._get_router_objects')
@mock.patch('ec2api.api.route_table._update_subnet_host_routes') @mock.patch('ec2api.api.route_table._update_subnet_host_routes')
@ -768,15 +751,12 @@ class RouteTableTestCase(base.ApiTestCase):
get_router_objects.return_value = {'fake': 'objects'} get_router_objects.return_value = {'fake': 'objects'}
route_table._update_routes_in_associated_subnets( route_table._update_routes_in_associated_subnets(
mock.MagicMock(), fakes.DB_ROUTE_TABLE_2, 'fake_cleaner', mock.MagicMock(), fakes.DB_ROUTE_TABLE_2, 'fake_cleaner')
{'fake': 'table'})
self.db_api.get_item_by_id.assert_called_once_with( self.db_api.get_item_by_id.assert_called_once_with(
mock.ANY, fakes.ID_EC2_VPC_1) mock.ANY, fakes.ID_EC2_VPC_1)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, subnet_rtb_2, fakes.DB_ROUTE_TABLE_2, mock.ANY, subnet_rtb_2, fakes.DB_ROUTE_TABLE_2, 'fake_cleaner',
cleaner='fake_cleaner',
rollback_route_table_object={'fake': 'table'},
router_objects={'fake': 'objects'}, neutron=mock.ANY) router_objects={'fake': 'objects'}, neutron=mock.ANY)
get_router_objects.assert_called_once_with(mock.ANY, get_router_objects.assert_called_once_with(mock.ANY,
fakes.DB_ROUTE_TABLE_2) fakes.DB_ROUTE_TABLE_2)
@ -787,14 +767,13 @@ class RouteTableTestCase(base.ApiTestCase):
route_table._update_routes_in_associated_subnets( route_table._update_routes_in_associated_subnets(
mock.MagicMock(), fakes.DB_ROUTE_TABLE_1, 'fake_cleaner', mock.MagicMock(), fakes.DB_ROUTE_TABLE_1, 'fake_cleaner',
{'fake': 'table'}, is_main=True) is_main=True)
self.assertEqual(0, self.db_api.get_item_by_id.call_count) self.assertEqual(0, self.db_api.get_item_by_id.call_count)
routes_updater.assert_called_once_with( routes_updater.assert_called_once_with(
mock.ANY, subnet_default_rtb, fakes.DB_ROUTE_TABLE_1, mock.ANY, subnet_default_rtb, fakes.DB_ROUTE_TABLE_1,
cleaner='fake_cleaner', 'fake_cleaner', router_objects={'fake': 'objects'},
rollback_route_table_object={'fake': 'table'}, neutron=mock.ANY)
router_objects={'fake': 'objects'}, neutron=mock.ANY)
get_router_objects.assert_called_once_with(mock.ANY, get_router_objects.assert_called_once_with(mock.ANY,
fakes.DB_ROUTE_TABLE_1) fakes.DB_ROUTE_TABLE_1)