Move security group refresh logic into ComputeAPI.

Add a trigger_security_group_members_refresh to ComputeAPI which
finds the hosts that have instances that have security groups that
reference a security group in which a new instance has just been placed,
and sends a refresh_security_group_members to each of them.
This commit is contained in:
Soren Hansen
2010-12-13 16:42:35 +01:00
parent 65c0443c4a
commit be9a3cd7e1
6 changed files with 127 additions and 23 deletions

View File

@@ -130,15 +130,6 @@ class CloudController(object):
result[key] = [line] result[key] = [line]
return result return result
def _trigger_refresh_security_group(self, context, security_group):
nodes = set([instance['host'] for instance in security_group.instances
if instance['host'] is not None])
for node in nodes:
rpc.cast(context,
'%s.%s' % (FLAGS.compute_topic, node),
{"method": "refresh_security_group",
"args": {"security_group_id": security_group.id}})
def get_metadata(self, address): def get_metadata(self, address):
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
instance_ref = db.fixed_ip_get_instance(ctxt, address) instance_ref = db.fixed_ip_get_instance(ctxt, address)
@@ -369,7 +360,8 @@ class CloudController(object):
match = False match = False
if match: if match:
db.security_group_rule_destroy(context, rule['id']) db.security_group_rule_destroy(context, rule['id'])
self._trigger_refresh_security_group(context, security_group) self.compute_api.trigger_security_group_rules_refresh(context,
security_group['id'])
return True return True
raise exception.ApiError("No rule for the specified parameters.") raise exception.ApiError("No rule for the specified parameters.")
@@ -392,7 +384,8 @@ class CloudController(object):
security_group_rule = db.security_group_rule_create(context, values) security_group_rule = db.security_group_rule_create(context, values)
self._trigger_refresh_security_group(context, security_group) self.compute_api.trigger_security_group_rules_refresh(context,
security_group['id'])
return True return True

View File

@@ -24,6 +24,7 @@ import datetime
import logging import logging
import time import time
from nova import context
from nova import db from nova import db
from nova import exception from nova import exception
from nova import flags from nova import flags
@@ -165,6 +166,10 @@ class ComputeAPI(base.Base):
"args": {"topic": FLAGS.compute_topic, "args": {"topic": FLAGS.compute_topic,
"instance_id": instance_id}}) "instance_id": instance_id}})
for group_id in security_groups:
self.trigger_security_group_members_refresh(elevated, group_id)
return instances return instances
def ensure_default_security_group(self, context): def ensure_default_security_group(self, context):
@@ -184,6 +189,62 @@ class ComputeAPI(base.Base):
'project_id': context.project_id} 'project_id': context.project_id}
db.security_group_create(context, values) db.security_group_create(context, values)
def trigger_security_group_rules_refresh(self, context, security_group_id):
"""Called when a rule is added to or removed from a security_group"""
security_group = db.security_group_get(context, security_group_id)
hosts = set()
for instance in security_group['instances']:
if instance['host'] is not None:
hosts.add(instance['host'])
for host in hosts:
rpc.cast(context,
self.db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "refresh_security_group",
"args": {"security_group_id": security_group.id}})
def trigger_security_group_members_refresh(self, context, group_id):
"""Called when a security group gains a new or loses a member
Sends an update request to each compute node for whom this is
relevant."""
# First, we get the security group rules that reference this group as
# the grantee..
security_group_rules = \
db.security_group_rule_get_by_security_group_grantee(context,
group_id)
# ..then we distill the security groups to which they belong..
security_groups = set()
for rule in security_group_rules:
security_groups.add(rule['parent_group_id'])
# ..then we find the instances that are members of these groups..
instances = set()
for security_group in security_groups:
for instance in security_group['instances']:
instances.add(instance['id'])
# ...then we find the hosts where they live...
hosts = set()
for instance in instances:
if instance['host']:
hosts.add(instance['host'])
# ...and finally we tell these nodes to refresh their view of this
# particular security group.
for host in hosts:
rpc.cast(context,
self.db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "refresh_security_group_members",
"args": {"security_group_id": group_id}})
def update_instance(self, context, instance_id, **kwargs): def update_instance(self, context, instance_id, **kwargs):
"""Updates the instance in the datastore. """Updates the instance in the datastore.

View File

@@ -80,9 +80,19 @@ class ComputeManager(manager.Manager):
@defer.inlineCallbacks @defer.inlineCallbacks
@exception.wrap_exception @exception.wrap_exception
def refresh_security_group(self, context, security_group_id, **_kwargs): def refresh_security_group_rules(self, context,
"""This call passes stright through to the virtualization driver.""" security_group_id, **_kwargs):
yield self.driver.refresh_security_group(security_group_id) """This call passes straight through to the virtualization driver."""
yield self.driver.refresh_security_group_rules(security_group_id)
@defer.inlineCallbacks
@exception.wrap_exception
def refresh_security_group_members(self, context,
security_group_id, **_kwargs):
"""This call passes straight through to the virtualization driver."""
yield self.driver.refresh_security_group_members(security_group_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@exception.wrap_exception @exception.wrap_exception

View File

@@ -711,6 +711,13 @@ def security_group_rule_get_by_security_group(context, security_group_id):
security_group_id) security_group_id)
def security_group_rule_get_by_security_group_grantee(context,
security_group_id):
"""Get all rules that grant access to the given security group."""
return IMPL.security_group_rule_get_by_security_group_grantee(context,
security_group_id)
def security_group_rule_destroy(context, security_group_rule_id): def security_group_rule_destroy(context, security_group_rule_id):
"""Deletes a security group rule.""" """Deletes a security group rule."""
return IMPL.security_group_rule_destroy(context, security_group_rule_id) return IMPL.security_group_rule_destroy(context, security_group_rule_id)

View File

@@ -1532,6 +1532,25 @@ def security_group_rule_get_by_security_group(context, security_group_id, sessio
return result return result
@require_context
def security_group_rule_get_by_security_group_grantee(context,
security_group_id,
session=None):
if not session:
session = get_session()
if is_admin_context(context):
result = session.query(models.SecurityGroupIngressRule).\
filter_by(deleted=can_read_deleted(context)).\
filter_by(group_id=security_group_id).\
all()
else:
result = session.query(models.SecurityGroupIngressRule).\
filter_by(deleted=False).\
filter_by(group_id=security_group_id).\
all()
return result
@require_context @require_context
def security_group_rule_create(context, values): def security_group_rule_create(context, values):
security_group_rule_ref = models.SecurityGroupIngressRule() security_group_rule_ref = models.SecurityGroupIngressRule()

View File

@@ -656,8 +656,11 @@ class LibvirtConnection(object):
domain = self._conn.lookupByName(instance_name) domain = self._conn.lookupByName(instance_name)
return domain.interfaceStats(interface) return domain.interfaceStats(interface)
def refresh_security_group(self, security_group_id): def refresh_security_group_rules(self, security_group_id):
self.firewall_driver.refresh_security_group(security_group_id) self.firewall_driver.refresh_security_group_rules(security_group_id)
def refresh_security_group_members(self, security_group_id):
self.firewall_driver.refresh_security_group_members(security_group_id)
class FirewallDriver(object): class FirewallDriver(object):
@@ -677,11 +680,19 @@ class FirewallDriver(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def refresh_security_group(self, security_group_id): def refresh_security_group_rules(self, security_group_id):
"""Refresh security group from data store """Refresh security group rules from data store
Gets called when changes have been made to the security Gets called when a rule has been added to or removed from
group.""" the security group."""
raise NotImplementedError()
def refresh_security_group_members(self, security_group_id):
"""Refresh security group members from data store
Gets called when an instance gets added to or removed from
the security group."""
raise NotImplementedError() raise NotImplementedError()
@@ -876,7 +887,7 @@ class NWFilterFirewall(FirewallDriver):
for security_group in db.security_group_get_by_instance(ctxt, for security_group in db.security_group_get_by_instance(ctxt,
instance['id']): instance['id']):
yield self.refresh_security_group(security_group['id']) yield self.refresh_security_group_rules(security_group['id'])
instance_secgroup_filter_children += [('nova-secgroup-%s' % instance_secgroup_filter_children += [('nova-secgroup-%s' %
security_group['id'])] security_group['id'])]
@@ -891,7 +902,7 @@ class NWFilterFirewall(FirewallDriver):
return return
def refresh_security_group(self, security_group_id): def refresh_security_group_rules(self, security_group_id):
return self._define_filter( return self._define_filter(
self.security_group_to_nwfilter_xml(security_group_id)) self.security_group_to_nwfilter_xml(security_group_id))
@@ -1062,7 +1073,10 @@ class IptablesFirewallDriver(FirewallDriver):
logging.info('new_filter: %s', '\n'.join(new_filter)) logging.info('new_filter: %s', '\n'.join(new_filter))
return new_filter return new_filter
def refresh_security_group(self, security_group): def refresh_security_group_members(self, security_group):
pass
def refresh_security_group_rules(self, security_group):
self.apply_ruleset() self.apply_ruleset()
def _security_group_chain_name(self, security_group): def _security_group_chain_name(self, security_group):