# vim: tabstop=4 shiftwidth=4 softtabstop=4 # # Copyright 2010 OpenStack LLC # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. from xml.etree.ElementTree import fromstring as xml_to_tree from xml.dom.minidom import parseString as xml_to_dom from nova import context from nova import db from nova import flags from nova import test from nova import utils from nova.api.ec2 import cloud from nova.auth import manager from nova.virt import libvirt_conn FLAGS = flags.FLAGS flags.DECLARE('instances_path', 'nova.compute.manager') class LibvirtConnTestCase(test.TestCase): def setUp(self): super(LibvirtConnTestCase, self).setUp() libvirt_conn._late_load_cheetah() self.flags(fake_call=True) self.manager = manager.AuthManager() self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) self.project = self.manager.create_project('fake', 'fake', 'fake') self.network = utils.import_object(FLAGS.network_manager) FLAGS.instances_path = '' test_ip = '10.11.12.13' test_instance = {'memory_kb': '1024000', 'basepath': '/some/path', 'bridge_name': 'br100', 'mac_address': '02:12:34:46:56:67', 'vcpus': 2, 'project_id': 'fake', 'bridge': 'br101', 'instance_type': 'm1.small'} def test_xml_and_uri_no_ramdisk_no_kernel(self): instance_data = dict(self.test_instance) self._check_xml_and_uri(instance_data, expect_kernel=False, expect_ramdisk=False) def test_xml_and_uri_no_ramdisk(self): instance_data = dict(self.test_instance) instance_data['kernel_id'] = 'aki-deadbeef' self._check_xml_and_uri(instance_data, expect_kernel=True, expect_ramdisk=False) def test_xml_and_uri_no_kernel(self): instance_data = dict(self.test_instance) instance_data['ramdisk_id'] = 'ari-deadbeef' self._check_xml_and_uri(instance_data, expect_kernel=False, expect_ramdisk=False) def test_xml_and_uri(self): instance_data = dict(self.test_instance) instance_data['ramdisk_id'] = 'ari-deadbeef' instance_data['kernel_id'] = 'aki-deadbeef' self._check_xml_and_uri(instance_data, expect_kernel=True, expect_ramdisk=True) def test_xml_and_uri_rescue(self): instance_data = dict(self.test_instance) instance_data['ramdisk_id'] = 'ari-deadbeef' instance_data['kernel_id'] = 'aki-deadbeef' self._check_xml_and_uri(instance_data, expect_kernel=True, expect_ramdisk=True, rescue=True) def _check_xml_and_uri(self, instance, expect_ramdisk, expect_kernel, rescue=False): user_context = context.RequestContext(project=self.project, user=self.user) instance_ref = db.instance_create(user_context, instance) host = self.network.get_network_host(user_context.elevated()) network_ref = db.project_get_network(context.get_admin_context(), self.project.id) fixed_ip = {'address': self.test_ip, 'network_id': network_ref['id']} ctxt = context.get_admin_context() fixed_ip_ref = db.fixed_ip_create(ctxt, fixed_ip) db.fixed_ip_update(ctxt, self.test_ip, {'allocated': True, 'instance_id': instance_ref['id']}) type_uri_map = {'qemu': ('qemu:///system', [(lambda t: t.find('.').get('type'), 'qemu'), (lambda t: t.find('./os/type').text, 'hvm'), (lambda t: t.find('./devices/emulator'), None)]), 'kvm': ('qemu:///system', [(lambda t: t.find('.').get('type'), 'kvm'), (lambda t: t.find('./os/type').text, 'hvm'), (lambda t: t.find('./devices/emulator'), None)]), 'uml': ('uml:///system', [(lambda t: t.find('.').get('type'), 'uml'), (lambda t: t.find('./os/type').text, 'uml')]), 'xen': ('xen:///', [(lambda t: t.find('.').get('type'), 'xen'), (lambda t: t.find('./os/type').text, 'linux')]), } for hypervisor_type in ['qemu', 'kvm', 'xen']: check_list = type_uri_map[hypervisor_type][1] if rescue: check = (lambda t: t.find('./os/kernel').text.split('/')[1], 'kernel.rescue') check_list.append(check) check = (lambda t: t.find('./os/initrd').text.split('/')[1], 'ramdisk.rescue') check_list.append(check) else: if expect_kernel: check = (lambda t: t.find('./os/kernel').text.split( '/')[1], 'kernel') else: check = (lambda t: t.find('./os/kernel'), None) check_list.append(check) if expect_ramdisk: check = (lambda t: t.find('./os/initrd').text.split( '/')[1], 'ramdisk') else: check = (lambda t: t.find('./os/initrd'), None) check_list.append(check) common_checks = [ (lambda t: t.find('.').tag, 'domain'), (lambda t: t.find( './devices/interface/filterref/parameter').get('name'), 'IP'), (lambda t: t.find( './devices/interface/filterref/parameter').get( 'value'), '10.11.12.13'), (lambda t: t.findall( './devices/interface/filterref/parameter')[1].get( 'name'), 'DHCPSERVER'), (lambda t: t.findall( './devices/interface/filterref/parameter')[1].get( 'value'), '10.0.0.1'), (lambda t: t.find('./devices/serial/source').get( 'path').split('/')[1], 'console.log'), (lambda t: t.find('./memory').text, '2097152')] if rescue: common_checks += [ (lambda t: t.findall('./devices/disk/source')[0].get( 'file').split('/')[1], 'disk.rescue'), (lambda t: t.findall('./devices/disk/source')[1].get( 'file').split('/')[1], 'disk')] else: common_checks += [(lambda t: t.findall( './devices/disk/source')[0].get('file').split('/')[1], 'disk')] common_checks += [(lambda t: t.findall( './devices/disk/source')[1].get('file').split('/')[1], 'disk.local')] for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems(): FLAGS.libvirt_type = libvirt_type conn = libvirt_conn.LibvirtConnection(True) uri = conn.get_uri() self.assertEquals(uri, expected_uri) xml = conn.to_xml(instance_ref, rescue) tree = xml_to_tree(xml) for i, (check, expected_result) in enumerate(checks): self.assertEqual(check(tree), expected_result, '%s failed check %d' % (xml, i)) for i, (check, expected_result) in enumerate(common_checks): self.assertEqual(check(tree), expected_result, '%s failed common check %d' % (xml, i)) # This test is supposed to make sure we don't override a specifically # set uri # # Deliberately not just assigning this string to FLAGS.libvirt_uri and # checking against that later on. This way we make sure the # implementation doesn't fiddle around with the FLAGS. testuri = 'something completely different' FLAGS.libvirt_uri = testuri for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems(): FLAGS.libvirt_type = libvirt_type conn = libvirt_conn.LibvirtConnection(True) uri = conn.get_uri() self.assertEquals(uri, testuri) def tearDown(self): super(LibvirtConnTestCase, self).tearDown() self.manager.delete_project(self.project) self.manager.delete_user(self.user) class IptablesFirewallTestCase(test.TestCase): def setUp(self): super(IptablesFirewallTestCase, self).setUp() self.manager = manager.AuthManager() self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) self.project = self.manager.create_project('fake', 'fake', 'fake') self.context = context.RequestContext('fake', 'fake') self.network = utils.import_object(FLAGS.network_manager) self.fw = libvirt_conn.IptablesFirewallDriver() def tearDown(self): self.manager.delete_project(self.project) self.manager.delete_user(self.user) super(IptablesFirewallTestCase, self).tearDown() in_rules = [ '# Generated by iptables-save v1.4.4 on Mon Dec 6 11:54:13 2010', '*filter', ':INPUT ACCEPT [969615:281627771]', ':FORWARD ACCEPT [0:0]', ':OUTPUT ACCEPT [915599:63811649]', ':nova-block-ipv4 - [0:0]', '-A INPUT -i virbr0 -p udp -m udp --dport 53 -j ACCEPT ', '-A INPUT -i virbr0 -p tcp -m tcp --dport 53 -j ACCEPT ', '-A INPUT -i virbr0 -p udp -m udp --dport 67 -j ACCEPT ', '-A INPUT -i virbr0 -p tcp -m tcp --dport 67 -j ACCEPT ', '-A FORWARD -d 192.168.122.0/24 -o virbr0 -m state --state RELATED' ',ESTABLISHED -j ACCEPT ', '-A FORWARD -s 192.168.122.0/24 -i virbr0 -j ACCEPT ', '-A FORWARD -i virbr0 -o virbr0 -j ACCEPT ', '-A FORWARD -o virbr0 -j REJECT --reject-with icmp-port-unreachable ', '-A FORWARD -i virbr0 -j REJECT --reject-with icmp-port-unreachable ', 'COMMIT', '# Completed on Mon Dec 6 11:54:13 2010', ] in6_rules = [ '# Generated by ip6tables-save v1.4.4 on Tue Jan 18 23:47:56 2011', '*filter', ':INPUT ACCEPT [349155:75810423]', ':FORWARD ACCEPT [0:0]', ':OUTPUT ACCEPT [349256:75777230]', 'COMMIT', '# Completed on Tue Jan 18 23:47:56 2011' ] def test_static_filters(self): instance_ref = db.instance_create(self.context, {'user_id': 'fake', 'project_id': 'fake', 'mac_address': '56:12:12:12:12:12'}) ip = '10.11.12.13' network_ref = db.project_get_network(self.context, 'fake') fixed_ip = {'address': ip, 'network_id': network_ref['id']} admin_ctxt = context.get_admin_context() db.fixed_ip_create(admin_ctxt, fixed_ip) db.fixed_ip_update(admin_ctxt, ip, {'allocated': True, 'instance_id': instance_ref['id']}) secgroup = db.security_group_create(admin_ctxt, {'user_id': 'fake', 'project_id': 'fake', 'name': 'testgroup', 'description': 'test group'}) db.security_group_rule_create(admin_ctxt, {'parent_group_id': secgroup['id'], 'protocol': 'icmp', 'from_port': -1, 'to_port': -1, 'cidr': '192.168.11.0/24'}) db.security_group_rule_create(admin_ctxt, {'parent_group_id': secgroup['id'], 'protocol': 'icmp', 'from_port': 8, 'to_port': -1, 'cidr': '192.168.11.0/24'}) db.security_group_rule_create(admin_ctxt, {'parent_group_id': secgroup['id'], 'protocol': 'tcp', 'from_port': 80, 'to_port': 81, 'cidr': '192.168.10.0/24'}) db.instance_add_security_group(admin_ctxt, instance_ref['id'], secgroup['id']) instance_ref = db.instance_get(admin_ctxt, instance_ref['id']) # self.fw.add_instance(instance_ref) def fake_iptables_execute(cmd, process_input=None): if cmd == 'sudo ip6tables-save -t filter': return '\n'.join(self.in6_rules), None if cmd == 'sudo iptables-save -t filter': return '\n'.join(self.in_rules), None if cmd == 'sudo iptables-restore': self.out_rules = process_input.split('\n') return '', '' if cmd == 'sudo ip6tables-restore': self.out6_rules = process_input.split('\n') return '', '' self.fw.execute = fake_iptables_execute self.fw.prepare_instance_filter(instance_ref) self.fw.apply_instance_filter(instance_ref) in_rules = filter(lambda l: not l.startswith('#'), self.in_rules) for rule in in_rules: if not 'nova' in rule: self.assertTrue(rule in self.out_rules, 'Rule went missing: %s' % rule) instance_chain = None for rule in self.out_rules: # This is pretty crude, but it'll do for now if '-d 10.11.12.13 -j' in rule: instance_chain = rule.split(' ')[-1] break self.assertTrue(instance_chain, "The instance chain wasn't added") security_group_chain = None for rule in self.out_rules: # This is pretty crude, but it'll do for now if '-A %s -j' % instance_chain in rule: security_group_chain = rule.split(' ')[-1] break self.assertTrue(security_group_chain, "The security group chain wasn't added") self.assertTrue('-A %s -p icmp -s 192.168.11.0/24 -j ACCEPT' % \ security_group_chain in self.out_rules, "ICMP acceptance rule wasn't added") self.assertTrue('-A %s -p icmp -s 192.168.11.0/24 -m icmp --icmp-type ' '8 -j ACCEPT' % security_group_chain in self.out_rules, "ICMP Echo Request acceptance rule wasn't added") self.assertTrue('-A %s -p tcp -s 192.168.10.0/24 -m multiport ' '--dports 80:81 -j ACCEPT' % security_group_chain \ in self.out_rules, "TCP port 80/81 acceptance rule wasn't added") class NWFilterTestCase(test.TestCase): def setUp(self): super(NWFilterTestCase, self).setUp() class Mock(object): pass self.manager = manager.AuthManager() self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) self.project = self.manager.create_project('fake', 'fake', 'fake') self.context = context.RequestContext(self.user, self.project) self.fake_libvirt_connection = Mock() self.fw = libvirt_conn.NWFilterFirewall( lambda: self.fake_libvirt_connection) def tearDown(self): self.manager.delete_project(self.project) self.manager.delete_user(self.user) def test_cidr_rule_nwfilter_xml(self): cloud_controller = cloud.CloudController() cloud_controller.create_security_group(self.context, 'testgroup', 'test group description') cloud_controller.authorize_security_group_ingress(self.context, 'testgroup', from_port='80', to_port='81', ip_protocol='tcp', cidr_ip='0.0.0.0/0') security_group = db.security_group_get_by_name(self.context, 'fake', 'testgroup') xml = self.fw.security_group_to_nwfilter_xml(security_group.id) dom = xml_to_dom(xml) self.assertEqual(dom.firstChild.tagName, 'filter') rules = dom.getElementsByTagName('rule') self.assertEqual(len(rules), 1) # It's supposed to allow inbound traffic. self.assertEqual(rules[0].getAttribute('action'), 'accept') self.assertEqual(rules[0].getAttribute('direction'), 'in') # Must be lower priority than the base filter (which blocks everything) self.assertTrue(int(rules[0].getAttribute('priority')) < 1000) ip_conditions = rules[0].getElementsByTagName('tcp') self.assertEqual(len(ip_conditions), 1) self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0') self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0') self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') self.teardown_security_group() def teardown_security_group(self): cloud_controller = cloud.CloudController() cloud_controller.delete_security_group(self.context, 'testgroup') def setup_and_return_security_group(self): cloud_controller = cloud.CloudController() cloud_controller.create_security_group(self.context, 'testgroup', 'test group description') cloud_controller.authorize_security_group_ingress(self.context, 'testgroup', from_port='80', to_port='81', ip_protocol='tcp', cidr_ip='0.0.0.0/0') return db.security_group_get_by_name(self.context, 'fake', 'testgroup') def test_creates_base_rule_first(self): # These come pre-defined by libvirt self.defined_filters = ['no-mac-spoofing', 'no-ip-spoofing', 'no-arp-spoofing', 'allow-dhcp-server'] self.recursive_depends = {} for f in self.defined_filters: self.recursive_depends[f] = [] def _filterDefineXMLMock(xml): dom = xml_to_dom(xml) name = dom.firstChild.getAttribute('name') self.recursive_depends[name] = [] for f in dom.getElementsByTagName('filterref'): ref = f.getAttribute('filter') self.assertTrue(ref in self.defined_filters, ('%s referenced filter that does ' + 'not yet exist: %s') % (name, ref)) dependencies = [ref] + self.recursive_depends[ref] self.recursive_depends[name] += dependencies self.defined_filters.append(name) return True self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock instance_ref = db.instance_create(self.context, {'user_id': 'fake', 'project_id': 'fake'}) inst_id = instance_ref['id'] def _ensure_all_called(): instance_filter = 'nova-instance-%s' % instance_ref['name'] secgroup_filter = 'nova-secgroup-%s' % self.security_group['id'] for required in [secgroup_filter, 'allow-dhcp-server', 'no-arp-spoofing', 'no-ip-spoofing', 'no-mac-spoofing']: self.assertTrue(required in self.recursive_depends[instance_filter], "Instance's filter does not include %s" % required) self.security_group = self.setup_and_return_security_group() db.instance_add_security_group(self.context, inst_id, self.security_group.id) instance = db.instance_get(self.context, inst_id) self.fw.setup_basic_filtering(instance) self.fw.prepare_instance_filter(instance) self.fw.apply_instance_filter(instance) _ensure_all_called() self.teardown_security_group()