This adds a way to create global firewall blocks that apply to all instances in your nova installation.

The mechanism for managing these rules is very similar to how security group rules are managed except there is only ever one instance of the provider rule table, as opposed to multiple security group tables.  Each instance will simply jump into the provider firewall table as one of its first actions (before security groups, so these rules cannot be overridden on a per-user basis).

Most of the changes are straightforward if you understand how security groups work.  There are a few small logging and variable name changes as well.

Right now this only exposes the creation of provider firewall rules.  If we agree this is the best path forward I will quickly be adding a list and destroy method and updating nova-adminclient.
This commit is contained in:
Todd Willey 2011-06-23 21:44:29 +00:00 committed by Tarmac
commit 654350a1cf
15 changed files with 556 additions and 15 deletions

View File

@ -21,7 +21,11 @@ Admin API controller, exposed through http via the api worker.
""" """
import base64 import base64
import datetime
import netaddr
import urllib
from nova import compute
from nova import db from nova import db
from nova import exception from nova import exception
from nova import flags from nova import flags
@ -117,6 +121,9 @@ class AdminController(object):
def __str__(self): def __str__(self):
return 'AdminController' return 'AdminController'
def __init__(self):
self.compute_api = compute.API()
def describe_instance_types(self, context, **_kwargs): def describe_instance_types(self, context, **_kwargs):
"""Returns all active instance types data (vcpus, memory, etc.)""" """Returns all active instance types data (vcpus, memory, etc.)"""
return {'instanceTypeSet': [instance_dict(v) for v in return {'instanceTypeSet': [instance_dict(v) for v in
@ -324,3 +331,41 @@ class AdminController(object):
rv.append(host_dict(host, compute, instances, volume, volumes, rv.append(host_dict(host, compute, instances, volume, volumes,
now)) now))
return {'hosts': rv} return {'hosts': rv}
def _provider_fw_rule_exists(self, context, rule):
# TODO(todd): we call this repeatedly, can we filter by protocol?
for old_rule in db.provider_fw_rule_get_all(context):
if all([rule[k] == old_rule[k] for k in ('cidr', 'from_port',
'to_port', 'protocol')]):
return True
return False
def block_external_addresses(self, context, cidr):
"""Add provider-level firewall rules to block incoming traffic."""
LOG.audit(_('Blocking traffic to all projects incoming from %s'),
cidr, context=context)
cidr = urllib.unquote(cidr).decode()
# raise if invalid
netaddr.IPNetwork(cidr)
rule = {'cidr': cidr}
tcp_rule = rule.copy()
tcp_rule.update({'protocol': 'tcp', 'from_port': 1, 'to_port': 65535})
udp_rule = rule.copy()
udp_rule.update({'protocol': 'udp', 'from_port': 1, 'to_port': 65535})
icmp_rule = rule.copy()
icmp_rule.update({'protocol': 'icmp', 'from_port': -1,
'to_port': None})
rules_added = 0
if not self._provider_fw_rule_exists(context, tcp_rule):
db.provider_fw_rule_create(context, tcp_rule)
rules_added += 1
if not self._provider_fw_rule_exists(context, udp_rule):
db.provider_fw_rule_create(context, udp_rule)
rules_added += 1
if not self._provider_fw_rule_exists(context, icmp_rule):
db.provider_fw_rule_create(context, icmp_rule)
rules_added += 1
if not rules_added:
raise exception.ApiError(_('Duplicate rule'))
self.compute_api.trigger_provider_fw_rules_refresh(context)
return {'status': 'OK', 'message': 'Added %s rules' % rules_added}

View File

@ -496,6 +496,16 @@ class API(base.Base):
{"method": "refresh_security_group_members", {"method": "refresh_security_group_members",
"args": {"security_group_id": group_id}}) "args": {"security_group_id": group_id}})
def trigger_provider_fw_rules_refresh(self, context):
"""Called when a rule is added to or removed from a security_group"""
hosts = [x['host'] for (x, idx)
in db.service_get_all_compute_sorted(context)]
for host in hosts:
rpc.cast(context,
self.db.queue_get_for(context, FLAGS.compute_topic, host),
{'method': 'refresh_provider_fw_rules', 'args': {}})
def update(self, context, instance_id, **kwargs): def update(self, context, instance_id, **kwargs):
"""Updates the instance in the datastore. """Updates the instance in the datastore.

View File

@ -215,6 +215,11 @@ class ComputeManager(manager.SchedulerDependentManager):
""" """
return self.driver.refresh_security_group_members(security_group_id) return self.driver.refresh_security_group_members(security_group_id)
@exception.wrap_exception
def refresh_provider_fw_rules(self, context, **_kwargs):
"""This call passes straight through to the virtualization driver."""
return self.driver.refresh_provider_fw_rules()
def _setup_block_device_mapping(self, context, instance_id): def _setup_block_device_mapping(self, context, instance_id):
"""setup volumes for block device mapping""" """setup volumes for block device mapping"""
self.db.instance_set_state(context, self.db.instance_set_state(context,

View File

@ -1034,6 +1034,19 @@ def security_group_rule_destroy(context, security_group_rule_id):
################### ###################
def provider_fw_rule_create(context, rule):
"""Add a firewall rule at the provider level (all hosts & instances)."""
return IMPL.provider_fw_rule_create(context, rule)
def provider_fw_rule_get_all(context):
"""Get all provider-level firewall rules."""
return IMPL.provider_fw_rule_get_all(context)
###################
def user_get(context, id): def user_get(context, id):
"""Get user by id.""" """Get user by id."""
return IMPL.user_get(context, id) return IMPL.user_get(context, id)

View File

@ -2180,6 +2180,25 @@ def security_group_rule_destroy(context, security_group_rule_id):
################### ###################
@require_admin_context
def provider_fw_rule_create(context, rule):
fw_rule_ref = models.ProviderFirewallRule()
fw_rule_ref.update(rule)
fw_rule_ref.save()
return fw_rule_ref
def provider_fw_rule_get_all(context):
session = get_session()
return session.query(models.ProviderFirewallRule).\
filter_by(deleted=can_read_deleted(context)).\
all()
###################
@require_admin_context @require_admin_context
def user_get(context, id, session=None): def user_get(context, id, session=None):
if not session: if not session:

View File

@ -0,0 +1,75 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from sqlalchemy import *
from migrate import *
from nova import log as logging
meta = MetaData()
# Just for the ForeignKey and column creation to succeed, these are not the
# actual definitions of instances or services.
instances = Table('instances', meta,
Column('id', Integer(), primary_key=True, nullable=False),
)
services = Table('services', meta,
Column('id', Integer(), primary_key=True, nullable=False),
)
networks = Table('networks', meta,
Column('id', Integer(), primary_key=True, nullable=False),
)
#
# New Tables
#
provider_fw_rules = Table('provider_fw_rules', meta,
Column('created_at', DateTime(timezone=False)),
Column('updated_at', DateTime(timezone=False)),
Column('deleted_at', DateTime(timezone=False)),
Column('deleted', Boolean(create_constraint=True, name=None)),
Column('id', Integer(), primary_key=True, nullable=False),
Column('protocol',
String(length=5, convert_unicode=False, assert_unicode=None,
unicode_error=None, _warn_on_bytestring=False)),
Column('from_port', Integer()),
Column('to_port', Integer()),
Column('cidr',
String(length=255, convert_unicode=False, assert_unicode=None,
unicode_error=None, _warn_on_bytestring=False))
)
def upgrade(migrate_engine):
# Upgrade operations go here. Don't create your own engine;
# bind migrate_engine to your metadata
meta.bind = migrate_engine
for table in (provider_fw_rules,):
try:
table.create()
except Exception:
logging.info(repr(table))
logging.exception('Exception while creating table')
raise

View File

@ -493,6 +493,17 @@ class SecurityGroupIngressRule(BASE, NovaBase):
group_id = Column(Integer, ForeignKey('security_groups.id')) group_id = Column(Integer, ForeignKey('security_groups.id'))
class ProviderFirewallRule(BASE, NovaBase):
"""Represents a rule in a security group."""
__tablename__ = 'provider_fw_rules'
id = Column(Integer, primary_key=True)
protocol = Column(String(5)) # "tcp", "udp", or "icmp"
from_port = Column(Integer)
to_port = Column(Integer)
cidr = Column(String(255))
class KeyPair(BASE, NovaBase): class KeyPair(BASE, NovaBase):
"""Represents a public key pair for ssh.""" """Represents a public key pair for ssh."""
__tablename__ = 'key_pairs' __tablename__ = 'key_pairs'

View File

@ -191,6 +191,13 @@ class IptablesTable(object):
{'chain': chain, 'rule': rule, {'chain': chain, 'rule': rule,
'top': top, 'wrap': wrap}) 'top': top, 'wrap': wrap})
def empty_chain(self, chain, wrap=True):
"""Remove all rules from a chain."""
chained_rules = [rule for rule in self.rules
if rule.chain == chain and rule.wrap == wrap]
for rule in chained_rules:
self.rules.remove(rule)
class IptablesManager(object): class IptablesManager(object):
"""Wrapper for iptables. """Wrapper for iptables.

View File

@ -0,0 +1,89 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from eventlet import greenthread
from nova import context
from nova import db
from nova import flags
from nova import log as logging
from nova import rpc
from nova import test
from nova import utils
from nova.auth import manager
from nova.api.ec2 import admin
from nova.image import fake
FLAGS = flags.FLAGS
LOG = logging.getLogger('nova.tests.adminapi')
class AdminApiTestCase(test.TestCase):
def setUp(self):
super(AdminApiTestCase, self).setUp()
self.flags(connection_type='fake')
self.conn = rpc.Connection.instance()
# set up our cloud
self.api = admin.AdminController()
# set up services
self.compute = self.start_service('compute')
self.scheduter = self.start_service('scheduler')
self.network = self.start_service('network')
self.volume = self.start_service('volume')
self.image_service = utils.import_object(FLAGS.image_service)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
self.project = self.manager.create_project('proj', 'admin', 'proj')
self.context = context.RequestContext(user=self.user,
project=self.project)
host = self.network.get_network_host(self.context.elevated())
def fake_show(meh, context, id):
return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1,
'type': 'machine', 'image_state': 'available'}}
self.stubs.Set(fake._FakeImageService, 'show', fake_show)
self.stubs.Set(fake._FakeImageService, 'show_by_name', fake_show)
# NOTE(vish): set up a manual wait so rpc.cast has a chance to finish
rpc_cast = rpc.cast
def finish_cast(*args, **kwargs):
rpc_cast(*args, **kwargs)
greenthread.sleep(0.2)
self.stubs.Set(rpc, 'cast', finish_cast)
def tearDown(self):
network_ref = db.project_get_network(self.context,
self.project.id)
db.network_disassociate(self.context, network_ref['id'])
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
super(AdminApiTestCase, self).tearDown()
def test_block_external_ips(self):
"""Make sure provider firewall rules are created."""
result = self.api.block_external_addresses(self.context, '1.1.1.1/32')
self.assertEqual('OK', result['status'])
self.assertEqual('Added 3 rules', result['message'])

View File

@ -799,6 +799,8 @@ class IptablesFirewallTestCase(test.TestCase):
self.network = utils.import_object(FLAGS.network_manager) self.network = utils.import_object(FLAGS.network_manager)
class FakeLibvirtConnection(object): class FakeLibvirtConnection(object):
def nwfilterDefineXML(*args, **kwargs):
"""setup_basic_rules in nwfilter calls this."""
pass pass
self.fake_libvirt_connection = FakeLibvirtConnection() self.fake_libvirt_connection = FakeLibvirtConnection()
self.fw = firewall.IptablesFirewallDriver( self.fw = firewall.IptablesFirewallDriver(
@ -1035,7 +1037,6 @@ class IptablesFirewallTestCase(test.TestCase):
fakefilter.filterDefineXMLMock fakefilter.filterDefineXMLMock
self.fw.nwfilter._conn.nwfilterLookupByName =\ self.fw.nwfilter._conn.nwfilterLookupByName =\
fakefilter.nwfilterLookupByName fakefilter.nwfilterLookupByName
instance_ref = self._create_instance_ref() instance_ref = self._create_instance_ref()
inst_id = instance_ref['id'] inst_id = instance_ref['id']
instance = db.instance_get(self.context, inst_id) instance = db.instance_get(self.context, inst_id)
@ -1057,6 +1058,63 @@ class IptablesFirewallTestCase(test.TestCase):
db.instance_destroy(admin_ctxt, instance_ref['id']) db.instance_destroy(admin_ctxt, instance_ref['id'])
def test_provider_firewall_rules(self):
# setup basic instance data
instance_ref = self._create_instance_ref()
nw_info = _create_network_info(1)
ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, 'fake')
admin_ctxt = context.get_admin_context()
fixed_ip = {'address': ip, 'network_id': network_ref['id']}
db.fixed_ip_create(admin_ctxt, fixed_ip)
db.fixed_ip_update(admin_ctxt, ip, {'allocated': True,
'instance_id': instance_ref['id']})
# FRAGILE: peeks at how the firewall names chains
chain_name = 'inst-%s' % instance_ref['id']
# create a firewall via setup_basic_filtering like libvirt_conn.spawn
# should have a chain with 0 rules
self.fw.setup_basic_filtering(instance_ref, network_info=nw_info)
self.assertTrue('provider' in self.fw.iptables.ipv4['filter'].chains)
rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
if rule.chain == 'provider']
self.assertEqual(0, len(rules))
# add a rule and send the update message, check for 1 rule
provider_fw0 = db.provider_fw_rule_create(admin_ctxt,
{'protocol': 'tcp',
'cidr': '10.99.99.99/32',
'from_port': 1,
'to_port': 65535})
self.fw.refresh_provider_fw_rules()
rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
if rule.chain == 'provider']
self.assertEqual(1, len(rules))
# Add another, refresh, and make sure number of rules goes to two
provider_fw1 = db.provider_fw_rule_create(admin_ctxt,
{'protocol': 'udp',
'cidr': '10.99.99.99/32',
'from_port': 1,
'to_port': 65535})
self.fw.refresh_provider_fw_rules()
rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
if rule.chain == 'provider']
self.assertEqual(2, len(rules))
# create the instance filter and make sure it has a jump rule
self.fw.prepare_instance_filter(instance_ref, network_info=nw_info)
self.fw.apply_instance_filter(instance_ref)
inst_rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
if rule.chain == chain_name]
jump_rules = [rule for rule in inst_rules if '-j' in rule.rule]
provjump_rules = []
# IptablesTable doesn't make rules unique internally
for rule in jump_rules:
if 'provider' in rule.rule and rule not in provjump_rules:
provjump_rules.append(rule)
self.assertEqual(1, len(provjump_rules))
class NWFilterTestCase(test.TestCase): class NWFilterTestCase(test.TestCase):
def setUp(self): def setUp(self):

View File

@ -164,3 +164,33 @@ class IptablesManagerTestCase(test.TestCase):
self.assertTrue('-A %s -j run_tests.py-%s' \ self.assertTrue('-A %s -j run_tests.py-%s' \
% (chain, chain) in new_lines, % (chain, chain) in new_lines,
"Built-in chain %s not wrapped" % (chain,)) "Built-in chain %s not wrapped" % (chain,))
def test_will_empty_chain(self):
self.manager.ipv4['filter'].add_chain('test-chain')
self.manager.ipv4['filter'].add_rule('test-chain', '-j DROP')
old_count = len(self.manager.ipv4['filter'].rules)
self.manager.ipv4['filter'].empty_chain('test-chain')
self.assertEqual(old_count - 1, len(self.manager.ipv4['filter'].rules))
def test_will_empty_unwrapped_chain(self):
self.manager.ipv4['filter'].add_chain('test-chain', wrap=False)
self.manager.ipv4['filter'].add_rule('test-chain', '-j DROP',
wrap=False)
old_count = len(self.manager.ipv4['filter'].rules)
self.manager.ipv4['filter'].empty_chain('test-chain', wrap=False)
self.assertEqual(old_count - 1, len(self.manager.ipv4['filter'].rules))
def test_will_not_empty_wrapped_when_unwrapped(self):
self.manager.ipv4['filter'].add_chain('test-chain')
self.manager.ipv4['filter'].add_rule('test-chain', '-j DROP')
old_count = len(self.manager.ipv4['filter'].rules)
self.manager.ipv4['filter'].empty_chain('test-chain', wrap=False)
self.assertEqual(old_count, len(self.manager.ipv4['filter'].rules))
def test_will_not_empty_unwrapped_when_wrapped(self):
self.manager.ipv4['filter'].add_chain('test-chain', wrap=False)
self.manager.ipv4['filter'].add_rule('test-chain', '-j DROP',
wrap=False)
old_count = len(self.manager.ipv4['filter'].rules)
self.manager.ipv4['filter'].empty_chain('test-chain')
self.assertEqual(old_count, len(self.manager.ipv4['filter'].rules))

View File

@ -191,6 +191,10 @@ class ComputeDriver(object):
def refresh_security_group_members(self, security_group_id): def refresh_security_group_members(self, security_group_id):
raise NotImplementedError() raise NotImplementedError()
def refresh_provider_fw_rules(self, security_group_id):
"""See: nova/virt/fake.py for docs."""
raise NotImplementedError()
def reset_network(self, instance): def reset_network(self, instance):
"""reset networking for specified instance""" """reset networking for specified instance"""
raise NotImplementedError() raise NotImplementedError()

View File

@ -466,6 +466,22 @@ class FakeConnection(driver.ComputeDriver):
""" """
return True return True
def refresh_provider_fw_rules(self):
"""This triggers a firewall update based on database changes.
When this is called, rules have either been added or removed from the
datastore. You can retrieve rules with
:method:`nova.db.api.provider_fw_rule_get_all`.
Provider rules take precedence over security group rules. If an IP
would be allowed by a security group ingress rule, but blocked by
a provider rule, then packets from the IP are dropped. This includes
intra-project traffic in the case of the allow_project_net_traffic
flag for the libvirt-derived classes.
"""
pass
def update_available_resource(self, ctxt, host): def update_available_resource(self, ctxt, host):
"""This method is supported only by libvirt.""" """This method is supported only by libvirt."""
return return

View File

@ -1383,6 +1383,9 @@ class LibvirtConnection(driver.ComputeDriver):
def refresh_security_group_members(self, security_group_id): def refresh_security_group_members(self, security_group_id):
self.firewall_driver.refresh_security_group_members(security_group_id) self.firewall_driver.refresh_security_group_members(security_group_id)
def refresh_provider_fw_rules(self):
self.firewall_driver.refresh_provider_fw_rules()
def update_available_resource(self, ctxt, host): def update_available_resource(self, ctxt, host):
"""Updates compute manager resource info on ComputeNode table. """Updates compute manager resource info on ComputeNode table.

View File

@ -76,6 +76,15 @@ class FirewallDriver(object):
the security group.""" the security group."""
raise NotImplementedError() raise NotImplementedError()
def refresh_provider_fw_rules(self):
"""Refresh common rules for all hosts/instances from data store.
Gets called when a rule has been added to or removed from
the list of rules (via admin api).
"""
raise NotImplementedError()
def setup_basic_filtering(self, instance, network_info=None): def setup_basic_filtering(self, instance, network_info=None):
"""Create rules to block spoofing and allow dhcp. """Create rules to block spoofing and allow dhcp.
@ -207,6 +216,13 @@ class NWFilterFirewall(FirewallDriver):
[base_filter])) [base_filter]))
def _ensure_static_filters(self): def _ensure_static_filters(self):
"""Static filters are filters that have no need to be IP aware.
There is no configuration or tuneability of these filters, so they
can be set up once and forgotten about.
"""
if self.static_filters_configured: if self.static_filters_configured:
return return
@ -310,19 +326,21 @@ class NWFilterFirewall(FirewallDriver):
'for %(instance_name)s is not found.') % locals()) 'for %(instance_name)s is not found.') % locals())
def prepare_instance_filter(self, instance, network_info=None): def prepare_instance_filter(self, instance, network_info=None):
""" """Creates an NWFilter for the given instance.
Creates an NWFilter for the given instance. In the process,
it makes sure the filters for the security groups as well as In the process, it makes sure the filters for the provider blocks,
the base filter are all in place. security groups, and base filter are all in place.
""" """
if not network_info: if not network_info:
network_info = netutils.get_network_info(instance) network_info = netutils.get_network_info(instance)
self.refresh_provider_fw_rules()
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
instance_secgroup_filter_name = \ instance_secgroup_filter_name = \
'%s-secgroup' % (self._instance_filter_name(instance)) '%s-secgroup' % (self._instance_filter_name(instance))
#% (instance_filter_name,)
instance_secgroup_filter_children = ['nova-base-ipv4', instance_secgroup_filter_children = ['nova-base-ipv4',
'nova-base-ipv6', 'nova-base-ipv6',
@ -366,7 +384,7 @@ class NWFilterFirewall(FirewallDriver):
for (_n, mapping) in network_info: for (_n, mapping) in network_info:
nic_id = mapping['mac'].replace(':', '') nic_id = mapping['mac'].replace(':', '')
instance_filter_name = self._instance_filter_name(instance, nic_id) instance_filter_name = self._instance_filter_name(instance, nic_id)
instance_filter_children = [base_filter, instance_filter_children = [base_filter, 'nova-provider-rules',
instance_secgroup_filter_name] instance_secgroup_filter_name]
if FLAGS.allow_project_net_traffic: if FLAGS.allow_project_net_traffic:
@ -388,6 +406,19 @@ class NWFilterFirewall(FirewallDriver):
return self._define_filter( return self._define_filter(
self.security_group_to_nwfilter_xml(security_group_id)) self.security_group_to_nwfilter_xml(security_group_id))
def refresh_provider_fw_rules(self):
"""Update rules for all instances.
This is part of the FirewallDriver API and is called when the
provider firewall rules change in the database. In the
`prepare_instance_filter` we add a reference to the
'nova-provider-rules' filter for each instance's firewall, and
by changing that filter we update them all.
"""
xml = self.provider_fw_to_nwfilter_xml()
return self._define_filter(xml)
def security_group_to_nwfilter_xml(self, security_group_id): def security_group_to_nwfilter_xml(self, security_group_id):
security_group = db.security_group_get(context.get_admin_context(), security_group = db.security_group_get(context.get_admin_context(),
security_group_id) security_group_id)
@ -426,6 +457,43 @@ class NWFilterFirewall(FirewallDriver):
xml += "chain='ipv4'>%s</filter>" % rule_xml xml += "chain='ipv4'>%s</filter>" % rule_xml
return xml return xml
def provider_fw_to_nwfilter_xml(self):
"""Compose a filter of drop rules from specified cidrs."""
rule_xml = ""
v6protocol = {'tcp': 'tcp-ipv6', 'udp': 'udp-ipv6', 'icmp': 'icmpv6'}
rules = db.provider_fw_rule_get_all(context.get_admin_context())
for rule in rules:
rule_xml += "<rule action='block' direction='in' priority='150'>"
version = netutils.get_ip_version(rule.cidr)
if(FLAGS.use_ipv6 and version == 6):
net, prefixlen = netutils.get_net_and_prefixlen(rule.cidr)
rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % \
(v6protocol[rule.protocol], net, prefixlen)
else:
net, mask = netutils.get_net_and_mask(rule.cidr)
rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % \
(rule.protocol, net, mask)
if rule.protocol in ['tcp', 'udp']:
rule_xml += "dstportstart='%s' dstportend='%s' " % \
(rule.from_port, rule.to_port)
elif rule.protocol == 'icmp':
LOG.info('rule.protocol: %r, rule.from_port: %r, '
'rule.to_port: %r', rule.protocol,
rule.from_port, rule.to_port)
if rule.from_port != -1:
rule_xml += "type='%s' " % rule.from_port
if rule.to_port != -1:
rule_xml += "code='%s' " % rule.to_port
rule_xml += '/>\n'
rule_xml += "</rule>\n"
xml = "<filter name='nova-provider-rules' "
if(FLAGS.use_ipv6):
xml += "chain='root'>%s</filter>" % rule_xml
else:
xml += "chain='ipv4'>%s</filter>" % rule_xml
return xml
def _instance_filter_name(self, instance, nic_id=None): def _instance_filter_name(self, instance, nic_id=None):
if not nic_id: if not nic_id:
return 'nova-instance-%s' % (instance['name']) return 'nova-instance-%s' % (instance['name'])
@ -453,6 +521,7 @@ class IptablesFirewallDriver(FirewallDriver):
self.iptables = linux_net.iptables_manager self.iptables = linux_net.iptables_manager
self.instances = {} self.instances = {}
self.nwfilter = NWFilterFirewall(kwargs['get_connection']) self.nwfilter = NWFilterFirewall(kwargs['get_connection'])
self.basicly_filtered = False
self.iptables.ipv4['filter'].add_chain('sg-fallback') self.iptables.ipv4['filter'].add_chain('sg-fallback')
self.iptables.ipv4['filter'].add_rule('sg-fallback', '-j DROP') self.iptables.ipv4['filter'].add_rule('sg-fallback', '-j DROP')
@ -460,10 +529,14 @@ class IptablesFirewallDriver(FirewallDriver):
self.iptables.ipv6['filter'].add_rule('sg-fallback', '-j DROP') self.iptables.ipv6['filter'].add_rule('sg-fallback', '-j DROP')
def setup_basic_filtering(self, instance, network_info=None): def setup_basic_filtering(self, instance, network_info=None):
"""Use NWFilter from libvirt for this.""" """Set up provider rules and basic NWFilter."""
if not network_info: if not network_info:
network_info = netutils.get_network_info(instance) network_info = netutils.get_network_info(instance)
return self.nwfilter.setup_basic_filtering(instance, network_info) self.nwfilter.setup_basic_filtering(instance, network_info)
if not self.basicly_filtered:
LOG.debug(_('iptables firewall: Setup Basic Filtering'))
self.refresh_provider_fw_rules()
self.basicly_filtered = True
def apply_instance_filter(self, instance): def apply_instance_filter(self, instance):
"""No-op. Everything is done in prepare_instance_filter""" """No-op. Everything is done in prepare_instance_filter"""
@ -543,6 +616,10 @@ class IptablesFirewallDriver(FirewallDriver):
ipv4_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] ipv4_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT']
ipv6_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] ipv6_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT']
# Pass through provider-wide drops
ipv4_rules += ['-j $provider']
ipv6_rules += ['-j $provider']
dhcp_servers = [network['gateway'] for (network, _m) in network_info] dhcp_servers = [network['gateway'] for (network, _m) in network_info]
for dhcp_server in dhcp_servers: for dhcp_server in dhcp_servers:
@ -560,7 +637,7 @@ class IptablesFirewallDriver(FirewallDriver):
# they're not worth the clutter. # they're not worth the clutter.
if FLAGS.use_ipv6: if FLAGS.use_ipv6:
# Allow RA responses # Allow RA responses
gateways_v6 = [network['gateway_v6'] for (network, _) in gateways_v6 = [network['gateway_v6'] for (network, _m) in
network_info] network_info]
for gateway_v6 in gateways_v6: for gateway_v6 in gateways_v6:
ipv6_rules.append( ipv6_rules.append(
@ -583,7 +660,7 @@ class IptablesFirewallDriver(FirewallDriver):
security_group['id']) security_group['id'])
for rule in rules: for rule in rules:
logging.info('%r', rule) LOG.debug(_('Adding security group rule: %r'), rule)
if not rule.cidr: if not rule.cidr:
# Eventually, a mechanism to grant access for security # Eventually, a mechanism to grant access for security
@ -592,9 +669,9 @@ class IptablesFirewallDriver(FirewallDriver):
version = netutils.get_ip_version(rule.cidr) version = netutils.get_ip_version(rule.cidr)
if version == 4: if version == 4:
rules = ipv4_rules fw_rules = ipv4_rules
else: else:
rules = ipv6_rules fw_rules = ipv6_rules
protocol = rule.protocol protocol = rule.protocol
if version == 6 and rule.protocol == 'icmp': if version == 6 and rule.protocol == 'icmp':
@ -629,7 +706,7 @@ class IptablesFirewallDriver(FirewallDriver):
icmp_type_arg] icmp_type_arg]
args += ['-j ACCEPT'] args += ['-j ACCEPT']
rules += [' '.join(args)] fw_rules += [' '.join(args)]
ipv4_rules += ['-j $sg-fallback'] ipv4_rules += ['-j $sg-fallback']
ipv6_rules += ['-j $sg-fallback'] ipv6_rules += ['-j $sg-fallback']
@ -657,6 +734,85 @@ class IptablesFirewallDriver(FirewallDriver):
network_info = netutils.get_network_info(instance) network_info = netutils.get_network_info(instance)
self.add_filters_for_instance(instance, network_info) self.add_filters_for_instance(instance, network_info)
def refresh_provider_fw_rules(self):
"""See class:FirewallDriver: docs."""
self._do_refresh_provider_fw_rules()
self.iptables.apply()
@utils.synchronized('iptables', external=True)
def _do_refresh_provider_fw_rules(self):
"""Internal, synchronized version of refresh_provider_fw_rules."""
self._purge_provider_fw_rules()
self._build_provider_fw_rules()
def _purge_provider_fw_rules(self):
"""Remove all rules from the provider chains."""
self.iptables.ipv4['filter'].empty_chain('provider')
if FLAGS.use_ipv6:
self.iptables.ipv6['filter'].empty_chain('provider')
def _build_provider_fw_rules(self):
"""Create all rules for the provider IP DROPs."""
self.iptables.ipv4['filter'].add_chain('provider')
if FLAGS.use_ipv6:
self.iptables.ipv6['filter'].add_chain('provider')
ipv4_rules, ipv6_rules = self._provider_rules()
for rule in ipv4_rules:
self.iptables.ipv4['filter'].add_rule('provider', rule)
if FLAGS.use_ipv6:
for rule in ipv6_rules:
self.iptables.ipv6['filter'].add_rule('provider', rule)
def _provider_rules(self):
"""Generate a list of rules from provider for IP4 & IP6."""
ctxt = context.get_admin_context()
ipv4_rules = []
ipv6_rules = []
rules = db.provider_fw_rule_get_all(ctxt)
for rule in rules:
LOG.debug(_('Adding provider rule: %s'), rule['cidr'])
version = netutils.get_ip_version(rule['cidr'])
if version == 4:
fw_rules = ipv4_rules
else:
fw_rules = ipv6_rules
protocol = rule['protocol']
if version == 6 and protocol == 'icmp':
protocol = 'icmpv6'
args = ['-p', protocol, '-s', rule['cidr']]
if protocol in ['udp', 'tcp']:
if rule['from_port'] == rule['to_port']:
args += ['--dport', '%s' % (rule['from_port'],)]
else:
args += ['-m', 'multiport',
'--dports', '%s:%s' % (rule['from_port'],
rule['to_port'])]
elif protocol == 'icmp':
icmp_type = rule['from_port']
icmp_code = rule['to_port']
if icmp_type == -1:
icmp_type_arg = None
else:
icmp_type_arg = '%s' % icmp_type
if not icmp_code == -1:
icmp_type_arg += '/%s' % icmp_code
if icmp_type_arg:
if version == 4:
args += ['-m', 'icmp', '--icmp-type',
icmp_type_arg]
elif version == 6:
args += ['-m', 'icmp6', '--icmpv6-type',
icmp_type_arg]
args += ['-j DROP']
fw_rules += [' '.join(args)]
return ipv4_rules, ipv6_rules
def _security_group_chain_name(self, security_group_id): def _security_group_chain_name(self, security_group_id):
return 'nova-sg-%s' % (security_group_id,) return 'nova-sg-%s' % (security_group_id,)