cloudbase-init/cloudbaseinit/tests/utils/windows/test_licensing.py

232 lines
9.4 KiB
Python

# Copyright (c) 2017 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import importlib
import unittest
try:
import unittest.mock as mock
except ImportError:
import mock
from cloudbaseinit import exception
from cloudbaseinit.tests import testutils
MODPATH = "cloudbaseinit.utils.windows.licensing"
class LicensingTest(unittest.TestCase):
def setUp(self):
self._wmi_mock = mock.MagicMock()
self._module_patcher = mock.patch.dict(
'sys.modules', {
'wmi': self._wmi_mock})
self.snatcher = testutils.LogSnatcher(MODPATH)
self._module_patcher.start()
self.licensing = importlib.import_module(MODPATH)
self._licensing = self.licensing.LicensingManager()
self._licensing_v2 = self.licensing.LicensingManagerV2()
def tearDown(self):
self._module_patcher.stop()
@mock.patch('cloudbaseinit.osutils.factory.get_os_utils')
def _test_run_slmgr(self, mock_get_os_utils, ret_val=0,
sysnative=True):
mock_args = [mock.sentinel.args]
mock_outval = b"fake-out"
mock_cscriptdir = r"fake\cscript\dir"
mock_osutils = mock.Mock()
mock_osutils.get_sysnative_dir.return_value = mock_cscriptdir
mock_osutils.get_system32_dir.return_value = mock_cscriptdir
mock_get_os_utils.return_value = mock_osutils
mock_osutils.get_system32_dir.return_value = r"fakedir"
mock_osutils.check_sysnative_dir_exists.return_value = sysnative
mock_osutils.execute_process.return_value = (
mock_outval, mock.sentinel.err, ret_val)
if ret_val:
self.assertRaises(exception.CloudbaseInitException,
self._licensing._run_slmgr, mock_args)
else:
res_out = self._licensing._run_slmgr(mock_args)
self.assertEqual(res_out, "fake-out")
self.assertEqual(mock_osutils.execute_process.call_count, 1)
def test_run_slmgr_sys_native(self):
self._test_run_slmgr()
def test_run_slmgr_system32(self):
self._test_run_slmgr(sysnative=False)
def test_run_slmgr_fail(self):
self._test_run_slmgr(ret_val=1)
@mock.patch(MODPATH + ".LicensingManager._run_slmgr")
def test_get_licensing_info(self, mock_run_slmgr):
mock_out = mock.sentinel.out_val
mock_run_slmgr.return_value = mock_out
res = self._licensing.get_licensing_info()
mock_run_slmgr.assert_called_once_with(['/dlv'])
self.assertEqual(res, mock_out)
@mock.patch(MODPATH + ".LicensingManager._run_slmgr")
def test_activate_windows(self, mock_run_slmgr):
mock_out = mock.sentinel.out_val
mock_run_slmgr.return_value = mock_out
res = self._licensing.activate_windows()
mock_run_slmgr.assert_called_once_with(['/ato'])
self.assertEqual(res, mock_out)
@mock.patch(MODPATH + ".LicensingManager._run_slmgr")
def test_set_kms_host(self, mock_run_slmgr):
mock_out = mock.sentinel.out_val
mock_kms = mock.sentinel.kms_host
mock_run_slmgr.return_value = mock_out
res = self._licensing.set_kms_host(mock_kms)
expected_host = "%s:%s" % (mock_kms, self.licensing.DEFAULT_KMS_PORT)
mock_run_slmgr.assert_called_once_with(['/skms', expected_host])
self.assertEqual(res, mock_out)
@mock.patch(MODPATH + ".LicensingManager._run_slmgr")
def test_set_kms_auto_discovery(self, mock_run_slmgr):
mock_out = mock.sentinel.out_val
mock_run_slmgr.return_value = mock_out
res = self._licensing.set_kms_auto_discovery()
mock_run_slmgr.assert_called_once_with(['/ckms'])
self.assertEqual(res, mock_out)
@mock.patch(MODPATH + ".LicensingManager._run_slmgr")
def test_set_product_key(self, mock_run_slmgr):
mock_out = mock.sentinel.out_val
mock_product_key = mock.sentinel.product_key
mock_run_slmgr.return_value = mock_out
res = self._licensing.set_product_key(mock_product_key)
mock_run_slmgr.assert_called_once_with(['/ipk', mock_product_key])
self.assertEqual(res, mock_out)
def test_is_eval_v1(self):
with self.assertRaises(NotImplementedError):
self._licensing.is_eval()
def test_get_kms_product_v1(self):
with self.assertRaises(NotImplementedError):
self._licensing.get_kms_product()
def test_get_volume_activation_product_key_v1(self):
with self.assertRaises(NotImplementedError):
self._licensing.get_volume_activation_product_key('fake')
def test_get_service(self):
mock_result = mock.Mock()
conn = self._wmi_mock.WMI
conn.SoftwareLicensingService.return_value = [mock_result]
self._licensing_v2._get_service()
self.assertIsNotNone(self._licensing_v2._service)
@mock.patch(MODPATH + '.LicensingManagerV2._get_service')
def test_set_product_key_v2(self, mock_get_service):
mock_product_key = mock.Mock()
mock_service = mock.Mock()
mock_get_service.return_value = mock_service
self._licensing_v2.set_product_key(mock_product_key)
mock_get_service.assert_called_once_with()
mock_service.InstallProductKey.assert_called_once_with(
mock_product_key)
@mock.patch(MODPATH + '.LicensingManagerV2._get_service')
def test_set_kms_auto_discovery_v2(self, mock_get_service):
mock_service = mock.Mock()
mock_get_service.return_value = mock_service
self._licensing_v2.set_kms_auto_discovery()
mock_get_service.assert_called_once_with()
mock_service.ClearKeyManagementServiceMachine.assert_called_once_with()
mock_service.ClearKeyManagementServicePort.assert_called_once_with()
@mock.patch(MODPATH + '.LicensingManagerV2._get_service')
def test_set_kms_host_v2(self, mock_get_service):
mock_service = mock.Mock()
mock_host = mock.sentinel.host
mock_port = mock.sentinel.port
mock_get_service.return_value = mock_service
self._licensing_v2.set_kms_host(mock_host, mock_port)
mock_get_service.assert_called_once_with()
mock_service.SetKeyManagementServiceMachine.assert_called_once_with(
mock_host)
mock_service.SetKeyManagementServicePort.assert_called_once_with(
mock_port)
@mock.patch(MODPATH + '.LicensingManagerV2._get_service')
def test_refresh_status_v2(self, mock_get_service):
mock_service = mock.Mock()
mock_get_service.return_value = mock_service
self._licensing_v2.refresh_status()
mock_get_service.assert_called_once_with()
mock_service.RefreshLicenseStatus.assert_called_once_with()
def test_is_current_product(self):
mock_product = mock.Mock()
mock_product.PartialProductKey = "fake-key"
res = self._licensing_v2._is_current_product(mock_product)
self.assertTrue(res)
def test_get_products(self):
mock_result = mock.Mock()
conn = self._wmi_mock.WMI
conn.query.return_value = mock_result
res = self._licensing_v2._get_products()
self.assertEqual(res, self._licensing_v2._products)
@mock.patch(MODPATH + ".LicensingManagerV2._get_products")
def test_is_eval(self, mock_get_products):
mock_product = mock.Mock()
mock_product.ApplicationId = self.licensing.WINDOWS_APP_ID
mock_product.Description = u"TIMEBASED_EVAL"
mock_product.EvaluationEndDate = "fake"
mock_get_products.return_value = [mock_product]
res = self._licensing_v2.is_eval()
self.assertEqual(res, "fake")
@mock.patch(MODPATH + ".LicensingManagerV2._get_products")
def _test_get_kms_product(self, mock_get_products, products=()):
mock_get_products.return_value = products
if not products:
self.assertRaises(exception.ItemNotFoundException,
self._licensing_v2.get_kms_product)
return
res = self._licensing_v2.get_kms_product()
self.assertIsNotNone(res)
def test_get_kms_product_no_keys(self):
self._test_get_kms_product()
def test_get_kms_product(self):
mock_product = mock.Mock()
mock_product.ApplicationId = self.licensing.WINDOWS_APP_ID
mock_product.Description = u"VOLUME_KMSCLIENT"
self._test_get_kms_product(products=[mock_product])
@mock.patch('cloudbaseinit.osutils.factory.get_os_utils')
def test_get_volume_activation_product_key(self, mock_get_os_utils):
mock_os_version = {'major_version': 10, 'minor_version': 0}
expected_key = "WC2BQ-8NRM3-FDDYY-2BFGV-KHKQY"
mock_osutils = mock.Mock()
mock_get_os_utils.return_value = mock_osutils
mock_osutils.get_os_version.return_value = mock_os_version
res = self._licensing_v2.get_volume_activation_product_key(
license_family="ServerStandard")
self.assertEqual(res, expected_key)