diff --git a/trove/guestagent/datastore/cassandra/service.py b/trove/guestagent/datastore/cassandra/service.py index f819508058..2533510b71 100644 --- a/trove/guestagent/datastore/cassandra/service.py +++ b/trove/guestagent/datastore/cassandra/service.py @@ -14,6 +14,7 @@ # under the License. import os +import tempfile import yaml from trove.common import cfg from trove.common import utils @@ -117,22 +118,34 @@ class CassandraApp(object): packager.pkg_install(packages, None, system.INSTALL_TIMEOUT) LOG.debug("Finished installing Cassandra server") - def write_config(self, config_contents): - LOG.debug('Defining temp config holder at %s.' % - system.CASSANDRA_TEMP_CONF) + def write_config(self, config_contents, + execute_function=utils.execute_with_timeout, + mkstemp_function=tempfile.mkstemp, + unlink_function=os.unlink): + # first securely create a temp file. mkstemp() will set + # os.O_EXCL on the open() call, and we get a file with + # permissions of 600 by default. + (conf_fd, conf_path) = mkstemp_function() + + LOG.debug('Storing temporary configuration at %s.' % conf_path) + + # write config and close the file, delete it if there is an + # error. only unlink if there is a problem. In normal course, + # we move the file. try: - with open(system.CASSANDRA_TEMP_CONF, 'w+') as conf: - conf.write(config_contents) - - LOG.info(_('Writing new config.')) - - utils.execute_with_timeout("sudo", "mv", - system.CASSANDRA_TEMP_CONF, - system.CASSANDRA_CONF) + os.write(conf_fd, config_contents) + execute_function("sudo", "mv", conf_path, system.CASSANDRA_CONF) except Exception: - os.unlink(system.CASSANDRA_TEMP_CONF) + LOG.exception( + _("Exception generating Cassandra configuration %s.") % + conf_path) + unlink_function(conf_path) raise + finally: + os.close(conf_fd) + + LOG.info(_('Wrote new Cassandra configuration.')) def read_conf(self): """Returns cassandra.yaml in dict structure.""" diff --git a/trove/tests/unittests/guestagent/test_dbaas.py b/trove/tests/unittests/guestagent/test_dbaas.py index dc3425ba06..721a1c9b29 100644 --- a/trove/tests/unittests/guestagent/test_dbaas.py +++ b/trove/tests/unittests/guestagent/test_dbaas.py @@ -13,6 +13,7 @@ # under the License. import os +import tempfile from uuid import uuid4 import time from mock import Mock @@ -39,6 +40,7 @@ from trove.guestagent.datastore.redis import service as rservice from trove.guestagent.datastore.redis.service import RedisApp from trove.guestagent.datastore.redis import system as RedisSystem from trove.guestagent.datastore.cassandra import service as cass_service +from trove.guestagent.datastore.cassandra import system as cass_system from trove.guestagent.datastore.mysql.service import MySqlAdmin from trove.guestagent.datastore.mysql.service import MySqlRootAccess from trove.guestagent.datastore.mysql.service import MySqlApp @@ -1448,7 +1450,6 @@ class CassandraDBAppTest(testtools.TestCase): rd_instance.ServiceStatuses.NEW) self.cassandra = cass_service.CassandraApp(self.appStatus) self.orig_unlink = os.unlink - os.unlink = Mock() def tearDown(self): @@ -1459,7 +1460,6 @@ class CassandraDBAppTest(testtools.TestCase): cass_service.packager.pkg_version = self.pkg_version cass_service.packager = self.pkg InstanceServiceStatus.find_by(instance_id=self.FAKE_ID).delete() - os.unlink = self.orig_unlink def assert_reported_status(self, expected_status): service_status = InstanceServiceStatus.find_by( @@ -1574,27 +1574,64 @@ class CassandraDBAppTest(testtools.TestCase): self.assert_reported_status(rd_instance.ServiceStatuses.NEW) def test_cassandra_error_in_write_config_verify_unlink(self): + # this test verifies not only that the write_config + # method properly invoked execute, but also that it properly + # attempted to unlink the file (as a result of the exception) from trove.common.exception import ProcessExecutionError - cass_service.utils.execute_with_timeout = ( - Mock(side_effect=ProcessExecutionError('some exception'))) + execute_with_timeout = Mock( + side_effect=ProcessExecutionError('some exception')) + + mock_unlink = Mock(return_value=0) + + # We call tempfile.mkstemp() here and Mock() the mkstemp() + # parameter to write_config for testability. + (temp_handle, temp_config_name) = tempfile.mkstemp() + mock_mkstemp = MagicMock(return_value=(temp_handle, temp_config_name)) + configuration = 'this is my configuration' self.assertRaises(ProcessExecutionError, self.cassandra.write_config, - config_contents=configuration) - self.assertEqual(cass_service.utils.execute_with_timeout.call_count, 1) - self.assertEqual(os.unlink.call_count, 1) + config_contents=configuration, + execute_function=execute_with_timeout, + mkstemp_function=mock_mkstemp, + unlink_function=mock_unlink) - def test_cassandra_error_in_write_config(self): - from trove.common.exception import ProcessExecutionError - cass_service.utils.execute_with_timeout = ( - Mock(side_effect=ProcessExecutionError('some exception'))) - configuration = 'this is my configuration' + self.assertEqual(mock_unlink.call_count, 1) - self.assertRaises(ProcessExecutionError, - self.cassandra.write_config, - config_contents=configuration) - self.assertEqual(cass_service.utils.execute_with_timeout.call_count, 1) + # really delete the temporary_config_file + os.unlink(temp_config_name) + + def test_cassandra_write_config(self): + # ensure that write_config creates a temporary file, and then + # moves the file to the final place. Also validate the + # contents of the file written. + + # We call tempfile.mkstemp() here and Mock() the mkstemp() + # parameter to write_config for testability. + (temp_handle, temp_config_name) = tempfile.mkstemp() + mock_mkstemp = MagicMock(return_value=(temp_handle, temp_config_name)) + + configuration = 'some arbitrary configuration text' + + mock_execute = MagicMock(return_value=('', '')) + + self.cassandra.write_config(configuration, + execute_function=mock_execute, + mkstemp_function=mock_mkstemp) + + mock_execute.assert_called_with("sudo", "mv", + temp_config_name, + cass_system.CASSANDRA_CONF) + mock_mkstemp.assert_called_once() + + with open(temp_config_name, 'r') as config_file: + configuration_data = config_file.read() + + self.assertEqual(configuration, configuration_data) + + # really delete the temporary_config_file + os.unlink(temp_config_name) class CouchbaseAppTest(testtools.TestCase):