Implement graph builder

This is a generic graph builder. It will later be integrated to the
protection service to allow the building of the resource tree from
protectables.

Change-Id: Id5ac431cc6b5edff06c312eaae22ea5e1d4d8eda
This commit is contained in:
Saggi Mizrahi 2016-02-09 17:54:05 +02:00
parent 505941c061
commit 94a5581983
2 changed files with 162 additions and 0 deletions

View File

@ -0,0 +1,84 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from collections import namedtuple
from oslo_log import log as logging
from smaug.i18n import _
_GraphBuilderContext = namedtuple("_GraphBuilderContext", (
"source_set",
"encounterd_set",
"finished_nodes",
"get_child_nodes",
))
GraphNode = namedtuple("GraphNode", (
"value",
"child_nodes",
))
LOG = logging.getLogger(__name__)
class FoundLoopError(RuntimeError):
def __init__(self):
super(FoundLoopError, self).__init__(
_("A loop was found in the graph"))
def _build_graph_rec(context, node):
LOG.trace("Entered node: %s", node)
source_set = context.source_set
encounterd_set = context.encounterd_set
finished_nodes = context.finished_nodes
LOG.trace("Gray set is %s", encounterd_set)
if node in encounterd_set:
raise FoundLoopError()
LOG.trace("Black set is %s", finished_nodes.keys())
if node in finished_nodes:
return finished_nodes[node]
LOG.trace("Change to gray: %s", node)
encounterd_set.add(node)
child_nodes = context.get_child_nodes(node)
LOG.trace("child nodes are ", child_nodes)
# If we found a parent than this is not a source
source_set.difference_update(child_nodes)
child_list = []
for child_node in child_nodes:
child_list.append(_build_graph_rec(context, child_node))
LOG.trace("Change to black: ", node)
encounterd_set.discard(node)
graph_node = GraphNode(value=node, child_nodes=tuple(child_list))
finished_nodes[node] = graph_node
return graph_node
def build_graph(start_nodes, get_child_nodes_func):
context = _GraphBuilderContext(
source_set=set(start_nodes),
encounterd_set=set(),
finished_nodes={},
get_child_nodes=get_child_nodes_func,
)
result = []
for node in start_nodes:
result.append(_build_graph_rec(context, node))
assert(len(context.encounterd_set) == 0)
return [item for item in result if item.value in context.source_set]

View File

@ -0,0 +1,78 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import smaug.services.protection.graph as graph
from smaug.tests import base
class GraphBuilderTest(base.TestCase):
def test_source_set(self):
"""Test that the source set only contains sources"""
test_matrix = (
({
"A": ["B"],
"B": ["C"],
"C": [],
}, {"A"}),
({
"A": [],
"B": ["C"],
"C": [],
}, {"A", "B"}),
({
"A": ["C"],
"B": ["C"],
"C": [],
}, {"A", "B"}),
)
for g, expected_result in test_matrix:
result = graph.build_graph(g.keys(), g.__getitem__)
self.assertEqual({node.value for node in result}, expected_result)
def test_detect_cyclic_graph(self):
"""Test that cyclic graphs are detected"""
test_matrix = (
({
"A": ["B"],
"B": ["C"],
"C": [],
}, False),
({
"A": [],
"B": ["C"],
"C": [],
}, False),
({
"A": ["C"],
"B": ["C"],
"C": ["A"],
}, True),
({
"A": ["B"],
"B": ["C"],
"C": ["A"],
}, True),
)
for g, expected_result in test_matrix:
if expected_result:
self.assertRaises(
graph.FoundLoopError,
graph.build_graph,
g.keys(), g.__getitem__,
)
else:
graph.build_graph(g.keys(), g.__getitem__)