diff --git a/cloudbaseinit/tests/utils/template_engine/test_base_template.py b/cloudbaseinit/tests/utils/template_engine/test_base_template.py index 10717cbd..cf6e6514 100644 --- a/cloudbaseinit/tests/utils/template_engine/test_base_template.py +++ b/cloudbaseinit/tests/utils/template_engine/test_base_template.py @@ -22,6 +22,7 @@ from cloudbaseinit.utils.template_engine import base_template as bt class TestBaseTemplateEngine(unittest.TestCase): @ddt.data((b'', b''), + (None, None), (b'## template:jinja test', b''), (b'## template:jinja \ntest', b'test')) @ddt.unpack diff --git a/cloudbaseinit/tests/utils/template_engine/test_jinja2_template.py b/cloudbaseinit/tests/utils/template_engine/test_jinja2_template.py index 5d9f52a7..0bf19f3c 100644 --- a/cloudbaseinit/tests/utils/template_engine/test_jinja2_template.py +++ b/cloudbaseinit/tests/utils/template_engine/test_jinja2_template.py @@ -13,6 +13,7 @@ # under the License. +import ddt import unittest try: import unittest.mock as mock @@ -23,6 +24,7 @@ from cloudbaseinit.utils.template_engine.jinja2_template import ( Jinja2TemplateEngine) +@ddt.ddt class TestJinja2TemplateEngine(unittest.TestCase): @mock.patch('cloudbaseinit.utils.template_engine.base_template' @@ -80,3 +82,13 @@ class TestJinja2TemplateEngine(unittest.TestCase): fake_instance_data=fake_instance_data, expected_result=expected_result, fake_template=fake_template) + + @ddt.data((b'', None), + (None, None), + (b'## template:jinja \n#ps1 \nmkdir', True), + (b'## template:jinja test', None), + (b'## template:jinjanone \ntest', None)) + @ddt.unpack + def test_load_template_definition(self, userdata, expected_output): + output = Jinja2TemplateEngine().load(userdata) + self.assertEqual(expected_output, output) diff --git a/cloudbaseinit/utils/template_engine/base_template.py b/cloudbaseinit/utils/template_engine/base_template.py index d3ae7aaa..0a8ec31e 100644 --- a/cloudbaseinit/utils/template_engine/base_template.py +++ b/cloudbaseinit/utils/template_engine/base_template.py @@ -42,13 +42,23 @@ class BaseTemplateEngine(object): def load(self, data): """Returns True if the template header matches, False otherwise""" + if not data: + return + template_type_matcher = self._template_matcher.match(data.decode()) + if not template_type_matcher: + return + template_type = template_type_matcher.group(1).lower().strip() if self.get_template_type() == template_type: return True @staticmethod def remove_template_definition(raw_template): + # return the raw template as is if it is None or empty array / dict + if not raw_template: + return raw_template + # Remove the first line, as it contains the template definition template_split = raw_template.split(b"\n", 1)