Added more unit-test for multi-nic-nova libvirt

This commit is contained in:
Eldar Nugaev 2011-04-28 16:58:03 +00:00 committed by Tarmac
commit d0cfae6c59
2 changed files with 209 additions and 100 deletions

View File

@ -31,9 +31,7 @@ from nova import test
from nova import utils from nova import utils
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
from nova.auth import manager from nova.auth import manager
from nova.compute import manager as compute_manager
from nova.compute import power_state from nova.compute import power_state
from nova.db.sqlalchemy import models
from nova.virt import libvirt_conn from nova.virt import libvirt_conn
libvirt = None libvirt = None
@ -46,6 +44,22 @@ def _concurrency(wait, done, target):
done.send() done.send()
def _create_network_info(count=1):
fake = 'fake'
fake_ip = '0.0.0.0/0'
fake_ip_2 = '0.0.0.1/0'
fake_ip_3 = '0.0.0.1/0'
network = {'gateway': fake,
'gateway_v6': fake,
'bridge': fake,
'cidr': fake_ip,
'cidr_v6': fake_ip}
mapping = {'mac': fake,
'ips': [{'ip': fake_ip}, {'ip': fake_ip}],
'ip6s': [{'ip': fake_ip}, {'ip': fake_ip_2}, {'ip': fake_ip_3}]}
return [(network, mapping) for x in xrange(0, count)]
class CacheConcurrencyTestCase(test.TestCase): class CacheConcurrencyTestCase(test.TestCase):
def setUp(self): def setUp(self):
super(CacheConcurrencyTestCase, self).setUp() super(CacheConcurrencyTestCase, self).setUp()
@ -194,6 +208,37 @@ class LibvirtConnTestCase(test.TestCase):
return db.service_create(context.get_admin_context(), service_ref) return db.service_create(context.get_admin_context(), service_ref)
def test_preparing_xml_info(self):
conn = libvirt_conn.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, self.test_instance)
result = conn._prepare_xml_info(instance_ref, False)
self.assertFalse(result['nics'])
result = conn._prepare_xml_info(instance_ref, False,
_create_network_info())
self.assertTrue(len(result['nics']) == 1)
result = conn._prepare_xml_info(instance_ref, False,
_create_network_info(2))
self.assertTrue(len(result['nics']) == 2)
def test_get_nic_for_xml_v4(self):
conn = libvirt_conn.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=False)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
self.assertTrue(params.find('PROJNETV6') == -1)
self.assertTrue(params.find('PROJMASKV6') == -1)
def test_get_nic_for_xml_v6(self):
conn = libvirt_conn.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=True)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
self.assertTrue(params.find('PROJNETV6') > -1)
self.assertTrue(params.find('PROJMASKV6') > -1)
def test_xml_and_uri_no_ramdisk_no_kernel(self): def test_xml_and_uri_no_ramdisk_no_kernel(self):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
self._check_xml_and_uri(instance_data, self._check_xml_and_uri(instance_data,
@ -229,6 +274,22 @@ class LibvirtConnTestCase(test.TestCase):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
self._check_xml_and_container(instance_data) self._check_xml_and_container(instance_data)
def test_multi_nic(self):
instance_data = dict(self.test_instance)
network_info = _create_network_info(2)
conn = libvirt_conn.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, instance_data)
xml = conn.to_xml(instance_ref, False, network_info)
tree = xml_to_tree(xml)
interfaces = tree.findall("./devices/interface")
self.assertEquals(len(interfaces), 2)
parameters = interfaces[0].findall('./filterref/parameter')
self.assertEquals(interfaces[0].get('type'), 'bridge')
self.assertEquals(parameters[0].get('name'), 'IP')
self.assertEquals(parameters[0].get('value'), '0.0.0.0/0')
self.assertEquals(parameters[1].get('name'), 'DHCPSERVER')
self.assertEquals(parameters[1].get('value'), 'fake')
def _check_xml_and_container(self, instance): def _check_xml_and_container(self, instance):
user_context = context.RequestContext(project=self.project, user_context = context.RequestContext(project=self.project,
user=self.user) user=self.user)
@ -327,19 +388,13 @@ class LibvirtConnTestCase(test.TestCase):
check = (lambda t: t.find('./os/initrd'), None) check = (lambda t: t.find('./os/initrd'), None)
check_list.append(check) check_list.append(check)
parameter = './devices/interface/filterref/parameter'
common_checks = [ common_checks = [
(lambda t: t.find('.').tag, 'domain'), (lambda t: t.find('.').tag, 'domain'),
(lambda t: t.find( (lambda t: t.find(parameter).get('name'), 'IP'),
'./devices/interface/filterref/parameter').get('name'), 'IP'), (lambda t: t.find(parameter).get('value'), '10.11.12.13'),
(lambda t: t.find( (lambda t: t.findall(parameter)[1].get('name'), 'DHCPSERVER'),
'./devices/interface/filterref/parameter').get( (lambda t: t.findall(parameter)[1].get('value'), '10.0.0.1'),
'value'), '10.11.12.13'),
(lambda t: t.findall(
'./devices/interface/filterref/parameter')[1].get(
'name'), 'DHCPSERVER'),
(lambda t: t.findall(
'./devices/interface/filterref/parameter')[1].get(
'value'), '10.0.0.1'),
(lambda t: t.find('./devices/serial/source').get( (lambda t: t.find('./devices/serial/source').get(
'path').split('/')[1], 'console.log'), 'path').split('/')[1], 'console.log'),
(lambda t: t.find('./memory').text, '2097152')] (lambda t: t.find('./memory').text, '2097152')]
@ -651,12 +706,15 @@ class IptablesFirewallTestCase(test.TestCase):
'# Completed on Tue Jan 18 23:47:56 2011', '# Completed on Tue Jan 18 23:47:56 2011',
] ]
def test_static_filters(self): def _create_instance_ref(self):
instance_ref = db.instance_create(self.context, return db.instance_create(self.context,
{'user_id': 'fake', {'user_id': 'fake',
'project_id': 'fake', 'project_id': 'fake',
'mac_address': '56:12:12:12:12:12', 'mac_address': '56:12:12:12:12:12',
'instance_type_id': 1}) 'instance_type_id': 1})
def test_static_filters(self):
instance_ref = self._create_instance_ref()
ip = '10.11.12.13' ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context,
@ -767,6 +825,32 @@ class IptablesFirewallTestCase(test.TestCase):
"TCP port 80/81 acceptance rule wasn't added") "TCP port 80/81 acceptance rule wasn't added")
db.instance_destroy(admin_ctxt, instance_ref['id']) db.instance_destroy(admin_ctxt, instance_ref['id'])
def test_filters_for_instance(self):
network_info = _create_network_info()
rulesv4, rulesv6 = self.fw._filters_for_instance("fake", network_info)
self.assertEquals(len(rulesv4), 2)
self.assertEquals(len(rulesv6), 3)
def multinic_iptables_test(self):
ipv4_rules_per_network = 2
ipv6_rules_per_network = 3
networks_count = 5
instance_ref = self._create_instance_ref()
network_info = _create_network_info(networks_count)
ipv4_len = len(self.fw.iptables.ipv4['filter'].rules)
ipv6_len = len(self.fw.iptables.ipv6['filter'].rules)
inst_ipv4, inst_ipv6 = self.fw.instance_rules(instance_ref,
network_info)
self.fw.add_filters_for_instance(instance_ref, network_info)
ipv4 = self.fw.iptables.ipv4['filter'].rules
ipv6 = self.fw.iptables.ipv6['filter'].rules
ipv4_network_rules = len(ipv4) - len(inst_ipv4) - ipv4_len
ipv6_network_rules = len(ipv6) - len(inst_ipv6) - ipv6_len
self.assertEquals(ipv4_network_rules,
ipv4_rules_per_network * networks_count)
self.assertEquals(ipv6_network_rules,
ipv6_rules_per_network * networks_count)
class NWFilterTestCase(test.TestCase): class NWFilterTestCase(test.TestCase):
def setUp(self): def setUp(self):
@ -848,6 +932,28 @@ class NWFilterTestCase(test.TestCase):
return db.security_group_get_by_name(self.context, 'fake', 'testgroup') return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
def _create_instance(self):
return db.instance_create(self.context,
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '00:A0:C9:14:C8:29',
'instance_type_id': 1})
def _create_instance_type(self, params={}):
"""Create a test instance"""
context = self.context.elevated()
inst = {}
inst['name'] = 'm1.small'
inst['memory_mb'] = '1024'
inst['vcpus'] = '1'
inst['local_gb'] = '20'
inst['flavorid'] = '1'
inst['swap'] = '2048'
inst['rxtx_quota'] = 100
inst['rxtx_cap'] = 200
inst.update(params)
return db.instance_type_create(context, inst)['id']
def test_creates_base_rule_first(self): def test_creates_base_rule_first(self):
# These come pre-defined by libvirt # These come pre-defined by libvirt
self.defined_filters = ['no-mac-spoofing', self.defined_filters = ['no-mac-spoofing',
@ -876,25 +982,18 @@ class NWFilterTestCase(test.TestCase):
self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
instance_ref = db.instance_create(self.context, instance_ref = self._create_instance()
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '00:A0:C9:14:C8:29',
'instance_type_id': 1})
inst_id = instance_ref['id'] inst_id = instance_ref['id']
ip = '10.11.12.13' ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context, 'fake')
'fake') fixed_ip = {'address': ip, 'network_id': network_ref['id']}
fixed_ip = {'address': ip,
'network_id': network_ref['id']}
admin_ctxt = context.get_admin_context() admin_ctxt = context.get_admin_context()
db.fixed_ip_create(admin_ctxt, fixed_ip) db.fixed_ip_create(admin_ctxt, fixed_ip)
db.fixed_ip_update(admin_ctxt, ip, {'allocated': True, db.fixed_ip_update(admin_ctxt, ip, {'allocated': True,
'instance_id': instance_ref['id']}) 'instance_id': inst_id})
def _ensure_all_called(): def _ensure_all_called():
instance_filter = 'nova-instance-%s-%s' % (instance_ref['name'], instance_filter = 'nova-instance-%s-%s' % (instance_ref['name'],
@ -920,3 +1019,11 @@ class NWFilterTestCase(test.TestCase):
_ensure_all_called() _ensure_all_called()
self.teardown_security_group() self.teardown_security_group()
db.instance_destroy(admin_ctxt, instance_ref['id']) db.instance_destroy(admin_ctxt, instance_ref['id'])
def test_create_network_filters(self):
instance_ref = self._create_instance()
network_info = _create_network_info(3)
result = self.fw._create_network_filters(instance_ref,
network_info,
"fake")
self.assertEquals(len(result), 3)

View File

@ -960,26 +960,16 @@ class LibvirtConnection(driver.ComputeDriver):
mac_id = mapping['mac'].replace(':', '') mac_id = mapping['mac'].replace(':', '')
if FLAGS.allow_project_net_traffic: if FLAGS.allow_project_net_traffic:
if FLAGS.use_ipv6: template = "<parameter name=\"%s\"value=\"%s\" />\n"
net, mask = _get_net_and_mask(network['cidr']) net, mask = _get_net_and_mask(network['cidr'])
values = [("PROJNET", net), ("PROJMASK", mask)]
if FLAGS.use_ipv6:
net_v6, prefixlen_v6 = _get_net_and_prefixlen( net_v6, prefixlen_v6 = _get_net_and_prefixlen(
network['cidr_v6']) network['cidr_v6'])
extra_params = ("<parameter name=\"PROJNET\" " values.extend([("PROJNETV6", net_v6),
"value=\"%s\" />\n" ("PROJMASKV6", prefixlen_v6)])
"<parameter name=\"PROJMASK\" "
"value=\"%s\" />\n" extra_params = "".join([template % value for value in values])
"<parameter name=\"PROJNETV6\" "
"value=\"%s\" />\n"
"<parameter name=\"PROJMASKV6\" "
"value=\"%s\" />\n") % \
(net, mask, net_v6, prefixlen_v6)
else:
net, mask = _get_net_and_mask(network['cidr'])
extra_params = ("<parameter name=\"PROJNET\" "
"value=\"%s\" />\n"
"<parameter name=\"PROJMASK\" "
"value=\"%s\" />\n") % \
(net, mask)
else: else:
extra_params = "\n" extra_params = "\n"
@ -997,10 +987,7 @@ class LibvirtConnection(driver.ComputeDriver):
return result return result
def to_xml(self, instance, rescue=False, network_info=None): def _prepare_xml_info(self, instance, rescue=False, network_info=None):
# TODO(termie): cache?
LOG.debug(_('instance %s: starting toXML method'), instance['name'])
# TODO(adiantum) remove network_info creation code # TODO(adiantum) remove network_info creation code
# when multinics will be completed # when multinics will be completed
if not network_info: if not network_info:
@ -1008,8 +995,7 @@ class LibvirtConnection(driver.ComputeDriver):
nics = [] nics = []
for (network, mapping) in network_info: for (network, mapping) in network_info:
nics.append(self._get_nic_for_xml(network, nics.append(self._get_nic_for_xml(network, mapping))
mapping))
# FIXME(vish): stick this in db # FIXME(vish): stick this in db
inst_type_id = instance['instance_type_id'] inst_type_id = instance['instance_type_id']
inst_type = instance_types.get_instance_type(inst_type_id) inst_type = instance_types.get_instance_type(inst_type_id)
@ -1041,10 +1027,14 @@ class LibvirtConnection(driver.ComputeDriver):
xml_info['ramdisk'] = xml_info['basepath'] + "/ramdisk" xml_info['ramdisk'] = xml_info['basepath'] + "/ramdisk"
xml_info['disk'] = xml_info['basepath'] + "/disk" xml_info['disk'] = xml_info['basepath'] + "/disk"
return xml_info
def to_xml(self, instance, rescue=False, network_info=None):
# TODO(termie): cache?
LOG.debug(_('instance %s: starting toXML method'), instance['name'])
xml_info = self._prepare_xml_info(instance, rescue, network_info)
xml = str(Template(self.libvirt_xml, searchList=[xml_info])) xml = str(Template(self.libvirt_xml, searchList=[xml_info]))
LOG.debug(_('instance %s: finished toXML method'), LOG.debug(_('instance %s: finished toXML method'), instance['name'])
instance['name'])
return xml return xml
def _lookup_by_name(self, instance_name): def _lookup_by_name(self, instance_name):
@ -1846,10 +1836,6 @@ class NWFilterFirewall(FirewallDriver):
""" """
if not network_info: if not network_info:
network_info = _get_network_info(instance) network_info = _get_network_info(instance)
if instance['image_id'] == str(FLAGS.vpn_image_id):
base_filter = 'nova-vpn'
else:
base_filter = 'nova-base'
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
@ -1861,41 +1847,59 @@ class NWFilterFirewall(FirewallDriver):
'nova-base-ipv6', 'nova-base-ipv6',
'nova-allow-dhcp-server'] 'nova-allow-dhcp-server']
if FLAGS.use_ipv6:
networks = [network for (network, _m) in network_info if
network['gateway_v6']]
if networks:
instance_secgroup_filter_children.\
append('nova-allow-ra-server')
for security_group in \ for security_group in \
db.security_group_get_by_instance(ctxt, instance['id']): db.security_group_get_by_instance(ctxt, instance['id']):
self.refresh_security_group_rules(security_group['id']) self.refresh_security_group_rules(security_group['id'])
instance_secgroup_filter_children += [('nova-secgroup-%s' % instance_secgroup_filter_children.append('nova-secgroup-%s' %
security_group['id'])] security_group['id'])
self._define_filter( self._define_filter(
self._filter_container(instance_secgroup_filter_name, self._filter_container(instance_secgroup_filter_name,
instance_secgroup_filter_children)) instance_secgroup_filter_children))
for (network, mapping) in network_info: network_filters = self.\
_create_network_filters(instance, network_info,
instance_secgroup_filter_name)
for (name, children) in network_filters:
self._define_filters(name, children)
def _create_network_filters(self, instance, network_info,
instance_secgroup_filter_name):
if instance['image_id'] == str(FLAGS.vpn_image_id):
base_filter = 'nova-vpn'
else:
base_filter = 'nova-base'
result = []
for (_n, mapping) in network_info:
nic_id = mapping['mac'].replace(':', '') nic_id = mapping['mac'].replace(':', '')
instance_filter_name = self._instance_filter_name(instance, nic_id) instance_filter_name = self._instance_filter_name(instance, nic_id)
instance_filter_children = \ instance_filter_children = [base_filter,
[base_filter, instance_secgroup_filter_name] instance_secgroup_filter_name]
if FLAGS.use_ipv6:
gateway_v6 = network['gateway_v6']
if gateway_v6:
instance_secgroup_filter_children += \
['nova-allow-ra-server']
if FLAGS.allow_project_net_traffic: if FLAGS.allow_project_net_traffic:
instance_filter_children += ['nova-project'] instance_filter_children.append('nova-project')
if FLAGS.use_ipv6: if FLAGS.use_ipv6:
instance_filter_children += ['nova-project-v6'] instance_filter_children.append('nova-project-v6')
self._define_filter( result.append((instance_filter_name, instance_filter_children))
self._filter_container(instance_filter_name,
instance_filter_children))
return return result
def _define_filters(self, filter_name, filter_children):
self._define_filter(self._filter_container(filter_name,
filter_children))
def refresh_security_group_rules(self, security_group_id): def refresh_security_group_rules(self, security_group_id):
return self._define_filter( return self._define_filter(
@ -1997,34 +2001,21 @@ class IptablesFirewallDriver(FirewallDriver):
self.add_filters_for_instance(instance, network_info) self.add_filters_for_instance(instance, network_info)
self.iptables.apply() self.iptables.apply()
def add_filters_for_instance(self, instance, network_info=None): def _create_filter(self, ips, chain_name):
if not network_info: return ['-d %s -j $%s' % (ip, chain_name) for ip in ips]
network_info = _get_network_info(instance)
chain_name = self._instance_chain_name(instance)
self.iptables.ipv4['filter'].add_chain(chain_name) def _filters_for_instance(self, chain_name, network_info):
ips_v4 = [ip['ip'] for (_n, mapping) in network_info
ips_v4 = [ip['ip'] for (_, mapping) in network_info
for ip in mapping['ips']] for ip in mapping['ips']]
ipv4_rules = self._create_filter(ips_v4, chain_name)
for ipv4_address in ips_v4: ips_v6 = [ip['ip'] for (_n, mapping) in network_info
self.iptables.ipv4['filter'].add_rule('local',
'-d %s -j $%s' %
(ipv4_address, chain_name))
if FLAGS.use_ipv6:
self.iptables.ipv6['filter'].add_chain(chain_name)
ips_v6 = [ip['ip'] for (_, mapping) in network_info
for ip in mapping['ip6s']] for ip in mapping['ip6s']]
for ipv6_address in ips_v6: ipv6_rules = self._create_filter(ips_v6, chain_name)
self.iptables.ipv6['filter'].add_rule('local', return ipv4_rules, ipv6_rules
'-d %s -j $%s' %
(ipv6_address,
chain_name))
ipv4_rules, ipv6_rules = self.instance_rules(instance, network_info)
def _add_filters(self, chain_name, ipv4_rules, ipv6_rules):
for rule in ipv4_rules: for rule in ipv4_rules:
self.iptables.ipv4['filter'].add_rule(chain_name, rule) self.iptables.ipv4['filter'].add_rule(chain_name, rule)
@ -2032,6 +2023,17 @@ class IptablesFirewallDriver(FirewallDriver):
for rule in ipv6_rules: for rule in ipv6_rules:
self.iptables.ipv6['filter'].add_rule(chain_name, rule) self.iptables.ipv6['filter'].add_rule(chain_name, rule)
def add_filters_for_instance(self, instance, network_info=None):
chain_name = self._instance_chain_name(instance)
if FLAGS.use_ipv6:
self.iptables.ipv6['filter'].add_chain(chain_name)
self.iptables.ipv4['filter'].add_chain(chain_name)
ipv4_rules, ipv6_rules = self._filters_for_instance(chain_name,
network_info)
self._add_filters('local', ipv4_rules, ipv6_rules)
ipv4_rules, ipv6_rules = self.instance_rules(instance, network_info)
self._add_filters(chain_name, ipv4_rules, ipv6_rules)
def remove_filters_for_instance(self, instance): def remove_filters_for_instance(self, instance):
chain_name = self._instance_chain_name(instance) chain_name = self._instance_chain_name(instance)