Ensure state space can also pass on_enter/exit callbacks
Change-Id: If455f9799b9a3f1a5489d50f8cac8c75143bbb58
This commit is contained in:
parent
ad4b42c963
commit
818b7998b4
@ -33,12 +33,18 @@ class State(object):
|
||||
:ivar name: The name of the state.
|
||||
:ivar is_terminal: Whether this state is terminal (or not).
|
||||
:ivar next_states: Dictionary of 'event' -> 'next state name' (or none).
|
||||
:ivar on_enter: callback that will be called when the state is entered.
|
||||
:ivar on_exit: callback that will be called when the state is exited.
|
||||
"""
|
||||
|
||||
def __init__(self, name, is_terminal=False, next_states=None):
|
||||
def __init__(self, name,
|
||||
is_terminal=False, next_states=None,
|
||||
on_enter=None, on_exit=None):
|
||||
self.name = name
|
||||
self.is_terminal = bool(is_terminal)
|
||||
self.next_states = next_states
|
||||
self.on_enter = on_enter
|
||||
self.on_exit = on_exit
|
||||
|
||||
|
||||
def _convert_to_states(state_space):
|
||||
@ -141,7 +147,10 @@ class FiniteMachine(object):
|
||||
state_space = list(_convert_to_states(state_space))
|
||||
m = cls()
|
||||
for state in state_space:
|
||||
m.add_state(state.name, terminal=state.is_terminal)
|
||||
m.add_state(state.name,
|
||||
terminal=state.is_terminal,
|
||||
on_enter=state.on_enter,
|
||||
on_exit=state.on_exit)
|
||||
for state in state_space:
|
||||
if state.next_states:
|
||||
for event, next_state in six.iteritems(state.next_states):
|
||||
|
@ -14,6 +14,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import random
|
||||
|
||||
@ -69,6 +70,40 @@ class FSMTest(testcase.TestCase):
|
||||
expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')]
|
||||
self.assertEqual(expected, list(m))
|
||||
|
||||
def test_build_transitions_with_callbacks(self):
|
||||
entered = collections.defaultdict(list)
|
||||
exitted = collections.defaultdict(list)
|
||||
|
||||
def on_enter(state, event):
|
||||
entered[state].append(event)
|
||||
|
||||
def on_exit(state, event):
|
||||
exitted[state].append(event)
|
||||
|
||||
space = [
|
||||
machines.State('down', is_terminal=False,
|
||||
next_states={'jump': 'up'},
|
||||
on_enter=on_enter, on_exit=on_exit),
|
||||
machines.State('up', is_terminal=False,
|
||||
next_states={'fall': 'down'},
|
||||
on_enter=on_enter, on_exit=on_exit),
|
||||
]
|
||||
m = machines.FiniteMachine.build(space)
|
||||
m.default_start_state = 'down'
|
||||
expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')]
|
||||
self.assertEqual(expected, list(m))
|
||||
|
||||
m.initialize()
|
||||
m.process_event('jump')
|
||||
|
||||
self.assertEqual({'down': ['jump']}, dict(exitted))
|
||||
self.assertEqual({'up': ['jump']}, dict(entered))
|
||||
|
||||
m.process_event('fall')
|
||||
|
||||
self.assertEqual({'down': ['jump'], 'up': ['fall']}, dict(exitted))
|
||||
self.assertEqual({'up': ['jump'], 'down': ['fall']}, dict(entered))
|
||||
|
||||
def test_build_transitions_dct(self):
|
||||
space = [
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user