Merge "Fix performance regression adding rules to security groups" into stable/rocky

This commit is contained in:
Zuul 2019-02-09 10:38:17 +00:00 committed by Gerrit Code Review
commit 4dd9b347ef
4 changed files with 142 additions and 114 deletions

@ -192,6 +192,18 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
raise ext_sg.SecurityGroupNotFound(id=id)
return sg
def _check_security_group(self, context, id, tenant_id=None):
if tenant_id:
tmp_context_tenant_id = context.tenant_id
context.tenant_id = tenant_id
try:
if not sg_obj.SecurityGroup.objects_exist(context, id=id):
raise ext_sg.SecurityGroupNotFound(id=id)
finally:
if tenant_id:
context.tenant_id = tmp_context_tenant_id
@db_api.retry_if_session_inactive()
def delete_security_group(self, context, id):
filters = {'security_group_id': [id]}
@ -325,10 +337,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
security_group_id = self._validate_security_group_rules(
context, security_group_rules)
with db_api.context_manager.writer.using(context):
if not self.get_security_group(context, security_group_id):
raise ext_sg.SecurityGroupNotFound(id=security_group_id)
self._check_for_duplicate_rules(context, rules)
self._check_for_duplicate_rules(context, security_group_id, rules)
ret = []
for rule_dict in rules:
res_rule_dict = self._create_security_group_rule(
@ -351,7 +360,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
def _create_security_group_rule(self, context, security_group_rule,
validate=True):
if validate:
self._validate_security_group_rule(context, security_group_rule)
sg_id = self._validate_security_group_rule(context,
security_group_rule)
rule_dict = security_group_rule['security_group_rule']
remote_ip_prefix = rule_dict.get('remote_ip_prefix')
if remote_ip_prefix:
@ -391,8 +401,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
exc_cls=ext_sg.SecurityGroupConflict, **kwargs)
with db_api.context_manager.writer.using(context):
if validate:
self._check_for_duplicate_rules_in_db(context,
security_group_rule)
self._check_for_duplicate_rules(context, sg_id,
[security_group_rule])
sg_rule = sg_obj.SecurityGroupRule(context, **args)
sg_rule.create()
@ -525,15 +535,15 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
remote_group_id = rule['remote_group_id']
# Check that remote_group_id exists for tenant
if remote_group_id:
self.get_security_group(context, remote_group_id,
tenant_id=rule['tenant_id'])
self._check_security_group(context, remote_group_id,
tenant_id=rule['tenant_id'])
security_group_id = rule['security_group_id']
# Confirm that the tenant has permission
# to add rules to this security group.
self.get_security_group(context, security_group_id,
tenant_id=rule['tenant_id'])
self._check_security_group(context, security_group_id,
tenant_id=rule['tenant_id'])
return security_group_id
def _validate_security_group_rules(self, context, security_group_rules):
@ -576,73 +586,54 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
res['protocol'] = self._get_ip_proto_name_and_num(value)
return res
def _rules_equal(self, rule1, rule2):
"""Determines if two rules are equal ignoring id field."""
rule1_copy = rule1.copy()
rule2_copy = rule2.copy()
rule1_copy.pop('id', None)
rule2_copy.pop('id', None)
return rule1_copy == rule2_copy
def _rule_to_key(self, rule):
def _normalize_rule_value(key, value):
# This string is used as a placeholder for str(None), but shorter.
none_char = '+'
def _check_for_duplicate_rules(self, context, security_group_rules):
for i in security_group_rules:
found_self = False
for j in security_group_rules:
if self._rules_equal(i['security_group_rule'],
j['security_group_rule']):
if found_self:
raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i)
found_self = True
if key == 'remote_ip_prefix':
all_address = ['0.0.0.0/0', '::/0', None]
if value in all_address:
return none_char
elif value is None:
return none_char
elif key == 'protocol':
return str(self._get_ip_proto_name_and_num(value))
return str(value)
self._check_for_duplicate_rules_in_db(context, i)
comparison_keys = [
'direction',
'ethertype',
'port_range_max',
'port_range_min',
'protocol',
'remote_group_id',
'remote_ip_prefix',
'security_group_id'
]
return '_'.join([_normalize_rule_value(x, rule.get(x))
for x in comparison_keys])
def _check_for_duplicate_rules_in_db(self, context, security_group_rule):
# Check in database if rule exists
filters = self._make_security_group_rule_filter_dict(
security_group_rule)
rule_dict = security_group_rule['security_group_rule'].copy()
rule_dict.pop('description', None)
keys = rule_dict.keys()
fields = list(keys) + ['id']
if 'remote_ip_prefix' not in fields:
fields += ['remote_ip_prefix']
db_rules = self.get_security_group_rules(context, filters,
fields=fields)
# Note(arosen): the call to get_security_group_rules wildcards
# values in the filter that have a value of [None]. For
# example, filters = {'remote_group_id': [None]} will return
# all security group rules regardless of their value of
# remote_group_id. Therefore it is not possible to do this
# query unless the behavior of _get_collection()
# is changed which cannot be because other methods are already
# relying on this behavior. Therefore, we do the filtering
# below to check for these corner cases.
rule_dict.pop('id', None)
sg_protocol = rule_dict.pop('protocol', None)
remote_ip_prefix = rule_dict.pop('remote_ip_prefix', None)
for db_rule in db_rules:
rule_id = db_rule.pop('id', None)
# remove protocol and match separately for number and type
db_protocol = db_rule.pop('protocol', None)
is_protocol_matching = (
self._get_ip_proto_name_and_num(db_protocol) ==
self._get_ip_proto_name_and_num(sg_protocol))
db_remote_ip_prefix = db_rule.pop('remote_ip_prefix', None)
duplicate_ip_prefix = self._validate_duplicate_ip_prefix(
remote_ip_prefix, db_remote_ip_prefix)
if (is_protocol_matching and duplicate_ip_prefix and
rule_dict == db_rule):
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
def _check_for_duplicate_rules(self, context, security_group_id,
new_security_group_rules):
# First up, check for any duplicates in the new rules.
new_rules_set = set()
for i in new_security_group_rules:
rule_key = self._rule_to_key(i['security_group_rule'])
if rule_key in new_rules_set:
raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i)
new_rules_set.add(rule_key)
def _validate_duplicate_ip_prefix(self, ip_prefix, other_ip_prefix):
if other_ip_prefix is not None:
other_ip_prefix = str(other_ip_prefix)
all_address = ['0.0.0.0/0', '::/0', None]
if ip_prefix == other_ip_prefix:
return True
elif ip_prefix in all_address and other_ip_prefix in all_address:
return True
return False
# Now, let's make sure none of the new rules conflict with
# existing rules; note that we do *not* store the db rules
# in the set, as we assume they were already checked,
# when added.
sg = self.get_security_group(context, security_group_id)
if sg:
for i in sg['security_group_rules']:
rule_key = self._rule_to_key(i)
if rule_key in new_rules_set:
raise ext_sg.SecurityGroupRuleExists(rule_id=i.get('id'))
def _validate_ip_prefix(self, rule):
"""Check that a valid cidr was specified as remote_ip_prefix

@ -844,6 +844,6 @@ class NeutronDbObject(NeutronObject):
if validate_filters:
cls.validate_filters(**kwargs)
# Succeed if at least a single object matches; no need to fetch more
return bool(obj_db_api.get_object(
return bool(obj_db_api.count(
cls, context, **cls.modify_fields_to_db(kwargs))
)

@ -103,7 +103,7 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
def test_create_security_group_rule_conflict(self):
with mock.patch.object(self.mixin, '_validate_security_group_rule'),\
mock.patch.object(self.mixin,
'_check_for_duplicate_rules_in_db'),\
'_check_for_duplicate_rules'),\
mock.patch.object(registry, "notify") as mock_notify:
mock_notify.side_effect = exceptions.CallbackFailure(Exception())
with testtools.ExpectedException(
@ -111,9 +111,9 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
self.mixin.create_security_group_rule(
self.ctx, mock.MagicMock())
def test__check_for_duplicate_rules_in_db_does_not_drop_protocol(self):
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[mock.Mock()]):
def test__check_for_duplicate_rules_does_not_drop_protocol(self):
with mock.patch.object(self.mixin, 'get_security_group',
return_value=None):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'protocol': None,
@ -121,7 +121,7 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
'security_group_id': 'fake',
'direction': 'fake'}
}
self.mixin._check_for_duplicate_rules_in_db(context, rule_dict)
self.mixin._check_for_duplicate_rules(context, 'fake', [rule_dict])
self.assertIn('protocol', rule_dict['security_group_rule'])
def test__check_for_duplicate_rules_ignores_rule_id(self):
@ -132,33 +132,20 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
# in this case as this test, tests that the id fields are dropped
# while being compared. This is in the case if a plugin specifies
# the rule ids themselves.
self.assertRaises(securitygroup.DuplicateSecurityGroupRuleInPost,
self.mixin._check_for_duplicate_rules,
context, rules)
def test__check_for_duplicate_rules_in_db_ignores_rule_id(self):
db_rules = {'protocol': 'tcp', 'id': 'fake', 'tenant_id': 'fake',
'direction': 'ingress', 'security_group_id': 'fake'}
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[db_rules]):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'protocol': 'tcp',
'id': 'fake2',
'tenant_id': 'fake',
'security_group_id': 'fake',
'direction': 'ingress'}
}
self.assertRaises(securitygroup.SecurityGroupRuleExists,
self.mixin._check_for_duplicate_rules_in_db,
context, rule_dict)
with mock.patch.object(self.mixin, 'get_security_group',
return_value=None):
self.assertRaises(securitygroup.DuplicateSecurityGroupRuleInPost,
self.mixin._check_for_duplicate_rules,
context, 'fake', rules)
def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv4(self):
db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv4',
'direction': 'ingress', 'security_group_id': 'fake',
'remote_ip_prefix': None}
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[db_rules]):
fake_secgroup = copy.deepcopy(FAKE_SECGROUP)
fake_secgroup['security_group_rules'] = \
[{'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv4',
'direction': 'ingress', 'security_group_id': 'fake',
'remote_ip_prefix': None}]
with mock.patch.object(self.mixin, 'get_security_group',
return_value=fake_secgroup):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'id': 'fake2',
@ -169,15 +156,17 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
'remote_ip_prefix': '0.0.0.0/0'}
}
self.assertRaises(securitygroup.SecurityGroupRuleExists,
self.mixin._check_for_duplicate_rules_in_db,
context, rule_dict)
self.mixin._check_for_duplicate_rules,
context, 'fake', [rule_dict])
def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv6(self):
db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv6',
'direction': 'ingress', 'security_group_id': 'fake',
'remote_ip_prefix': None}
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[db_rules]):
fake_secgroup = copy.deepcopy(FAKE_SECGROUP)
fake_secgroup['security_group_rules'] = \
[{'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv6',
'direction': 'ingress', 'security_group_id': 'fake',
'remote_ip_prefix': None}]
with mock.patch.object(self.mixin, 'get_security_group',
return_value=fake_secgroup):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'id': 'fake2',
@ -188,8 +177,8 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
'remote_ip_prefix': '::/0'}
}
self.assertRaises(securitygroup.SecurityGroupRuleExists,
self.mixin._check_for_duplicate_rules_in_db,
context, rule_dict)
self.mixin._check_for_duplicate_rules,
context, 'fake', [rule_dict])
def test_delete_security_group_rule_in_use(self):
with mock.patch.object(registry, "notify") as mock_notify:

@ -128,6 +128,10 @@ class SecurityGroupsTestCase(test_db_base_plugin_v2.NeutronDbPluginV2TestCase):
# create a specific auth context for this request
security_group_rule_req.environ['neutron.context'] = (
context.Context('', kwargs['tenant_id']))
elif kwargs.get('admin_context'):
security_group_rule_req.environ['neutron.context'] = (
context.Context(user_id='admin', tenant_id='admin-tenant',
is_admin=True))
return security_group_rule_req.get_response(self.ext_api)
def _make_security_group(self, fmt, name, description, **kwargs):
@ -699,6 +703,50 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
for k, v, in keys:
self.assertEqual(sg_rule[0][k], v)
# This test case checks that admins from a different tenant can add rules
# as themselves. This is an odd behavior, with some weird GET semantics,
# but this test is checking that we don't break that old behavior, at least
# until we make a conscious choice to do so.
def test_create_security_group_rules_admin_tenant(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
# Add a couple normal rules
rule = self._build_security_group_rule(
sg['security_group']['id'], "ingress", const.PROTO_NAME_TCP,
port_range_min=22, port_range_max=22,
remote_ip_prefix="10.0.0.0/24",
ethertype=const.IPv4)
self._make_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], "ingress", const.PROTO_NAME_TCP,
port_range_min=22, port_range_max=22,
remote_ip_prefix="10.0.1.0/24",
ethertype=const.IPv4)
self._make_security_group_rule(self.fmt, rule)
# Let's add a rule as admin, with a different tenant_id. The
# results of this call are arguably a bug, but it is past behavior.
rule = self._build_security_group_rule(
sg['security_group']['id'], "ingress", const.PROTO_NAME_TCP,
port_range_min=22, port_range_max=22,
remote_ip_prefix="10.0.2.0/24",
ethertype=const.IPv4,
tenant_id='admin-tenant')
self._make_security_group_rule(self.fmt, rule, admin_context=True)
# Now, let's make sure all the rules are there, with their odd
# tenant_id behavior.
res = self.new_list_request('security-groups')
sgs = self.deserialize(self.fmt, res.get_response(self.ext_api))
for sg in sgs['security_groups']:
if sg['name'] == "webservers":
rules = sg['security_group_rules']
self.assertEqual(len(rules), 5)
self.assertNotEqual(rules[3]['tenant_id'], 'admin-tenant')
self.assertEqual(rules[4]['tenant_id'], 'admin-tenant')
def test_get_security_group_on_port_from_wrong_tenant(self):
plugin = directory.get_plugin()
if not hasattr(plugin, '_get_security_groups_on_port'):