diff --git a/oslotest/base.py b/oslotest/base.py index 213b28c..5b85b16 100644 --- a/oslotest/base.py +++ b/oslotest/base.py @@ -20,6 +20,7 @@ import os import tempfile import fixtures +import six from six.moves import mock import testtools @@ -121,17 +122,29 @@ class BaseTestCase(testtools.TestCase): else: logging.basicConfig(format=_LOG_FORMAT, level=level) - def create_tempfiles(self, files, ext='.conf'): + def create_tempfiles(self, files, ext='.conf', default_encoding='utf-8'): """Safely create temporary files. :param files: Sequence of tuples containing (filename, file_contents). :type files: list of tuple :param ext: File name extension for the temporary file. :type ext: str + :param default_encoding: Default file content encoding when it is + not provided, used to decode the tempfile + contents from a text string into a binary + string. + :type default_encoding: str :return: A list of str with the names of the files created. """ tempfiles = [] - for (basename, contents) in files: + for f in files: + if len(f) == 3: + basename, contents, encoding = f + else: + basename, contents = f + encoding = default_encoding + if isinstance(contents, six.text_type): + contents = contents.encode(encoding) if not os.path.isabs(basename): (fd, path) = tempfile.mkstemp(prefix=basename, suffix=ext) else: diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py index 0816041..3ba0fa7 100644 --- a/tests/unit/test_base.py +++ b/tests/unit/test_base.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + # Copyright 2014 Deutsche Telekom AG # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,8 +15,10 @@ # under the License. import logging +import os import unittest +import six from six.moves import mock import testtools @@ -121,3 +125,52 @@ class TestManualMock(base.BaseTestCase): patcher = mock.patch('os.environ.get') patcher.start() self.addCleanup(patcher.stop) + + +class TestTempFiles(base.BaseTestCase): + def test_create_unicode_files(self): + files = [["no_approve", u'ಠ_ಠ']] + temps = self.create_tempfiles(files) + self.assertEqual(1, len(temps)) + with open(temps[0], 'rb') as f: + contents = f.read() + self.assertEqual(u'ಠ_ಠ', six.text_type(contents, encoding='utf-8')) + + def test_create_unicode_files_encoding(self): + files = [["embarrassed", u'⊙﹏⊙', 'utf-8']] + temps = self.create_tempfiles(files) + self.assertEqual(1, len(temps)) + with open(temps[0], 'rb') as f: + contents = f.read() + self.assertEqual(u'⊙﹏⊙', six.text_type(contents, encoding='utf-8')) + + def test_create_unicode_files_multi_encoding(self): + files = [ + ["embarrassed", u'⊙﹏⊙', 'utf-8'], + ['abc', 'abc', 'ascii'], + ] + temps = self.create_tempfiles(files) + self.assertEqual(2, len(temps)) + for i, (basename, raw_contents, raw_encoding) in enumerate(files): + with open(temps[i], 'rb') as f: + contents = f.read() + if not isinstance(raw_contents, six.text_type): + raw_contents = six.text_type(raw_contents, + encoding=raw_encoding) + self.assertEqual(raw_contents, + six.text_type(contents, encoding=raw_encoding)) + + def test_create_bad_encoding(self): + files = [["hrm", u'ಠ~ಠ', 'ascii']] + self.assertRaises(UnicodeError, self.create_tempfiles, files) + + def test_prefix(self): + files = [["testing", '']] + temps = self.create_tempfiles(files) + self.assertEqual(1, len(temps)) + basename = os.path.basename(temps[0]) + self.assertTrue(basename.startswith('testing')) + + def test_wrong_length(self): + files = [["testing"]] + self.assertRaises(ValueError, self.create_tempfiles, files)