Add the layout for a couple of Windows OS utils, especially networking

This patch adds the routes implementation, general.check_os_version
and network.default_gateway, which are used by the first data source
layout.

Change-Id: If8ede3c41e834d62cfb2d341d88bc1fbaef947b6
This commit is contained in:
Claudiu Popa 2015-06-04 15:02:49 +03:00
parent f92c13ba45
commit e59fe5bb27
14 changed files with 929 additions and 1 deletions

8
cloudinit/exceptions.py Normal file
View File

@ -0,0 +1,8 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
class CloudInitError(Exception):
pass

View File

@ -74,6 +74,10 @@ class Route(object):
self.use = use
self.expire = expire
def __repr__(self):
return ("Route(destination={!r}, gateway={!r}, netmask={!r})"
.format(self.destination, self.gateway, self.netmask))
@abc.abstractproperty
def is_static(self):
"""Check if this route is static."""

View File

@ -51,7 +51,7 @@ class User(object):
"""Base class for an user."""
@classmethod
def create(self, username, password, **kwargs):
def create(cls, username, password, **kwargs):
"""Create a new user."""
@abc.abstractmethod

View File

View File

@ -0,0 +1,26 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
from cloudinit.osys import base
from cloudinit.osys.windows import general as general_module
from cloudinit.osys.windows import network as network_module
__all__ = ('OSUtils', )
class OSUtils(base.OSUtils):
"""The OS utils namespace for the Windows platform."""
name = "windows"
network = network_module.Network()
general = general_module.General()
route_class = network_module.Route
# These aren't yet implemented, use `None` for them
# so that we could instantiate the class.
filesystem = user_class = users = None
interface_class = None

View File

@ -0,0 +1,59 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
"""General utilities for Windows platform."""
import ctypes
from cloudinit import exceptions
from cloudinit.osys import general
from cloudinit.osys.windows.util import kernel32
class General(general.General):
"""General utilities namespace for Windows."""
@staticmethod
def check_os_version(major, minor, build=0):
"""Check if this OS version is equal or higher than (major, minor)"""
version_info = kernel32.Win32_OSVERSIONINFOEX_W()
version_info.dwOSVersionInfoSize = ctypes.sizeof(
kernel32.Win32_OSVERSIONINFOEX_W)
version_info.dwMajorVersion = major
version_info.dwMinorVersion = minor
version_info.dwBuildNumber = build
mask = 0
for type_mask in [kernel32.VER_MAJORVERSION,
kernel32.VER_MINORVERSION,
kernel32.VER_BUILDNUMBER]:
mask = kernel32.VerSetConditionMask(mask, type_mask,
kernel32.VER_GREATER_EQUAL)
type_mask = (kernel32.VER_MAJORVERSION |
kernel32.VER_MINORVERSION |
kernel32.VER_BUILDNUMBER)
ret_val = kernel32.VerifyVersionInfoW(ctypes.byref(version_info),
type_mask, mask)
if ret_val:
return True
else:
err = kernel32.GetLastError()
if err == kernel32.ERROR_OLD_WIN_VERSION:
return False
else:
raise exceptions.CloudInitError(
"VerifyVersionInfo failed with error: %s" % err)
def reboot(self):
raise NotImplementedError
def set_locale(self):
raise NotImplementedError
def set_timezone(self):
raise NotImplementedError

View File

@ -0,0 +1,159 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
"""Network utilities for Windows."""
import contextlib
import ctypes
from ctypes import wintypes
import logging
import subprocess
from cloudinit import exceptions
from cloudinit.osys import network
from cloudinit.osys.windows.util import iphlpapi
from cloudinit.osys.windows.util import kernel32
from cloudinit.osys.windows.util import ws2_32
MIB_IPPROTO_NETMGMT = 3
_FW_IP_PROTOCOL_TCP = 6
_FW_IP_PROTOCOL_UDP = 17
_FW_SCOPE_ALL = 0
_PROTOCOL_TCP = "TCP"
_PROTOCOL_UDP = "UDP"
_ERROR_FILE_NOT_FOUND = 2
_ComputerNamePhysicalDnsHostname = 5
LOG = logging.getLogger(__file__)
def _heap_alloc(heap, size):
table_mem = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value))
if not table_mem:
raise exceptions.CloudInitError(
'Unable to allocate memory for the IP forward table')
return table_mem
class Network(network.Network):
"""Network namespace object tailored for the Windows platform."""
@staticmethod
@contextlib.contextmanager
def _get_forward_table():
heap = kernel32.GetProcessHeap()
forward_table_size = ctypes.sizeof(iphlpapi.Win32_MIB_IPFORWARDTABLE)
size = wintypes.ULONG(forward_table_size)
table_mem = _heap_alloc(heap, size)
p_forward_table = ctypes.cast(
table_mem, ctypes.POINTER(iphlpapi.Win32_MIB_IPFORWARDTABLE))
try:
err = iphlpapi.GetIpForwardTable(p_forward_table,
ctypes.byref(size), 0)
if err == iphlpapi.ERROR_INSUFFICIENT_BUFFER:
kernel32.HeapFree(heap, 0, p_forward_table)
table_mem = _heap_alloc(heap, size)
p_forward_table = ctypes.cast(
table_mem,
ctypes.POINTER(iphlpapi.Win32_MIB_IPFORWARDTABLE))
err = iphlpapi.GetIpForwardTable(p_forward_table,
ctypes.byref(size), 0)
if err and err != kernel32.ERROR_NO_DATA:
raise exceptions.CloudInitError(
'Unable to get IP forward table. Error: %s' % err)
yield p_forward_table
finally:
kernel32.HeapFree(heap, 0, p_forward_table)
def routes(self):
"""Get a collection of the available routes."""
routing_table = []
with self._get_forward_table() as p_forward_table:
forward_table = p_forward_table.contents
table = ctypes.cast(
ctypes.addressof(forward_table.table),
ctypes.POINTER(iphlpapi.Win32_MIB_IPFORWARDROW *
forward_table.dwNumEntries)).contents
for row in table:
destination = ws2_32.Ws2_32.inet_ntoa(
row.dwForwardDest).decode()
netmask = ws2_32.Ws2_32.inet_ntoa(
row.dwForwardMask).decode()
gateway = ws2_32.Ws2_32.inet_ntoa(
row.dwForwardNextHop).decode()
index = row.dwForwardIfIndex
flags = row.dwForwardProto
metric = row.dwForwardMetric1
route = Route(destination=destination,
gateway=gateway,
netmask=netmask,
interface=index,
metric=metric,
flags=flags)
routing_table.append(route)
return routing_table
def default_gateway(self):
"""Get the default gateway.
This will actually return a :class:`Route` instance. The gateway
can be accessed with the :attr:`gateway` attribute.
"""
return next((r for r in self.routes() if r.destination == '0.0.0.0'),
None)
# These are not required by the Windows version for now,
# but we provide them as noop version.
def hosts(self):
"""Grab the content of the hosts file."""
raise NotImplementedError
def interfaces(self):
raise NotImplementedError
def set_hostname(self, hostname):
raise NotImplementedError
def set_static_network_config(self, adapter_name, address, netmask,
broadcast, gateway, dnsnameservers):
raise NotImplementedError
class Route(network.Route):
"""Windows route class."""
@property
def is_static(self):
return self.flags == MIB_IPPROTO_NETMGMT
@classmethod
def add(cls, route):
"""Add a new route in the underlying OS.
The function should expect an instance of :class:`Route`.
"""
args = ['ROUTE', 'ADD',
route.destination,
'MASK', route.netmask, route.gateway]
popen = subprocess.Popen(args, shell=False,
stderr=subprocess.PIPE)
_, stderr = popen.communicate()
if popen.returncode or stderr:
# Cannot use the return value to determine the outcome
raise exceptions.CloudInitError('Unable to add route: %s' % stderr)
@classmethod
def delete(cls, _):
"""Delete a route from the underlying OS.
This function should expect an instance of :class:`Route`.
"""
raise NotImplementedError

View File

View File

@ -0,0 +1,210 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import ctypes
from ctypes import windll
from ctypes import wintypes
from cloudinit.osys.windows.util import kernel32
from cloudinit.osys.windows.util import ws2_32
ERROR_INSUFFICIENT_BUFFER = 122
MAX_ADAPTER_NAME_LENGTH = 256
MAX_ADAPTER_DESCRIPTION_LENGTH = 128
MAX_ADAPTER_ADDRESS_LENGTH = 8
# Do not return IPv6 anycast addresses.
GAA_FLAG_SKIP_ANYCAST = 2
GAA_FLAG_SKIP_MULTICAST = 4
IP_ADAPTER_DHCP_ENABLED = 4
IP_ADAPTER_IPV4_ENABLED = 0x80
IP_ADAPTER_IPV6_ENABLED = 0x0100
MAX_DHCPV6_DUID_LENGTH = 130
IF_TYPE_ETHERNET_CSMACD = 6
IF_TYPE_SOFTWARE_LOOPBACK = 24
IF_TYPE_IEEE80211 = 71
IF_TYPE_TUNNEL = 131
IP_ADAPTER_ADDRESSES_SIZE_2003 = 144
class SOCKET_ADDRESS(ctypes.Structure):
_fields_ = [
('lpSockaddr', ctypes.POINTER(ws2_32.SOCKADDR)),
('iSockaddrLength', wintypes.INT),
]
class IP_ADAPTER_ADDRESSES_Struct1(ctypes.Structure):
_fields_ = [
('Length', wintypes.ULONG),
('IfIndex', wintypes.DWORD),
]
class IP_ADAPTER_ADDRESSES_Union1(ctypes.Union):
_fields_ = [
('Alignment', wintypes.ULARGE_INTEGER),
('Struct1', IP_ADAPTER_ADDRESSES_Struct1),
]
class IP_ADAPTER_UNICAST_ADDRESS(ctypes.Structure):
_fields_ = [
('Union1', IP_ADAPTER_ADDRESSES_Union1),
('Next', wintypes.LPVOID),
('Address', SOCKET_ADDRESS),
('PrefixOrigin', wintypes.DWORD),
('SuffixOrigin', wintypes.DWORD),
('DadState', wintypes.DWORD),
('ValidLifetime', wintypes.ULONG),
('PreferredLifetime', wintypes.ULONG),
('LeaseLifetime', wintypes.ULONG),
]
class IP_ADAPTER_DNS_SERVER_ADDRESS_Struct1(ctypes.Structure):
_fields_ = [
('Length', wintypes.ULONG),
('Reserved', wintypes.DWORD),
]
class IP_ADAPTER_DNS_SERVER_ADDRESS_Union1(ctypes.Union):
_fields_ = [
('Alignment', wintypes.ULARGE_INTEGER),
('Struct1', IP_ADAPTER_DNS_SERVER_ADDRESS_Struct1),
]
class IP_ADAPTER_DNS_SERVER_ADDRESS(ctypes.Structure):
_fields_ = [
('Union1', IP_ADAPTER_DNS_SERVER_ADDRESS_Union1),
('Next', wintypes.LPVOID),
('Address', SOCKET_ADDRESS),
]
class IP_ADAPTER_PREFIX_Struct1(ctypes.Structure):
_fields_ = [
('Length', wintypes.ULONG),
('Flags', wintypes.DWORD),
]
class IP_ADAPTER_PREFIX_Union1(ctypes.Union):
_fields_ = [
('Alignment', wintypes.ULARGE_INTEGER),
('Struct1', IP_ADAPTER_PREFIX_Struct1),
]
class IP_ADAPTER_PREFIX(ctypes.Structure):
_fields_ = [
('Union1', IP_ADAPTER_PREFIX_Union1),
('Next', wintypes.LPVOID),
('Address', SOCKET_ADDRESS),
('PrefixLength', wintypes.ULONG),
]
class NET_LUID_LH(ctypes.Union):
_fields_ = [
('Value', wintypes.ULARGE_INTEGER),
('Info', wintypes.ULARGE_INTEGER),
]
class IP_ADAPTER_ADDRESSES(ctypes.Structure):
_fields_ = [
('Union1', IP_ADAPTER_ADDRESSES_Union1),
('Next', wintypes.LPVOID),
('AdapterName', ctypes.c_char_p),
('FirstUnicastAddress',
ctypes.POINTER(IP_ADAPTER_UNICAST_ADDRESS)),
('FirstAnycastAddress',
ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),
('FirstMulticastAddress',
ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),
('FirstDnsServerAddress',
ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),
('DnsSuffix', wintypes.LPWSTR),
('Description', wintypes.LPWSTR),
('FriendlyName', wintypes.LPWSTR),
('PhysicalAddress', ctypes.c_ubyte * MAX_ADAPTER_ADDRESS_LENGTH),
('PhysicalAddressLength', wintypes.DWORD),
('Flags', wintypes.DWORD),
('Mtu', wintypes.DWORD),
('IfType', wintypes.DWORD),
('OperStatus', wintypes.DWORD),
('Ipv6IfIndex', wintypes.DWORD),
('ZoneIndices', wintypes.DWORD * 16),
('FirstPrefix', ctypes.POINTER(IP_ADAPTER_PREFIX)),
# kernel >= 6.0
('TransmitLinkSpeed', wintypes.ULARGE_INTEGER),
('ReceiveLinkSpeed', wintypes.ULARGE_INTEGER),
('FirstWinsServerAddress',
ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),
('FirstGatewayAddress',
ctypes.POINTER(IP_ADAPTER_DNS_SERVER_ADDRESS)),
('Ipv4Metric', wintypes.ULONG),
('Ipv6Metric', wintypes.ULONG),
('Luid', NET_LUID_LH),
('Dhcpv4Server', SOCKET_ADDRESS),
('CompartmentId', wintypes.DWORD),
('NetworkGuid', kernel32.GUID),
('ConnectionType', wintypes.DWORD),
('TunnelType', wintypes.DWORD),
('Dhcpv6Server', SOCKET_ADDRESS),
('Dhcpv6ClientDuid', ctypes.c_ubyte * MAX_DHCPV6_DUID_LENGTH),
('Dhcpv6ClientDuidLength', wintypes.ULONG),
('Dhcpv6Iaid', wintypes.ULONG),
]
class Win32_MIB_IPFORWARDROW(ctypes.Structure):
_fields_ = [
('dwForwardDest', wintypes.DWORD),
('dwForwardMask', wintypes.DWORD),
('dwForwardPolicy', wintypes.DWORD),
('dwForwardNextHop', wintypes.DWORD),
('dwForwardIfIndex', wintypes.DWORD),
('dwForwardType', wintypes.DWORD),
('dwForwardProto', wintypes.DWORD),
('dwForwardAge', wintypes.DWORD),
('dwForwardNextHopAS', wintypes.DWORD),
('dwForwardMetric1', wintypes.DWORD),
('dwForwardMetric2', wintypes.DWORD),
('dwForwardMetric3', wintypes.DWORD),
('dwForwardMetric4', wintypes.DWORD),
('dwForwardMetric5', wintypes.DWORD)
]
class Win32_MIB_IPFORWARDTABLE(ctypes.Structure):
_fields_ = [
('dwNumEntries', wintypes.DWORD),
('table', Win32_MIB_IPFORWARDROW * 1)
]
GetAdaptersAddresses = windll.Iphlpapi.GetAdaptersAddresses
GetAdaptersAddresses.argtypes = [
wintypes.ULONG, wintypes.ULONG, wintypes.LPVOID,
ctypes.POINTER(IP_ADAPTER_ADDRESSES),
ctypes.POINTER(wintypes.ULONG)]
GetAdaptersAddresses.restype = wintypes.ULONG
GetIpForwardTable = windll.Iphlpapi.GetIpForwardTable
GetIpForwardTable.argtypes = [
ctypes.POINTER(Win32_MIB_IPFORWARDTABLE),
ctypes.POINTER(wintypes.ULONG),
wintypes.BOOL]
GetIpForwardTable.restype = wintypes.DWORD

View File

@ -0,0 +1,85 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import ctypes
from ctypes import windll
from ctypes import wintypes
ERROR_BUFFER_OVERFLOW = 111
ERROR_NO_DATA = 232
class GUID(ctypes.Structure):
_fields_ = [
("data1", wintypes.DWORD),
("data2", wintypes.WORD),
("data3", wintypes.WORD),
("data4", wintypes.BYTE * 8)]
def __init__(self, l, w1, w2, b1, b2, b3, b4, b5, b6, b7, b8):
self.data1 = l
self.data2 = w1
self.data3 = w2
self.data4[0] = b1
self.data4[1] = b2
self.data4[2] = b3
self.data4[3] = b4
self.data4[4] = b5
self.data4[5] = b6
self.data4[6] = b7
self.data4[7] = b8
class Win32_OSVERSIONINFOEX_W(ctypes.Structure):
_fields_ = [
('dwOSVersionInfoSize', wintypes.DWORD),
('dwMajorVersion', wintypes.DWORD),
('dwMinorVersion', wintypes.DWORD),
('dwBuildNumber', wintypes.DWORD),
('dwPlatformId', wintypes.DWORD),
('szCSDVersion', wintypes.WCHAR * 128),
('wServicePackMajor', wintypes.DWORD),
('wServicePackMinor', wintypes.DWORD),
('wSuiteMask', wintypes.DWORD),
('wProductType', wintypes.BYTE),
('wReserved', wintypes.BYTE)
]
GetLastError = windll.kernel32.GetLastError
GetProcessHeap = windll.kernel32.GetProcessHeap
GetProcessHeap.argtypes = []
GetProcessHeap.restype = wintypes.HANDLE
HeapAlloc = windll.kernel32.HeapAlloc
# Note: wintypes.ULONG must be replaced with a 64 bit variable on x64
HeapAlloc.argtypes = [wintypes.HANDLE, wintypes.DWORD, wintypes.ULONG]
HeapAlloc.restype = wintypes.LPVOID
HeapFree = windll.kernel32.HeapFree
HeapFree.argtypes = [wintypes.HANDLE, wintypes.DWORD, wintypes.LPVOID]
HeapFree.restype = wintypes.BOOL
SetComputerNameExW = windll.kernel32.SetComputerNameExW
VerifyVersionInfoW = windll.kernel32.VerifyVersionInfoW
VerSetConditionMask = windll.kernel32.VerSetConditionMask
VerifyVersionInfoW.argtypes = [
ctypes.POINTER(Win32_OSVERSIONINFOEX_W),
wintypes.DWORD, wintypes.ULARGE_INTEGER]
VerifyVersionInfoW.restype = wintypes.BOOL
VerSetConditionMask.argtypes = [wintypes.ULARGE_INTEGER,
wintypes.DWORD,
wintypes.BYTE]
VerSetConditionMask.restype = wintypes.ULARGE_INTEGER
ERROR_OLD_WIN_VERSION = 1150
VER_MAJORVERSION = 1
VER_MINORVERSION = 2
VER_BUILDNUMBER = 4
VER_GREATER_EQUAL = 3

View File

@ -0,0 +1,54 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import ctypes
from ctypes import windll
from ctypes import wintypes
AF_UNSPEC = 0
AF_INET = 2
AF_INET6 = 23
VERSION_2_2 = (2 << 8) + 2
class SOCKADDR(ctypes.Structure):
_fields_ = [
('sa_family', wintypes.USHORT),
('sa_data', ctypes.c_char * 14),
]
class WSADATA(ctypes.Structure):
_fields_ = [
('opaque_data', wintypes.BYTE * 400),
]
WSAGetLastError = windll.Ws2_32.WSAGetLastError
WSAGetLastError.argtypes = []
WSAGetLastError.restype = wintypes.INT
WSAStartup = windll.Ws2_32.WSAStartup
WSAStartup.argtypes = [wintypes.WORD, ctypes.POINTER(WSADATA)]
WSAStartup.restype = wintypes.INT
WSACleanup = windll.Ws2_32.WSACleanup
WSACleanup.argtypes = []
WSACleanup.restype = wintypes.INT
WSAAddressToStringW = windll.Ws2_32.WSAAddressToStringW
WSAAddressToStringW.argtypes = [
ctypes.POINTER(SOCKADDR), wintypes.DWORD, wintypes.LPVOID,
wintypes.LPWSTR, ctypes.POINTER(wintypes.DWORD)]
WSAAddressToStringW.restype = wintypes.INT
Ws2_32 = windll.Ws2_32
Ws2_32.inet_ntoa.restype = ctypes.c_char_p
def init_wsa(version=VERSION_2_2):
wsadata = WSADATA()
WSAStartup(version, ctypes.byref(wsadata))

View File

View File

@ -0,0 +1,70 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import importlib
import unittest
from cloudinit import exceptions
from cloudinit.tests.util import mock
class TestWindowsGeneral(unittest.TestCase):
def setUp(self):
self._ctypes_mock = mock.Mock()
self._util_mock = mock.MagicMock()
self._module_patcher = mock.patch.dict(
'sys.modules',
{'ctypes': self._ctypes_mock,
'cloudinit.osys.windows.util': self._util_mock})
self._module_patcher.start()
self._general_module = importlib.import_module(
"cloudinit.osys.windows.general")
self._kernel32 = self._general_module.kernel32
self._general = self._general_module.General()
def tearDown(self):
self._module_patcher.stop()
def _test_check_os_version(self, ret_value, error_value=None):
verset_return = 2
self._kernel32.VerSetConditionMask.return_value = (
verset_return)
self._kernel32.VerifyVersionInfoW.return_value = ret_value
self._kernel32.GetLastError.return_value = error_value
old_version = self._kernel32.ERROR_OLD_WIN_VERSION
if error_value and error_value is not old_version:
self.assertRaises(exceptions.CloudInitError,
self._general.check_os_version, 3, 1, 2)
self._kernel32.GetLastError.assert_called_once_with()
else:
response = self._general.check_os_version(3, 1, 2)
self._ctypes_mock.sizeof.assert_called_once_with(
self._kernel32.Win32_OSVERSIONINFOEX_W)
self.assertEqual(
3, self._kernel32.VerSetConditionMask.call_count)
mask = (self._kernel32.VER_MAJORVERSION |
self._kernel32.VER_MINORVERSION |
self._kernel32.VER_BUILDNUMBER)
self._kernel32.VerifyVersionInfoW.assert_called_with(
self._ctypes_mock.byref.return_value, mask, verset_return)
if error_value is old_version:
self._kernel32.GetLastError.assert_called_with()
self.assertFalse(response)
else:
self.assertTrue(response)
def test_check_os_version(self):
m = mock.MagicMock()
self._test_check_os_version(ret_value=m)
def test_check_os_version_expect_false(self):
self._test_check_os_version(
ret_value=None, error_value=self._kernel32.ERROR_OLD_WIN_VERSION)

View File

@ -0,0 +1,253 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import importlib
import subprocess
import unittest
from cloudinit import exceptions
from cloudinit.tests.util import mock
class TestNetworkWindows(unittest.TestCase):
def setUp(self):
self._ctypes_mock = mock.MagicMock()
self._moves_mock = mock.Mock()
self._win32com_mock = mock.Mock()
self._wmi_mock = mock.Mock()
self._module_patcher = mock.patch.dict(
'sys.modules',
{'ctypes': self._ctypes_mock,
'win32com': self._win32com_mock,
'wmi': self._wmi_mock,
'six.moves': self._moves_mock})
self._module_patcher.start()
self._iphlpapi = mock.Mock()
self._kernel32 = mock.Mock()
self._ws2_32 = mock.Mock()
self._network_module = importlib.import_module(
'cloudinit.osys.windows.network')
self._network_module.iphlpapi = self._iphlpapi
self._network_module.kernel32 = self._kernel32
self._network_module.ws2_32 = self._ws2_32
self._network = self._network_module.Network()
def tearDown(self):
self._module_patcher.stop()
def _test__heap_alloc(self, fail):
mock_heap = mock.Mock()
mock_size = mock.Mock()
if fail:
self._kernel32.HeapAlloc.return_value = None
with self.assertRaises(exceptions.CloudInitError) as cm:
self._network_module._heap_alloc(mock_heap, mock_size)
self.assertEqual('Unable to allocate memory for the IP '
'forward table',
str(cm.exception))
else:
result = self._network_module._heap_alloc(mock_heap, mock_size)
self.assertEqual(self._kernel32.HeapAlloc.return_value, result)
self._kernel32.HeapAlloc.assert_called_once_with(
mock_heap, 0, self._ctypes_mock.c_size_t(mock_size.value))
def test__heap_alloc_error(self):
self._test__heap_alloc(fail=True)
def test__heap_alloc_no_error(self):
self._test__heap_alloc(fail=False)
def test__get_forward_table_no_memory(self):
self._network_module._heap_alloc = mock.Mock()
error_msg = 'Unable to allocate memory for the IP forward table'
exc = exceptions.CloudInitError(error_msg)
self._network_module._heap_alloc.side_effect = exc
with self.assertRaises(exceptions.CloudInitError) as cm:
with self._network._get_forward_table():
pass
self.assertEqual(error_msg, str(cm.exception))
self._network_module._heap_alloc.assert_called_once_with(
self._kernel32.GetProcessHeap.return_value,
self._ctypes_mock.wintypes.ULONG.return_value)
def test__get_forward_table_insufficient_buffer_no_memory(self):
self._kernel32.HeapAlloc.side_effect = (mock.sentinel.table_mem, None)
self._iphlpapi.GetIpForwardTable.return_value = (
self._iphlpapi.ERROR_INSUFFICIENT_BUFFER)
with self.assertRaises(exceptions.CloudInitError):
with self._network._get_forward_table():
pass
table = self._ctypes_mock.cast.return_value
self._iphlpapi.GetIpForwardTable.assert_called_once_with(
table,
self._ctypes_mock.byref.return_value, 0)
heap_calls = [
mock.call(self._kernel32.GetProcessHeap.return_value, 0, table),
mock.call(self._kernel32.GetProcessHeap.return_value, 0, table)
]
self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls)
def _test__get_forward_table(self, reallocation=False,
insufficient_buffer=False,
fail=False):
if fail:
with self.assertRaises(exceptions.CloudInitError) as cm:
with self._network._get_forward_table():
pass
msg = ('Unable to get IP forward table. Error: %s'
% mock.sentinel.error)
self.assertEqual(msg, str(cm.exception))
else:
with self._network._get_forward_table() as table:
pass
pointer = self._ctypes_mock.POINTER(
self._iphlpapi.Win32_MIB_IPFORWARDTABLE)
expected_forward_table = self._ctypes_mock.cast(
self._kernel32.HeapAlloc.return_value, pointer)
self.assertEqual(expected_forward_table, table)
heap_calls = [
mock.call(self._kernel32.GetProcessHeap.return_value, 0,
self._ctypes_mock.cast.return_value)
]
forward_calls = [
mock.call(self._ctypes_mock.cast.return_value,
self._ctypes_mock.byref.return_value, 0),
]
if insufficient_buffer:
# We expect two calls for GetIpForwardTable
forward_calls.append(forward_calls[0])
if reallocation:
heap_calls.append(heap_calls[0])
self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls)
self.assertEqual(forward_calls,
self._iphlpapi.GetIpForwardTable.mock_calls)
def test__get_forward_table_sufficient_buffer(self):
self._iphlpapi.GetIpForwardTable.return_value = None
self._test__get_forward_table()
def test__get_forward_table_insufficient_buffer_reallocate(self):
self._kernel32.HeapAlloc.side_effect = (
mock.sentinel.table_mem, mock.sentinel.table_mem)
self._iphlpapi.GetIpForwardTable.side_effect = (
self._iphlpapi.ERROR_INSUFFICIENT_BUFFER, None)
self._test__get_forward_table(reallocation=True,
insufficient_buffer=True)
def test__get_forward_table_insufficient_buffer_other_error(self):
self._kernel32.HeapAlloc.side_effect = (
mock.sentinel.table_mem, mock.sentinel.table_mem)
self._iphlpapi.GetIpForwardTable.side_effect = (
self._iphlpapi.ERROR_INSUFFICIENT_BUFFER, mock.sentinel.error)
self._test__get_forward_table(reallocation=True,
insufficient_buffer=True,
fail=True)
@mock.patch('cloudinit.osys.windows.network.Network.routes')
def test_default_gateway_no_gateway(self, mock_routes):
mock_routes.return_value = iter((mock.Mock(), mock.Mock()))
self.assertIsNone(self._network.default_gateway())
mock_routes.assert_called_once_with()
@mock.patch('cloudinit.osys.windows.network.Network.routes')
def test_default_gateway(self, mock_routes):
default_gateway = mock.Mock()
default_gateway.destination = '0.0.0.0'
mock_routes.return_value = iter((mock.Mock(), default_gateway))
gateway = self._network.default_gateway()
self.assertEqual(default_gateway, gateway)
def test_route_is_static(self):
bad_route = self._network_module.Route(
destination=None, netmask=None,
gateway=None, interface=None, metric=None,
flags=404)
good_route = self._network_module.Route(
destination=None, netmask=None,
gateway=None, interface=None, metric=None,
flags=self._network_module.MIB_IPPROTO_NETMGMT)
self.assertTrue(good_route.is_static)
self.assertFalse(bad_route.is_static)
@mock.patch('subprocess.Popen')
def _test_route_add(self, mock_popen, err):
mock_route = mock.Mock()
mock_route.destination = mock.sentinel.destination
mock_route.netmask = mock.sentinel.netmask
mock_route.gateway = mock.sentinel.gateway
args = ['ROUTE', 'ADD', mock.sentinel.destination,
'MASK', mock.sentinel.netmask,
mock.sentinel.gateway]
mock_popen.return_value.returncode = err
mock_popen.return_value.communicate.return_value = (None, err)
if err:
with self.assertRaises(exceptions.CloudInitError) as cm:
self._network_module.Route.add(mock_route)
msg = "Unable to add route: %s" % err
self.assertEqual(msg, str(cm.exception))
else:
self._network_module.Route.add(mock_route)
mock_popen.assert_called_once_with(args, shell=False,
stderr=subprocess.PIPE)
def test_route_add_fails(self):
self._test_route_add(err=1)
def test_route_add_works(self):
self._test_route_add(err=0)
@mock.patch('cloudinit.osys.windows.network.Network._get_forward_table')
def test_routes(self, mock_forward_table):
def _same(arg):
return arg._mock_name.encode()
route = mock.MagicMock()
mock_cast_result = mock.Mock()
mock_cast_result.contents = [route]
self._ctypes_mock.cast.return_value = mock_cast_result
self._network_module.ws2_32.Ws2_32.inet_ntoa.side_effect = _same
route.dwForwardIfIndex = 'dwForwardIfIndex'
route.dwForwardProto = 'dwForwardProto'
route.dwForwardMetric1 = 'dwForwardMetric1'
routes = self._network.routes()
mock_forward_table.assert_called_once_with()
enter = mock_forward_table.return_value.__enter__
enter.assert_called_once_with()
exit_ = mock_forward_table.return_value.__exit__
exit_.assert_called_once_with(None, None, None)
self.assertEqual(1, len(routes))
given_route = routes[0]
self.assertEqual('dwForwardDest', given_route.destination)
self.assertEqual('dwForwardNextHop', given_route.gateway)
self.assertEqual('dwForwardMask', given_route.netmask)
self.assertEqual('dwForwardIfIndex', given_route.interface)
self.assertEqual('dwForwardMetric1', given_route.metric)
self.assertEqual('dwForwardProto', given_route.flags)