Ensure state space can also pass on_enter/exit callbacks

Change-Id: If455f9799b9a3f1a5489d50f8cac8c75143bbb58
This commit is contained in:
Joshua Harlow 2016-01-26 18:08:20 -08:00
parent ad4b42c963
commit 818b7998b4
2 changed files with 46 additions and 2 deletions

View File

@ -33,12 +33,18 @@ class State(object):
:ivar name: The name of the state. :ivar name: The name of the state.
:ivar is_terminal: Whether this state is terminal (or not). :ivar is_terminal: Whether this state is terminal (or not).
:ivar next_states: Dictionary of 'event' -> 'next state name' (or none). :ivar next_states: Dictionary of 'event' -> 'next state name' (or none).
:ivar on_enter: callback that will be called when the state is entered.
:ivar on_exit: callback that will be called when the state is exited.
""" """
def __init__(self, name, is_terminal=False, next_states=None): def __init__(self, name,
is_terminal=False, next_states=None,
on_enter=None, on_exit=None):
self.name = name self.name = name
self.is_terminal = bool(is_terminal) self.is_terminal = bool(is_terminal)
self.next_states = next_states self.next_states = next_states
self.on_enter = on_enter
self.on_exit = on_exit
def _convert_to_states(state_space): def _convert_to_states(state_space):
@ -141,7 +147,10 @@ class FiniteMachine(object):
state_space = list(_convert_to_states(state_space)) state_space = list(_convert_to_states(state_space))
m = cls() m = cls()
for state in state_space: for state in state_space:
m.add_state(state.name, terminal=state.is_terminal) m.add_state(state.name,
terminal=state.is_terminal,
on_enter=state.on_enter,
on_exit=state.on_exit)
for state in state_space: for state in state_space:
if state.next_states: if state.next_states:
for event, next_state in six.iteritems(state.next_states): for event, next_state in six.iteritems(state.next_states):

View File

@ -14,6 +14,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import collections
import functools import functools
import random import random
@ -69,6 +70,40 @@ class FSMTest(testcase.TestCase):
expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')] expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')]
self.assertEqual(expected, list(m)) self.assertEqual(expected, list(m))
def test_build_transitions_with_callbacks(self):
entered = collections.defaultdict(list)
exitted = collections.defaultdict(list)
def on_enter(state, event):
entered[state].append(event)
def on_exit(state, event):
exitted[state].append(event)
space = [
machines.State('down', is_terminal=False,
next_states={'jump': 'up'},
on_enter=on_enter, on_exit=on_exit),
machines.State('up', is_terminal=False,
next_states={'fall': 'down'},
on_enter=on_enter, on_exit=on_exit),
]
m = machines.FiniteMachine.build(space)
m.default_start_state = 'down'
expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')]
self.assertEqual(expected, list(m))
m.initialize()
m.process_event('jump')
self.assertEqual({'down': ['jump']}, dict(exitted))
self.assertEqual({'up': ['jump']}, dict(entered))
m.process_event('fall')
self.assertEqual({'down': ['jump'], 'up': ['fall']}, dict(exitted))
self.assertEqual({'up': ['jump'], 'down': ['fall']}, dict(entered))
def test_build_transitions_dct(self): def test_build_transitions_dct(self):
space = [ space = [
{ {