Move security group refresh logic into ComputeAPI.

Add a trigger_security_group_members_refresh to ComputeAPI which
finds the hosts that have instances that have security groups that
reference a security group in which a new instance has just been placed,
and sends a refresh_security_group_members to each of them.
This commit is contained in:
Soren Hansen
2010-12-13 16:42:35 +01:00
parent 65c0443c4a
commit be9a3cd7e1
6 changed files with 127 additions and 23 deletions

View File

@@ -130,15 +130,6 @@ class CloudController(object):
result[key] = [line]
return result
def _trigger_refresh_security_group(self, context, security_group):
nodes = set([instance['host'] for instance in security_group.instances
if instance['host'] is not None])
for node in nodes:
rpc.cast(context,
'%s.%s' % (FLAGS.compute_topic, node),
{"method": "refresh_security_group",
"args": {"security_group_id": security_group.id}})
def get_metadata(self, address):
ctxt = context.get_admin_context()
instance_ref = db.fixed_ip_get_instance(ctxt, address)
@@ -369,7 +360,8 @@ class CloudController(object):
match = False
if match:
db.security_group_rule_destroy(context, rule['id'])
self._trigger_refresh_security_group(context, security_group)
self.compute_api.trigger_security_group_rules_refresh(context,
security_group['id'])
return True
raise exception.ApiError("No rule for the specified parameters.")
@@ -392,7 +384,8 @@ class CloudController(object):
security_group_rule = db.security_group_rule_create(context, values)
self._trigger_refresh_security_group(context, security_group)
self.compute_api.trigger_security_group_rules_refresh(context,
security_group['id'])
return True

View File

@@ -24,6 +24,7 @@ import datetime
import logging
import time
from nova import context
from nova import db
from nova import exception
from nova import flags
@@ -165,6 +166,10 @@ class ComputeAPI(base.Base):
"args": {"topic": FLAGS.compute_topic,
"instance_id": instance_id}})
for group_id in security_groups:
self.trigger_security_group_members_refresh(elevated, group_id)
return instances
def ensure_default_security_group(self, context):
@@ -184,6 +189,62 @@ class ComputeAPI(base.Base):
'project_id': context.project_id}
db.security_group_create(context, values)
def trigger_security_group_rules_refresh(self, context, security_group_id):
"""Called when a rule is added to or removed from a security_group"""
security_group = db.security_group_get(context, security_group_id)
hosts = set()
for instance in security_group['instances']:
if instance['host'] is not None:
hosts.add(instance['host'])
for host in hosts:
rpc.cast(context,
self.db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "refresh_security_group",
"args": {"security_group_id": security_group.id}})
def trigger_security_group_members_refresh(self, context, group_id):
"""Called when a security group gains a new or loses a member
Sends an update request to each compute node for whom this is
relevant."""
# First, we get the security group rules that reference this group as
# the grantee..
security_group_rules = \
db.security_group_rule_get_by_security_group_grantee(context,
group_id)
# ..then we distill the security groups to which they belong..
security_groups = set()
for rule in security_group_rules:
security_groups.add(rule['parent_group_id'])
# ..then we find the instances that are members of these groups..
instances = set()
for security_group in security_groups:
for instance in security_group['instances']:
instances.add(instance['id'])
# ...then we find the hosts where they live...
hosts = set()
for instance in instances:
if instance['host']:
hosts.add(instance['host'])
# ...and finally we tell these nodes to refresh their view of this
# particular security group.
for host in hosts:
rpc.cast(context,
self.db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "refresh_security_group_members",
"args": {"security_group_id": group_id}})
def update_instance(self, context, instance_id, **kwargs):
"""Updates the instance in the datastore.

View File

@@ -80,9 +80,19 @@ class ComputeManager(manager.Manager):
@defer.inlineCallbacks
@exception.wrap_exception
def refresh_security_group(self, context, security_group_id, **_kwargs):
"""This call passes stright through to the virtualization driver."""
yield self.driver.refresh_security_group(security_group_id)
def refresh_security_group_rules(self, context,
security_group_id, **_kwargs):
"""This call passes straight through to the virtualization driver."""
yield self.driver.refresh_security_group_rules(security_group_id)
@defer.inlineCallbacks
@exception.wrap_exception
def refresh_security_group_members(self, context,
security_group_id, **_kwargs):
"""This call passes straight through to the virtualization driver."""
yield self.driver.refresh_security_group_members(security_group_id)
@defer.inlineCallbacks
@exception.wrap_exception

View File

@@ -711,6 +711,13 @@ def security_group_rule_get_by_security_group(context, security_group_id):
security_group_id)
def security_group_rule_get_by_security_group_grantee(context,
security_group_id):
"""Get all rules that grant access to the given security group."""
return IMPL.security_group_rule_get_by_security_group_grantee(context,
security_group_id)
def security_group_rule_destroy(context, security_group_rule_id):
"""Deletes a security group rule."""
return IMPL.security_group_rule_destroy(context, security_group_rule_id)

View File

@@ -1532,6 +1532,25 @@ def security_group_rule_get_by_security_group(context, security_group_id, sessio
return result
@require_context
def security_group_rule_get_by_security_group_grantee(context,
security_group_id,
session=None):
if not session:
session = get_session()
if is_admin_context(context):
result = session.query(models.SecurityGroupIngressRule).\
filter_by(deleted=can_read_deleted(context)).\
filter_by(group_id=security_group_id).\
all()
else:
result = session.query(models.SecurityGroupIngressRule).\
filter_by(deleted=False).\
filter_by(group_id=security_group_id).\
all()
return result
@require_context
def security_group_rule_create(context, values):
security_group_rule_ref = models.SecurityGroupIngressRule()

View File

@@ -656,8 +656,11 @@ class LibvirtConnection(object):
domain = self._conn.lookupByName(instance_name)
return domain.interfaceStats(interface)
def refresh_security_group(self, security_group_id):
self.firewall_driver.refresh_security_group(security_group_id)
def refresh_security_group_rules(self, security_group_id):
self.firewall_driver.refresh_security_group_rules(security_group_id)
def refresh_security_group_members(self, security_group_id):
self.firewall_driver.refresh_security_group_members(security_group_id)
class FirewallDriver(object):
@@ -677,11 +680,19 @@ class FirewallDriver(object):
"""
raise NotImplementedError()
def refresh_security_group(self, security_group_id):
"""Refresh security group from data store
def refresh_security_group_rules(self, security_group_id):
"""Refresh security group rules from data store
Gets called when changes have been made to the security
group."""
Gets called when a rule has been added to or removed from
the security group."""
raise NotImplementedError()
def refresh_security_group_members(self, security_group_id):
"""Refresh security group members from data store
Gets called when an instance gets added to or removed from
the security group."""
raise NotImplementedError()
@@ -876,7 +887,7 @@ class NWFilterFirewall(FirewallDriver):
for security_group in db.security_group_get_by_instance(ctxt,
instance['id']):
yield self.refresh_security_group(security_group['id'])
yield self.refresh_security_group_rules(security_group['id'])
instance_secgroup_filter_children += [('nova-secgroup-%s' %
security_group['id'])]
@@ -891,7 +902,7 @@ class NWFilterFirewall(FirewallDriver):
return
def refresh_security_group(self, security_group_id):
def refresh_security_group_rules(self, security_group_id):
return self._define_filter(
self.security_group_to_nwfilter_xml(security_group_id))
@@ -1062,7 +1073,10 @@ class IptablesFirewallDriver(FirewallDriver):
logging.info('new_filter: %s', '\n'.join(new_filter))
return new_filter
def refresh_security_group(self, security_group):
def refresh_security_group_members(self, security_group):
pass
def refresh_security_group_rules(self, security_group):
self.apply_ruleset()
def _security_group_chain_name(self, security_group):