diff --git a/taskflow/persistence/backends/impl_memory.py b/taskflow/persistence/backends/impl_memory.py index b9489835..31913a6b 100644 --- a/taskflow/persistence/backends/impl_memory.py +++ b/taskflow/persistence/backends/impl_memory.py @@ -27,6 +27,13 @@ from taskflow.types import tree from taskflow.utils import lock_utils +class FakeInode(tree.Node): + """A in-memory filesystem inode-like object.""" + + def __init__(self, item, path, value=None): + super(FakeInode, self).__init__(item, path=path, value=value) + + class FakeFilesystem(object): """An in-memory filesystem-like structure. @@ -79,7 +86,10 @@ class FakeFilesystem(object): return pp.sep.join(pieces) def __init__(self, deep_copy=True): - self._root = tree.Node(self.root_path, value=None) + self._root = FakeInode(self.root_path, self.root_path) + self._reverse_mapping = { + self.root_path: self._root, + } if deep_copy: self._copier = copy.deepcopy else: @@ -97,21 +107,33 @@ class FakeFilesystem(object): child_node = node.find(piece, only_direct=True, include_self=False) if child_node is None: - child_node = tree.Node(piece, value=None) - node.add(child_node) + child_node = self._insert_child(node, piece) node = child_node - def _fetch_node(self, path): - node = self._root - path = self.normpath(path) - if path == self._root.item: - return node - for piece in self._iter_pieces(path): - node = node.find(piece, only_direct=True, - include_self=False) - if node is None: - raise exc.NotFound("Path '%s' not found" % path) - return node + def _insert_child(self, parent_node, basename, value=None): + child_path = self.join(parent_node.metadata['path'], basename) + # This avoids getting '//a/b' (duplicated sep at start)... + # + # Which can happen easily if something like the following is given. + # >>> x = ['/', 'b'] + # >>> pp.sep.join(x) + # '//b' + if child_path.startswith(pp.sep * 2): + child_path = child_path[1:] + child_node = FakeInode(basename, child_path, value=value) + parent_node.add(child_node) + self._reverse_mapping[child_path] = child_node + return child_node + + def _fetch_node(self, path, normalized=False): + if not normalized: + normed_path = self.normpath(path) + else: + normed_path = path + try: + return self._reverse_mapping[normed_path] + except KeyError: + raise exc.NotFound("Path '%s' not found" % path) def get(self, path, default=None): """Fetch the value of given path (and return default if not found).""" @@ -142,23 +164,14 @@ class FakeFilesystem(object): if not recursive: return [node.item for node in self._fetch_node(path)] else: - paths = [] node = self._fetch_node(path) - for child in node.bfs_iter(): - # Reconstruct the child's path... - hops = [child.item] - for parent in child.path_iter(include_self=False): - hops.append(parent.item) - hops.reverse() - # This avoids getting '//a/b' (duplicated sep at start)... - child_path = self.join(*hops) - if child_path.startswith("//"): - child_path = child_path[1:] - paths.append(child_path) - return paths + return [child.metadata['path'] for child in node.bfs_iter()] def clear(self): """Remove all nodes (except the root) from this filesystem.""" + self._reverse_mapping = { + self.root_path: self._root, + } for node in list(self._root.reverse_iter()): node.disassociate() @@ -179,9 +192,14 @@ class FakeFilesystem(object): yield piece def __delitem__(self, path): - node = self._fetch_node(path) + path = self.normpath(path) + node = self._fetch_node(path, normalized=True) if node is self._root: raise ValueError("Can not delete '%s'" % self._root.item) + removals = [path] + removals.extend(child.metadata['path'] for child in node.bfs_iter()) + for path in removals: + self._reverse_mapping.pop(path, None) node.disassociate() @staticmethod @@ -200,13 +218,12 @@ class FakeFilesystem(object): dest_path = self.normpath(dest_path) src_path = self.normpath(src_path) dirname, basename = pp.split(dest_path) - parent_node = self._fetch_node(dirname) + parent_node = self._fetch_node(dirname, normalized=True) child_node = parent_node.find(basename, only_direct=True, include_self=False) if child_node is None: - child_node = tree.Node(basename, value=None) - parent_node.add(child_node) + child_node = self._insert_child(parent_node, basename) child_node.metadata['target'] = src_path def __getitem__(self, path): @@ -216,12 +233,12 @@ class FakeFilesystem(object): path = self.normpath(path) value = self._copier(value) try: - item_node = self._fetch_node(path) + item_node = self._fetch_node(path, normalized=True) item_node.metadata.update(value=value) except exc.NotFound: dirname, basename = pp.split(path) - parent_node = self._fetch_node(dirname) - parent_node.add(tree.Node(basename, value=value)) + parent_node = self._fetch_node(dirname, normalized=True) + self._insert_child(parent_node, basename, value=value) class MemoryBackend(path_based.PathBasedBackend):