diff --git a/translator/toscalib/elements/capabilitytype.py b/translator/toscalib/elements/capabilitytype.py index 5340b39..624e5fa 100644 --- a/translator/toscalib/elements/capabilitytype.py +++ b/translator/toscalib/elements/capabilitytype.py @@ -17,11 +17,11 @@ from translator.toscalib.elements.statefulentitytype import StatefulEntityType class CapabilityTypeDef(StatefulEntityType): '''TOSCA built-in capabilities type.''' - def __init__(self, name, ctype, ntype): + def __init__(self, name, ctype, ntype, custom_def=None): self.name = name - super(CapabilityTypeDef, self).__init__(ctype, self.CAPABILITY_PREFIX) + super(CapabilityTypeDef, self).__init__(ctype, self.CAPABILITY_PREFIX, + custom_def) self.nodetype = ntype - self.defs = self.TOSCA_DEF[ctype] self.properties = None if self.PROPERTIES in self.defs: self.properties = self.defs[self.PROPERTIES] diff --git a/translator/toscalib/elements/nodetype.py b/translator/toscalib/elements/nodetype.py index add4d74..a15fe80 100644 --- a/translator/toscalib/elements/nodetype.py +++ b/translator/toscalib/elements/nodetype.py @@ -139,7 +139,8 @@ class NodeType(StatefulEntityType): if caps: for name, value in caps.items(): ctype = value.get('type') - cap = CapabilityTypeDef(name, ctype, self.type) + cap = CapabilityTypeDef(name, ctype, self.type, + self.custom_def) typecapabilities.append(cap) return typecapabilities diff --git a/translator/toscalib/tests/test_toscatpl.py b/translator/toscalib/tests/test_toscatpl.py index 15e44a2..800360f 100644 --- a/translator/toscalib/tests/test_toscatpl.py +++ b/translator/toscalib/tests/test_toscatpl.py @@ -11,7 +11,9 @@ # under the License. import os +import six +from translator.toscalib.common import exception import translator.toscalib.elements.interfaces as ifaces from translator.toscalib.elements.nodetype import NodeType from translator.toscalib.functions import GetInput @@ -319,3 +321,56 @@ class ToscaTemplateTest(TestCase): self.assertRaises( NotImplementedError, lambda: NodeTemplate(tpl_name, nodetemplates).relationships) + + def test_custom_capability_type_definition(self): + tpl_snippet = ''' + node_templates: + test_app: + type: tosca.nodes.WebApplication.TestApp + capabilities: + test_cap: + properties: + test: 1 + ''' + #custom definition with capability type definition + custom_def = ''' + tosca.nodes.WebApplication.TestApp: + derived_from: tosca.nodes.WebApplication + capabilities: + test_cap: + type: tosca.capabilities.TestCapability + tosca.capabilities.TestCapability: + derived_from: tosca.capabilities.Root + properties: + test: + type: integer + required: no + ''' + expected_capabilities = ['test_cap'] + nodetemplates = (translator.toscalib.utils.yamlparser. + simple_parse(tpl_snippet))['node_templates'] + custom_def = (translator.toscalib.utils.yamlparser. + simple_parse(custom_def)) + name = list(nodetemplates.keys())[0] + tpl = NodeTemplate(name, nodetemplates, custom_def) + self.assertEqual( + expected_capabilities, + sorted([c.name for c in tpl.capabilities])) + + #custom definition without capability type definition + custom_def = ''' + tosca.nodes.WebApplication.TestApp: + derived_from: tosca.nodes.WebApplication + capabilities: + test_cap: + type: tosca.capabilities.TestCapability + ''' + custom_def = (translator.toscalib.utils.yamlparser. + simple_parse(custom_def)) + tpl = NodeTemplate(name, nodetemplates, custom_def) + err = self.assertRaises( + exception.InvalidTypeError, + lambda: NodeTemplate(name, nodetemplates, + custom_def).capabilities) + self.assertEqual('Type "tosca.capabilities.TestCapability" is not ' + 'a valid type.', six.text_type(err)) diff --git a/translator/toscalib/tosca_template.py b/translator/toscalib/tosca_template.py index 4cd2ccb..196a940 100644 --- a/translator/toscalib/tosca_template.py +++ b/translator/toscalib/tosca_template.py @@ -78,7 +78,9 @@ class ToscaTemplate(object): data_types = self._get_custom_types(DATATYPE_DEFINITIONS) if data_types: custom_defs.update(data_types) - + capability_types = self._get_custom_types(CAPABILITY_TYPES) + if capability_types: + custom_defs.update(capability_types) nodetemplates = [] tpls = self._tpl_nodetemplates() for name in tpls: