Add state to NodeBase class

Making the global state reference a defined part of the node makes
some parts of the block device processing easier and removes the need
for other global values.

The state is passed to PluginNodeBase.__init__() and expected to be
passed into all nodes as they are created.  NodeBase.__init__() is
updated with the new paramater 'state'.

The parameter is removed from the create() call as nodes can simply
reference it at any point as "self.state".

This is similar to 1cdc8b20373c5d582ea928cfd7334469ff36dbce, except it
is based on I68840594a34af28d41d9522addcfd830bd203b97 which loads the
node-list from pickled state for later cmd_* calls.  Thus we only
build the state *once*, at cmd_create() time as we build the node
list.

Change-Id: I468dbf5134947629f125504513703d6f2cdace59
This commit is contained in:
Ian Wienand 2017-06-01 14:31:49 +10:00
parent e82e0097a9
commit 824a9e91c4
14 changed files with 117 additions and 99 deletions

View File

@ -44,18 +44,19 @@ def _load_json(file_name):
class BlockDeviceState(collections.MutableMapping): class BlockDeviceState(collections.MutableMapping):
"""The global state singleton """The global state singleton
An reference to an instance of this object is passed between nodes An reference to an instance of this object is saved into nodes as
as a global repository. It wraps a single dictionary "state" a global repository. It wraps a single dictionary "state" and
and provides a few helper functions. provides a few helper functions.
This is used in two contexts: The state ends up used in two contexts:
- The state is built by the :func:`NodeBase.create` commands as - The node list (including this state) is pickled and dumped
called during :func:`BlockDevice.cmd_create`. It is then between cmd_create() and later cmd_* calls that need to call
persisted to disk by :func:`save_state` the nodes.
- Later calls (cleanup, umount, etc) load the state dictionary - Some other cmd_* calls, such as cmd_writefstab, only need
from disk and are thus passed the full state. access to values inside the state and not the whole node list,
and load it from the json dump created after cmd_create()
""" """
# XXX: # XXX:
# - we could implement getters/setters such that if loaded from # - we could implement getters/setters such that if loaded from
@ -373,9 +374,9 @@ class BlockDevice(object):
# Create a new, empty state # Create a new, empty state
state = BlockDeviceState() state = BlockDeviceState()
try: try:
dg, call_order = create_graph(self.config, self.params) dg, call_order = create_graph(self.config, self.params, state)
for node in call_order: for node in call_order:
node.create(state, rollback) node.create(rollback)
except Exception: except Exception:
logger.exception("Create failed; rollback initiated") logger.exception("Create failed; rollback initiated")
for rollback_cb in reversed(rollback): for rollback_cb in reversed(rollback):

View File

@ -142,13 +142,15 @@ def config_tree_to_graph(config):
return output return output
def create_graph(config, default_config): def create_graph(config, default_config, state):
"""Generate configuration digraph """Generate configuration digraph
Generate the configuration digraph from the config Generate the configuration digraph from the config
:param config: graph configuration file :param config: graph configuration file
:param default_config: default parameters (from --params) :param default_config: default parameters (from --params)
:param state: reference to global state dictionary.
Passed to :func:`PluginBase.__init__`
:return: tuple with the graph object (a :class:`nx.Digraph`), :return: tuple with the graph object (a :class:`nx.Digraph`),
ordered list of :class:`NodeBase` objects ordered list of :class:`NodeBase` objects
@ -175,7 +177,7 @@ def create_graph(config, default_config):
("Config element [%s] is not implemented" % cfg_obj_name)) ("Config element [%s] is not implemented" % cfg_obj_name))
plugin = _extensions[cfg_obj_name].plugin plugin = _extensions[cfg_obj_name].plugin
assert issubclass(plugin, PluginBase) assert issubclass(plugin, PluginBase)
cfg_obj = plugin(cfg_obj_val, default_config) cfg_obj = plugin(cfg_obj_val, default_config, state)
# Ask the plugin for the nodes it would like to insert # Ask the plugin for the nodes it would like to insert
# into the graph. Some plugins, such as partitioning, # into the graph. Some plugins, such as partitioning,

View File

@ -80,10 +80,10 @@ class LocalLoopNode(NodeBase):
This class handles local loop devices that can be used This class handles local loop devices that can be used
for VM image installation. for VM image installation.
""" """
def __init__(self, config, default_config): def __init__(self, config, default_config, state):
logger.debug("Creating LocalLoop object; config [%s] " logger.debug("Creating LocalLoop object; config [%s] "
"default_config [%s]", config, default_config) "default_config [%s]", config, default_config)
super(LocalLoopNode, self).__init__(config['name']) super(LocalLoopNode, self).__init__(config['name'], state)
if 'size' in config: if 'size' in config:
self.size = parse_abs_size_spec(config['size']) self.size = parse_abs_size_spec(config['size'])
logger.debug("Image size [%s]", self.size) logger.debug("Image size [%s]", self.size)
@ -100,7 +100,7 @@ class LocalLoopNode(NodeBase):
"""Because this is created without base, there are no edges.""" """Because this is created without base, there are no edges."""
return ([], []) return ([], [])
def create(self, state, rollback): def create(self, rollback):
logger.debug("[%s] Creating loop on [%s] with size [%d]", logger.debug("[%s] Creating loop on [%s] with size [%d]",
self.name, self.filename, self.size) self.name, self.filename, self.size)
@ -110,11 +110,11 @@ class LocalLoopNode(NodeBase):
block_device = loopdev_attach(self.filename) block_device = loopdev_attach(self.filename)
rollback.append(lambda: loopdev_detach(block_device)) rollback.append(lambda: loopdev_detach(block_device))
if 'blockdev' not in state: if 'blockdev' not in self.state:
state['blockdev'] = {} self.state['blockdev'] = {}
state['blockdev'][self.name] = {"device": block_device, self.state['blockdev'][self.name] = {"device": block_device,
"image": self.filename} "image": self.filename}
logger.debug("Created loop name [%s] device [%s] image [%s]", logger.debug("Created loop name [%s] device [%s] image [%s]",
self.name, block_device, self.filename) self.name, block_device, self.filename)
return return
@ -131,9 +131,9 @@ class LocalLoopNode(NodeBase):
class LocalLoop(PluginBase): class LocalLoop(PluginBase):
def __init__(self, config, defaults): def __init__(self, config, defaults, state):
super(LocalLoop, self).__init__() super(LocalLoop, self).__init__()
self.node = LocalLoopNode(config, defaults) self.node = LocalLoopNode(config, defaults, state)
def get_nodes(self): def get_nodes(self):
return [self.node] return [self.node]

View File

@ -25,9 +25,9 @@ class PartitionNode(NodeBase):
flag_boot = 1 flag_boot = 1
flag_primary = 2 flag_primary = 2
def __init__(self, config, parent, prev_partition): def __init__(self, config, state, parent, prev_partition):
super(PartitionNode, self).__init__(config['name']) super(PartitionNode, self).__init__(config['name'], state)
self.base = config['base'] self.base = config['base']
self.partitioning = parent self.partitioning = parent
@ -65,5 +65,5 @@ class PartitionNode(NodeBase):
edge_from.append(self.prev_partition.name) edge_from.append(self.prev_partition.name)
return (edge_from, edge_to) return (edge_from, edge_to)
def create(self, state, rollback): def create(self, rollback):
self.partitioning.create(state, rollback) self.partitioning.create(rollback)

View File

@ -32,10 +32,15 @@ logger = logging.getLogger(__name__)
class Partitioning(PluginBase): class Partitioning(PluginBase):
def __init__(self, config, default_config): def __init__(self, config, default_config, state):
logger.debug("Creating Partitioning object; config [%s]", config) logger.debug("Creating Partitioning object; config [%s]", config)
super(Partitioning, self).__init__() super(Partitioning, self).__init__()
# Unlike other PluginBase we are somewhat persistent, as the
# partition nodes call back to us (see create() below). We
# need to keep this reference.
self.state = state
# Because using multiple partitions of one base is done # Because using multiple partitions of one base is done
# within one object, there is the need to store a flag if the # within one object, there is the need to store a flag if the
# creation of the partitions was already done. # creation of the partitions was already done.
@ -76,7 +81,7 @@ class Partitioning(PluginBase):
prev_partition = None prev_partition = None
for part_cfg in config['partitions']: for part_cfg in config['partitions']:
np = PartitionNode(part_cfg, self, prev_partition) np = PartitionNode(part_cfg, state, self, prev_partition)
self.partitions.append(np) self.partitions.append(np)
prev_partition = np prev_partition = np
@ -127,12 +132,12 @@ class Partitioning(PluginBase):
exec_sudo(["kpartx", "-avs", device_path]) exec_sudo(["kpartx", "-avs", device_path])
def create(self, state, rollback): def create(self, rollback):
# not this is NOT a node and this is not called directly! The # not this is NOT a node and this is not called directly! The
# create() calls in the partition nodes this plugin has # create() calls in the partition nodes this plugin has
# created are calling back into this. # created are calling back into this.
image_path = state['blockdev'][self.base]['image'] image_path = self.state['blockdev'][self.base]['image']
device_path = state['blockdev'][self.base]['device'] device_path = self.state['blockdev'][self.base]['device']
logger.info("Creating partition on [%s] [%s]", self.base, image_path) logger.info("Creating partition on [%s] [%s]", self.base, image_path)
# This is a bit of a hack. Each of the partitions is actually # This is a bit of a hack. Each of the partitions is actually
@ -166,7 +171,7 @@ class Partitioning(PluginBase):
logger.debug("Create partition [%s] [%d]", logger.debug("Create partition [%s] [%d]",
part_name, part_no) part_name, part_no)
partition_device_name = device_path + "p%d" % part_no partition_device_name = device_path + "p%d" % part_no
state['blockdev'][part_name] \ self.state['blockdev'][part_name] \
= {'device': partition_device_name} = {'device': partition_device_name}
partition_devices.add(partition_device_name) partition_devices.add(partition_device_name)

View File

@ -43,9 +43,9 @@ file_system_max_label_length = {
class FilesystemNode(NodeBase): class FilesystemNode(NodeBase):
def __init__(self, config): def __init__(self, config, state):
logger.debug("Create filesystem object; config [%s]", config) logger.debug("Create filesystem object; config [%s]", config)
super(FilesystemNode, self).__init__(config['name']) super(FilesystemNode, self).__init__(config['name'], state)
# Parameter check (mandatory) # Parameter check (mandatory)
for pname in ['base', 'type']: for pname in ['base', 'type']:
@ -102,7 +102,7 @@ class FilesystemNode(NodeBase):
edge_to = [] edge_to = []
return (edge_from, edge_to) return (edge_from, edge_to)
def create(self, state, rollback): def create(self, rollback):
cmd = ["mkfs"] cmd = ["mkfs"]
cmd.extend(['-t', self.type]) cmd.extend(['-t', self.type])
@ -121,17 +121,17 @@ class FilesystemNode(NodeBase):
if self.type in ('ext2', 'ext3', 'ext4', 'xfs'): if self.type in ('ext2', 'ext3', 'ext4', 'xfs'):
cmd.append('-q') cmd.append('-q')
if 'blockdev' not in state: if 'blockdev' not in self.state:
state['blockdev'] = {} self.state['blockdev'] = {}
device = state['blockdev'][self.base]['device'] device = self.state['blockdev'][self.base]['device']
cmd.append(device) cmd.append(device)
logger.debug("Creating fs command [%s]", cmd) logger.debug("Creating fs command [%s]", cmd)
exec_sudo(cmd) exec_sudo(cmd)
if 'filesys' not in state: if 'filesys' not in self.state:
state['filesys'] = {} self.state['filesys'] = {}
state['filesys'][self.name] \ self.state['filesys'][self.name] \
= {'uuid': self.uuid, 'label': self.label, = {'uuid': self.uuid, 'label': self.label,
'fstype': self.type, 'opts': self.opts, 'fstype': self.type, 'opts': self.opts,
'device': device} 'device': device}
@ -144,10 +144,10 @@ class Mkfs(PluginBase):
systems. systems.
""" """
def __init__(self, config, defaults): def __init__(self, config, defaults, state):
super(Mkfs, self).__init__() super(Mkfs, self).__init__()
self.filesystems = {} self.filesystems = {}
fs = FilesystemNode(config) fs = FilesystemNode(config, state)
self.filesystems[fs.get_name()] = fs self.filesystems[fs.get_name()] = fs
def get_nodes(self): def get_nodes(self):

View File

@ -31,8 +31,8 @@ sorted_mount_points = []
class MountPointNode(NodeBase): class MountPointNode(NodeBase):
def __init__(self, mount_base, config): def __init__(self, mount_base, config, state):
super(MountPointNode, self).__init__(config['name']) super(MountPointNode, self).__init__(config['name'], state)
# Parameter check # Parameter check
self.mount_base = mount_base self.mount_base = mount_base
@ -72,7 +72,7 @@ class MountPointNode(NodeBase):
edge_from.append(self.base) edge_from.append(self.base)
return (edge_from, edge_to) return (edge_from, edge_to)
def create(self, state, rollback): def create(self, rollback):
logger.debug("mount called [%s]", self.mount_point) logger.debug("mount called [%s]", self.mount_point)
rel_mp = self.mount_point if self.mount_point[0] != '/' \ rel_mp = self.mount_point if self.mount_point[0] != '/' \
else self.mount_point[1:] else self.mount_point[1:]
@ -82,17 +82,17 @@ class MountPointNode(NodeBase):
# file system tree. # file system tree.
exec_sudo(['mkdir', '-p', mount_point]) exec_sudo(['mkdir', '-p', mount_point])
logger.info("Mounting [%s] to [%s]", self.name, mount_point) logger.info("Mounting [%s] to [%s]", self.name, mount_point)
exec_sudo(["mount", state['filesys'][self.base]['device'], exec_sudo(["mount", self.state['filesys'][self.base]['device'],
mount_point]) mount_point])
if 'mount' not in state: if 'mount' not in self.state:
state['mount'] = {} self.state['mount'] = {}
state['mount'][self.mount_point] \ self.state['mount'][self.mount_point] \
= {'name': self.name, 'base': self.base, 'path': mount_point} = {'name': self.name, 'base': self.base, 'path': mount_point}
if 'mount_order' not in state: if 'mount_order' not in self.state:
state['mount_order'] = [] self.state['mount_order'] = []
state['mount_order'].append(self.mount_point) self.state['mount_order'].append(self.mount_point)
def umount(self, state): def umount(self, state):
logger.info("Called for [%s]", self.name) logger.info("Called for [%s]", self.name)
@ -103,13 +103,13 @@ class MountPointNode(NodeBase):
class Mount(PluginBase): class Mount(PluginBase):
def __init__(self, config, defaults): def __init__(self, config, defaults, state):
super(Mount, self).__init__() super(Mount, self).__init__()
if 'mount-base' not in defaults: if 'mount-base' not in defaults:
raise BlockDeviceSetupException( raise BlockDeviceSetupException(
"Mount default config needs 'mount-base'") "Mount default config needs 'mount-base'")
self.node = MountPointNode(defaults['mount-base'], config) self.node = MountPointNode(defaults['mount-base'], config, state)
# save this new node to the global mount-point list and # save this new node to the global mount-point list and
# re-order it. # re-order it.

View File

@ -22,8 +22,8 @@ logger = logging.getLogger(__name__)
class FstabNode(NodeBase): class FstabNode(NodeBase):
def __init__(self, config, params): def __init__(self, config, state):
super(FstabNode, self).__init__(config['name']) super(FstabNode, self).__init__(config['name'], state)
self.base = config['base'] self.base = config['base']
self.options = config.get('options', 'defaults') self.options = config.get('options', 'defaults')
self.dump_freq = config.get('dump-freq', 0) self.dump_freq = config.get('dump-freq', 0)
@ -34,13 +34,13 @@ class FstabNode(NodeBase):
edge_to = [] edge_to = []
return (edge_from, edge_to) return (edge_from, edge_to)
def create(self, state, rollback): def create(self, rollback):
logger.debug("fstab create called [%s]", self.name) logger.debug("fstab create called [%s]", self.name)
if 'fstab' not in state: if 'fstab' not in self.state:
state['fstab'] = {} self.state['fstab'] = {}
state['fstab'][self.base] = { self.state['fstab'][self.base] = {
'name': self.name, 'name': self.name,
'base': self.base, 'base': self.base,
'options': self.options, 'options': self.options,
@ -50,10 +50,10 @@ class FstabNode(NodeBase):
class Fstab(PluginBase): class Fstab(PluginBase):
def __init__(self, config, defaults): def __init__(self, config, defaults, state):
super(Fstab, self).__init__() super(Fstab, self).__init__()
self.node = FstabNode(config, defaults) self.node = FstabNode(config, state)
def get_nodes(self): def get_nodes(self):
return [self.node] return [self.node]

View File

@ -32,7 +32,7 @@ class NodeBase(object):
Every node has a unique string ``name``. This is its key in the Every node has a unique string ``name``. This is its key in the
graph and used for edge relationships. Implementations must graph and used for edge relationships. Implementations must
ensure they initialize it; e.g. ensure they initalize it; e.g.
.. code-block:: python .. code-block:: python
@ -41,8 +41,9 @@ class NodeBase(object):
super(FooNode, self).__init__(name) super(FooNode, self).__init__(name)
""" """
def __init__(self, name): def __init__(self, name, state):
self.name = name self.name = name
self.state = state
def get_name(self): def get_name(self):
return self.name return self.name
@ -74,7 +75,7 @@ class NodeBase(object):
return return
@abc.abstractmethod @abc.abstractmethod
def create(self, state, rollback): def create(self, rollback):
"""Main creation driver """Main creation driver
This is the main driver function. After the graph is This is the main driver function. After the graph is
@ -82,12 +83,6 @@ class NodeBase(object):
Arguments: Arguments:
:param state: A shared dictionary of prior results. This
dictionary is passed by reference to each call, meaning any
entries inserted will be available to subsequent :func:`create`
calls of following nodes. The ``state`` dictionary will be
saved and available to other calls.
:param rollback: A shared list of functions to be called in :param rollback: A shared list of functions to be called in
the failure case. Nodes should only append to this list. the failure case. Nodes should only append to this list.
On failure, the callbacks will be processed in reverse On failure, the callbacks will be processed in reverse
@ -164,13 +159,16 @@ class PluginBase(object):
argument_a: bar argument_a: bar
argument_b: baz argument_b: baz
The ``__init__`` function will be passed two arguments: The ``__init__`` function will be passed three arguments:
``config`` ``config``
The full configuration dictionary for the entry. The full configuration dictionary for the entry.
A unique ``name`` entry can be assumed. In most cases A unique ``name`` entry can be assumed. In most cases
a ``base`` entry will be present giving the parent node a ``base`` entry will be present giving the parent node
(see :func:`NodeBase.get_edges`). (see :func:`NodeBase.get_edges`).
``state``
A reference to the gobal state dictionary. This should be
passed to :func:`NodeBase.__init__` on node creation
``defaults`` ``defaults``
The global defaults dictionary (see ``--params``) The global defaults dictionary (see ``--params``)
@ -183,9 +181,9 @@ class PluginBase(object):
class Foo(PluginBase): class Foo(PluginBase):
def __init__(self, config, defaults): def __init__(self, config, defaults, state):
super(Foo, self).__init__() super(Foo, self).__init__()
self.node = FooNode(config.name, ...) self.node = FooNode(config.name, state, ...)
def get_nodes(self): def get_nodes(self):
return [self.node] return [self.node]

View File

@ -22,20 +22,22 @@ logger = logging.getLogger(__name__)
class TestANode(NodeBase): class TestANode(NodeBase):
def __init__(self, name): def __init__(self, name, state):
logger.debug("Create test 1") logger.debug("Create test 1")
super(TestANode, self).__init__(name) super(TestANode, self).__init__(name, state)
# put something in the state for test_b to check for
state['test_init_state'] = 'here'
def get_edges(self): def get_edges(self):
# this is like the loop node; it's a root and doesn't have a # this is like the loop node; it's a root and doesn't have a
# base # base
return ([], []) return ([], [])
def create(self, state, rollback): def create(self, rollback):
# put some fake entries into state # put some fake entries into state
state['test_a'] = {} self.state['test_a'] = {}
state['test_a']['value'] = 'foo' self.state['test_a']['value'] = 'foo'
state['test_a']['value2'] = 'bar' self.state['test_a']['value2'] = 'bar'
return return
def umount(self, state): def umount(self, state):
@ -45,9 +47,9 @@ class TestANode(NodeBase):
class TestA(PluginBase): class TestA(PluginBase):
def __init__(self, config, defaults): def __init__(self, config, defaults, state):
super(TestA, self).__init__() super(TestA, self).__init__()
self.node = TestANode(config['name']) self.node = TestANode(config['name'], state)
def get_nodes(self): def get_nodes(self):
return [self.node] return [self.node]

View File

@ -22,17 +22,20 @@ logger = logging.getLogger(__name__)
class TestBNode(NodeBase): class TestBNode(NodeBase):
def __init__(self, name, base): def __init__(self, name, state, base):
logger.debug("Create test 1") logger.debug("Create test 1")
super(TestBNode, self).__init__(name) super(TestBNode, self).__init__(name, state)
self.base = base self.base = base
def get_edges(self): def get_edges(self):
# this should have been inserted by test_a before
# we are called
assert self.state['test_init_state'] == 'here'
return ([self.base], []) return ([self.base], [])
def create(self, state, rollback): def create(self, rollback):
state['test_b'] = {} self.state['test_b'] = {}
state['test_b']['value'] = 'baz' self.state['test_b']['value'] = 'baz'
return return
def umount(self, state): def umount(self, state):
@ -44,9 +47,10 @@ class TestBNode(NodeBase):
class TestB(PluginBase): class TestB(PluginBase):
def __init__(self, config, defaults): def __init__(self, config, defaults, state):
super(TestB, self).__init__() super(TestB, self).__init__()
self.node = TestBNode(config['name'], self.node = TestBNode(config['name'],
state,
config['base']) config['base'])
def get_nodes(self): def get_nodes(self):

View File

@ -104,7 +104,7 @@ class TestCreateGraph(TestGraphGeneration):
self.assertRaisesRegex(BlockDeviceSetupException, self.assertRaisesRegex(BlockDeviceSetupException,
"Edge not defined: this_is_not_a_node", "Edge not defined: this_is_not_a_node",
create_graph, create_graph,
config, self.fake_default_config) config, self.fake_default_config, {})
# Test a graph with bad edge pointing to an invalid node # Test a graph with bad edge pointing to an invalid node
def test_duplicate_name(self): def test_duplicate_name(self):
@ -113,13 +113,13 @@ class TestCreateGraph(TestGraphGeneration):
"Duplicate node name: " "Duplicate node name: "
"this_is_a_duplicate", "this_is_a_duplicate",
create_graph, create_graph,
config, self.fake_default_config) config, self.fake_default_config, {})
# Test digraph generation from deep_graph config file # Test digraph generation from deep_graph config file
def test_deep_graph_generator(self): def test_deep_graph_generator(self):
config = self.load_config_file('deep_graph.yaml') config = self.load_config_file('deep_graph.yaml')
graph, call_order = create_graph(config, self.fake_default_config) graph, call_order = create_graph(config, self.fake_default_config, {})
call_order_list = [n.name for n in call_order] call_order_list = [n.name for n in call_order]
@ -136,7 +136,7 @@ class TestCreateGraph(TestGraphGeneration):
def test_multiple_partitions_graph_generator(self): def test_multiple_partitions_graph_generator(self):
config = self.load_config_file('multiple_partitions_graph.yaml') config = self.load_config_file('multiple_partitions_graph.yaml')
graph, call_order = create_graph(config, self.fake_default_config) graph, call_order = create_graph(config, self.fake_default_config, {})
call_order_list = [n.name for n in call_order] call_order_list = [n.name for n in call_order]
# The sort creating call_order_list is unstable. # The sort creating call_order_list is unstable.

View File

@ -28,9 +28,16 @@ class TestMountOrder(tc.TestGraphGeneration):
config = self.load_config_file('multiple_partitions_graph.yaml') config = self.load_config_file('multiple_partitions_graph.yaml')
graph, call_order = create_graph(config, self.fake_default_config)
state = {} state = {}
graph, call_order = create_graph(config, self.fake_default_config,
state)
rollback = []
# build up some fake state so that we don't have to mock out
# all the parent calls that would really make these values, as
# we just want to test MountPointNode
state['filesys'] = {} state['filesys'] = {}
state['filesys']['mkfs_root'] = {} state['filesys']['mkfs_root'] = {}
state['filesys']['mkfs_root']['device'] = 'fake' state['filesys']['mkfs_root']['device'] = 'fake'
@ -39,14 +46,12 @@ class TestMountOrder(tc.TestGraphGeneration):
state['filesys']['mkfs_var_log'] = {} state['filesys']['mkfs_var_log'] = {}
state['filesys']['mkfs_var_log']['device'] = 'fake' state['filesys']['mkfs_var_log']['device'] = 'fake'
rollback = []
for node in call_order: for node in call_order:
if isinstance(node, MountPointNode): if isinstance(node, MountPointNode):
# XXX: do we even need to create? We could test the # XXX: do we even need to create? We could test the
# sudo arguments from the mock in the below asserts # sudo arguments from the mock in the below asserts
# too # too
node.create(state, rollback) node.create(rollback)
# ensure that partitions are mounted in order root->var->var/log # ensure that partitions are mounted in order root->var->var/log
self.assertListEqual(state['mount_order'], ['/', '/var', '/var/log']) self.assertListEqual(state['mount_order'], ['/', '/var', '/var/log'])

View File

@ -72,7 +72,8 @@ class TestState(TestStateBase):
self.assertDictEqual(state, self.assertDictEqual(state,
{'test_a': {'value': 'foo', {'test_a': {'value': 'foo',
'value2': 'bar'}, 'value2': 'bar'},
'test_b': {'value': 'baz'}}) 'test_b': {'value': 'baz'},
'test_init_state': 'here'})
pickle_file = bd_obj.node_pickle_file_name pickle_file = bd_obj.node_pickle_file_name
self.assertThat(pickle_file, FileExists()) self.assertThat(pickle_file, FileExists())