Add typing

Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
Change-Id: Ib34fab9de71cdadb0e4a8a211f86b9e6dc362007
This commit is contained in:
Stephen Finucane
2025-12-22 20:01:51 +00:00
parent d4417df0b9
commit b73bb3b67c
11 changed files with 231 additions and 108 deletions

View File

@@ -2,15 +2,22 @@
Automaton
=========
.. image:: https://img.shields.io/pypi/v/automaton.svg
.. image:: https://governance.openstack.org/tc/badges/automaton
:target: https://governance.openstack.org/tc/reference/projects/oslo.html
.. image:: https://img.shields.io/pypi/v/automaton
:target: https://pypi.org/project/automaton/
:alt: Latest Version
.. image:: https://img.shields.io/pypi/dm/automaton.svg
.. image:: https://img.shields.io/pypi/dm/automaton
:target: https://pypi.org/project/automaton/
:alt: Downloads
Friendly state machines for python. The goal of this library is to provide
.. image:: https://img.shields.io/pypi/types/automaton
:target: https://pypi.org/project/automaton/
:alt: Typing Status
Friendly state machines for Python. The goal of this library is to provide
well documented state machine classes and associated utilities. The state
machine pattern (or the implemented variation there-of) is a commonly
used pattern and has a multitude of various usages. Some of the usages

View File

@@ -12,8 +12,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from typing import Any
def get_callback_name(cb):
def get_callback_name(cb: Any) -> str:
"""Tries to get a callbacks fully-qualified name."""
segments = [cb.__qualname__]

View File

@@ -12,6 +12,11 @@
# License for the specific language governing permissions and limitations
# under the License.
from collections.abc import Callable, Mapping
from typing import Any
from automaton import machines
try:
import pydot
@@ -19,37 +24,35 @@ try:
except ImportError:
PYDOT_AVAILABLE = False
NodeAttrsCallbackT = Callable[[str], Mapping[str, Any]]
EdgeAttrsCallbackT = Callable[[str, str, str], Mapping[str, Any]]
def convert(
machine,
graph_name,
graph_attrs=None,
node_attrs_cb=None,
edge_attrs_cb=None,
add_start_state=True,
name_translations=None,
):
machine: machines.FiniteMachine,
graph_name: str,
graph_attrs: Mapping[str, Any] | None = None,
node_attrs_cb: NodeAttrsCallbackT | None = None,
edge_attrs_cb: EdgeAttrsCallbackT | None = None,
add_start_state: bool = True,
name_translations: Mapping[str, str] | None = None,
) -> Any:
"""Translates the state machine into a pydot graph.
:param machine: state machine to convert
:type machine: FiniteMachine
:param graph_name: name of the graph to be created
:type graph_name: string
:param graph_attrs: any initial graph attributes to set
(see http://www.graphviz.org/doc/info/attrs.html for
what these can be)
:type graph_attrs: dict
:param node_attrs_cb: a callback that takes one argument ``state``
and is expected to return a dict of node attributes
(see http://www.graphviz.org/doc/info/attrs.html for
what these can be)
:type node_attrs_cb: callback
:param edge_attrs_cb: a callback that takes three arguments ``start_state,
event, end_state`` and is expected to return a dict
of edge attributes (see
http://www.graphviz.org/doc/info/attrs.html for
what these can be)
:type edge_attrs_cb: callback
:param add_start_state: when enabled this creates a *private* start state
with the name ``__start__`` that will be a point
node that will have a dotted edge to the
@@ -57,10 +60,8 @@ def convert(
defined (if your machine has no actively defined
``default_start_state`` then this does nothing,
even if enabled)
:type add_start_state: bool
:param name_translations: a dict that provides alternative ``state``
string names for each state
:type name_translations: dict
"""
if not PYDOT_AVAILABLE:
raise RuntimeError(
@@ -83,7 +84,7 @@ def convert(
graph_kwargs.update(graph_attrs)
graph_kwargs['graph_name'] = graph_name
g = pydot.Dot(**graph_kwargs)
node_attrs = {
node_attrs: dict[str, Any] = {
'fontsize': '11',
}
nodes = {}
@@ -106,7 +107,7 @@ def convert(
pretty_end_state = name_translations.get(end_state, end_state)
nodes[end_state] = pydot.Node(pretty_end_state, **end_node_attrs)
g.add_node(nodes[end_state])
edge_attrs = {}
edge_attrs: dict[str, Any] = {}
if edge_attrs_cb is not None:
edge_attrs.update(edge_attrs_cb(start_state, event, end_state))
g.add_edge(

View File

@@ -36,5 +36,5 @@ class Duplicate(AutomatonException):
class FrozenMachine(AutomatonException):
"""Exception raised when a frozen machine is modified."""
def __init__(self):
def __init__(self) -> None:
super().__init__("Frozen machine can't be modified")

View File

@@ -13,12 +13,18 @@
# under the License.
import collections
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, TypedDict
import prettytable
from typing_extensions import NotRequired, Self
from automaton import _utils as utils
from automaton import exceptions as excp
OnEnterCallbackT = Callable[[str, str], None] | None
OnExitCallbackT = Callable[[str, str], None] | None
class State:
"""Container that defines needed components of a single state.
@@ -35,12 +41,12 @@ class State:
def __init__(
self,
name,
is_terminal=False,
next_states=None,
on_enter=None,
on_exit=None,
):
name: str,
is_terminal: bool = False,
next_states: Mapping[str, str] | None = None,
on_enter: OnEnterCallbackT | None = None,
on_exit: OnExitCallbackT | None = None,
) -> None:
self.name = name
self.is_terminal = bool(is_terminal)
self.next_states = next_states
@@ -48,15 +54,26 @@ class State:
self.on_exit = on_exit
def _convert_to_states(state_space):
class StateDict(TypedDict, total=False):
name: str
is_terminal: bool
next_states: dict[str, str] | None
on_enter: OnEnterCallbackT | None
on_exit: OnExitCallbackT | None
def _convert_to_states(
state_space: Sequence[State | StateDict],
) -> Generator[State, None, None]:
# NOTE(harlowja): if provided dicts, convert them...
for state in state_space:
if isinstance(state, dict):
state = State(**state)
yield state
yield State(**state)
else:
yield state
def _orderedkeys(data, sort=True):
def _orderedkeys(data: Mapping[str, Any], sort: bool = True) -> list[str]:
if sort:
return sorted(data)
else:
@@ -66,12 +83,24 @@ def _orderedkeys(data, sort=True):
class _Jump:
"""A FSM transition tracks this data while jumping."""
def __init__(self, name, on_enter, on_exit):
def __init__(
self, name: str, on_enter: OnEnterCallbackT, on_exit: OnExitCallbackT
) -> None:
self.name = name
self.on_enter = on_enter
self.on_exit = on_exit
class _TrackedState(TypedDict):
terminal: bool
reactions: dict[
str, tuple[Callable[..., Any], tuple[Any, ...], dict[str, Any]]
]
on_enter: OnEnterCallbackT
on_exit: OnExitCallbackT
machine: NotRequired['FiniteMachine']
class FiniteMachine:
"""A finite state machine.
@@ -102,20 +131,22 @@ class FiniteMachine:
Effect = collections.namedtuple('Effect', 'reaction,terminal')
@classmethod
def _effect_builder(cls, new_state, event):
def _effect_builder(
cls, new_state: Mapping[str, Any], event: str
) -> Effect:
return cls.Effect(
new_state['reactions'].get(event), new_state["terminal"]
)
def __init__(self):
self._transitions = {}
self._states = {}
self._default_start_state = None
self._current = None
def __init__(self) -> None:
self._transitions: dict[str, dict[str, _Jump]] = {}
self._states: dict[str, _TrackedState] = {}
self._default_start_state: str | None = None
self._current: _Jump | None = None
self.frozen = False
@property
def default_start_state(self):
def default_start_state(self) -> str | None:
"""Sets the *default* start state that the machine should use.
NOTE(harlowja): this will be used by ``initialize`` but only if that
@@ -125,34 +156,38 @@ class FiniteMachine:
return self._default_start_state
@default_start_state.setter
def default_start_state(self, state):
def default_start_state(self, state: str) -> None:
if self.frozen:
raise excp.FrozenMachine()
if state not in self._states:
raise excp.NotFound(
f"Can not set the default start state to undefined state "
f"'{state}'"
)
self._default_start_state = state
@classmethod
def build(cls, state_space):
def build(
cls, state_space: Sequence[State | StateDict]
) -> 'FiniteMachine':
"""Builds a machine from a state space listing.
Each element of this list must be an instance
of :py:class:`.State` or a ``dict`` with equivalent keys that
can be used to construct a :py:class:`.State` instance.
"""
state_space = list(_convert_to_states(state_space))
normalized_states = list(_convert_to_states(state_space))
m = cls()
for state in state_space:
for state in normalized_states:
m.add_state(
state.name,
terminal=state.is_terminal,
on_enter=state.on_enter,
on_exit=state.on_exit,
)
for state in state_space:
for state in normalized_states:
if state.next_states:
for event, next_state in state.next_states.items():
if isinstance(next_state, State):
@@ -161,20 +196,26 @@ class FiniteMachine:
return m
@property
def current_state(self):
def current_state(self) -> str | None:
"""The current state the machine is in (or none if not initialized)."""
if self._current is not None:
return self._current.name
return None
@property
def terminated(self):
def terminated(self) -> bool:
"""Returns whether the state machine is in a terminal state."""
if self._current is None:
return False
return self._states[self._current.name]['terminal']
return bool(self._states[self._current.name]['terminal'])
def add_state(self, state, terminal=False, on_enter=None, on_exit=None):
def add_state(
self,
state: str,
terminal: bool = False,
on_enter: OnEnterCallbackT = None,
on_exit: OnExitCallbackT = None,
) -> None:
"""Adds a given state to the state machine.
The ``on_enter`` and ``on_exit`` callbacks, if provided will be
@@ -201,7 +242,7 @@ class FiniteMachine:
}
self._transitions[state] = {}
def is_actionable_event(self, event):
def is_actionable_event(self, event: str) -> bool:
"""Check whether the event is actionable in the current state."""
current = self._current
if current is None:
@@ -210,7 +251,14 @@ class FiniteMachine:
return False
return True
def add_reaction(self, state, event, reaction, *args, **kwargs):
def add_reaction(
self,
state: str,
event: str,
reaction: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> None:
"""Adds a reaction that may get triggered by the given event & state.
Reaction callbacks may (depending on how the state machine is ran) be
@@ -232,21 +280,26 @@ class FiniteMachine:
"""
if self.frozen:
raise excp.FrozenMachine()
if state not in self._states:
raise excp.NotFound(
f"Can not add a reaction to event '{event}' for an "
f"undefined state '{state}'"
)
if not callable(reaction):
raise ValueError("Reaction callback must be callable")
if event not in self._states[state]['reactions']:
self._states[state]['reactions'][event] = (reaction, args, kwargs)
else:
if event in self._states[state]['reactions']:
raise excp.Duplicate(
f"State '{state}' reaction to event '{event}' already defined"
)
def add_transition(self, start, end, event, replace=False):
self._states[state]['reactions'][event] = (reaction, args, kwargs)
def add_transition(
self, start: str, end: str, event: str, replace: bool = False
) -> None:
"""Adds an allowed transition from start -> end for the given event.
:param start: starting state
@@ -290,7 +343,7 @@ class FiniteMachine:
)
self._transitions[start][event] = target
def _pre_process_event(self, event):
def _pre_process_event(self, event: str) -> None:
current = self._current
if current is None:
raise excp.NotInitialized(
@@ -308,10 +361,10 @@ class FiniteMachine:
f"'{event}' (no defined transition)"
)
def _post_process_event(self, event, result):
def _post_process_event(self, event: str, result: Effect) -> Effect:
return result
def process_event(self, event):
def process_event(self, event: str) -> Effect:
"""Trigger a state change in response to the provided event.
:returns: Effect this is either a :py:class:`.FiniteMachine.Effect` or
@@ -325,6 +378,8 @@ class FiniteMachine:
"""
self._pre_process_event(event)
current = self._current
# narrow type (_pre_process_event ensures this)
assert current is not None # noqa: S101
replacement = self._transitions[current.name][event]
if current.on_exit is not None:
current.on_exit(current.name, event)
@@ -334,7 +389,7 @@ class FiniteMachine:
result = self._effect_builder(self._states[replacement.name], event)
return self._post_process_event(event, result)
def initialize(self, start_state=None):
def initialize(self, start_state: str | None = None) -> None:
"""Sets up the state machine (sets current state to start state...).
:param start_state: explicit start state to use to initialize the
@@ -360,7 +415,7 @@ class FiniteMachine:
start_state, None, self._states[start_state]['on_exit']
)
def copy(self, shallow=False, unfreeze=False):
def copy(self, shallow: bool = False, unfreeze: bool = False) -> Self:
"""Copies the current state machine.
NOTE(harlowja): the copy will be left in an *uninitialized* state.
@@ -379,45 +434,45 @@ class FiniteMachine:
else:
c.frozen = self.frozen
if not shallow:
for state, data in self._states.items():
copied_data = data.copy()
copied_data['reactions'] = copied_data['reactions'].copy()
c._states[state] = copied_data
for state, data in self._transitions.items():
c._transitions[state] = data.copy()
for state_name, state in self._states.items():
copied_state = state.copy()
copied_state['reactions'] = copied_state['reactions'].copy()
c._states[state_name] = copied_state
for state_name, transition in self._transitions.items():
c._transitions[state_name] = transition.copy()
else:
c._transitions = self._transitions
c._states = self._states
return c
def __contains__(self, state):
def __contains__(self, state: str) -> bool:
"""Returns if this state exists in the machines known states."""
return state in self._states
def freeze(self):
def freeze(self) -> None:
"""Freezes & stops addition of states, transitions, reactions..."""
self.frozen = True
@property
def states(self):
def states(self) -> list[str]:
"""Returns the state names."""
return list(self._states)
@property
def events(self):
def events(self) -> int:
"""Returns how many events exist."""
c = 0
for state in self._states:
c += len(self._transitions[state])
return c
def __iter__(self):
def __iter__(self) -> Generator[tuple[str, str, str], None, None]:
"""Iterates over (start, event, end) transition tuples."""
for state in self._states:
for event, target in self._transitions[state].items():
yield (state, event, target.name)
def pformat(self, sort=True, empty='.'):
def pformat(self, sort: bool = True, empty: str = '.') -> str:
"""Pretty formats the state + transition table into a string.
NOTE(harlowja): the sort parameter can be provided to sort the states
@@ -453,14 +508,14 @@ class FiniteMachine:
row.append(empty)
tbl.add_row(row)
else:
on_enter = self._states[state]['on_enter']
if on_enter is not None:
on_enter = utils.get_callback_name(on_enter)
on_enter_cb = self._states[state]['on_enter']
if on_enter_cb is not None:
on_enter = utils.get_callback_name(on_enter_cb)
else:
on_enter = empty
on_exit = self._states[state]['on_exit']
if on_exit is not None:
on_exit = utils.get_callback_name(on_exit)
on_exit_cb = self._states[state]['on_exit']
if on_exit_cb is not None:
on_exit = utils.get_callback_name(on_exit_cb)
else:
on_exit = empty
tbl.add_row([pretty_state, empty, empty, on_enter, on_exit])
@@ -473,12 +528,14 @@ class HierarchicalFiniteMachine(FiniteMachine):
#: The result of processing an event (cause and effect...)
Effect = collections.namedtuple('Effect', 'reaction,terminal,machine')
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._nested_machines = {}
self._nested_machines: dict[str, FiniteMachine] = {}
@classmethod
def _effect_builder(cls, new_state, event):
def _effect_builder( # type: ignore[override]
cls, new_state: Mapping[str, Any], event: str
) -> Effect:
return cls.Effect(
new_state['reactions'].get(event),
new_state["terminal"],
@@ -486,13 +543,17 @@ class HierarchicalFiniteMachine(FiniteMachine):
)
def add_state(
self, state, terminal=False, on_enter=None, on_exit=None, machine=None
):
self,
state: str,
terminal: bool = False,
on_enter: OnEnterCallbackT = None,
on_exit: OnExitCallbackT = None,
machine: 'FiniteMachine | None' = None,
) -> None:
"""Adds a given state to the state machine.
:param machine: the nested state machine that will be transitioned
into when this state is entered
:type machine: :py:class:`.FiniteMachine`
Further arguments are interpreted as
for :py:meth:`.FiniteMachine.add_state`.
@@ -508,7 +569,7 @@ class HierarchicalFiniteMachine(FiniteMachine):
self._states[state]['machine'] = machine
self._nested_machines[state] = machine
def copy(self, shallow=False, unfreeze=False):
def copy(self, shallow: bool = False, unfreeze: bool = False) -> Self:
c = super().copy(shallow=shallow, unfreeze=unfreeze)
if shallow:
c._nested_machines = self._nested_machines
@@ -516,7 +577,12 @@ class HierarchicalFiniteMachine(FiniteMachine):
c._nested_machines = self._nested_machines.copy()
return c
def initialize(self, start_state=None, nested_start_state_fetcher=None):
def initialize(
self,
start_state: str | None = None,
nested_start_state_fetcher: Callable[['FiniteMachine'], str | None]
| None = None,
) -> None:
"""Sets up the state machine (sets current state to start state...).
:param start_state: explicit start state to use to initialize the
@@ -555,6 +621,6 @@ class HierarchicalFiniteMachine(FiniteMachine):
nested_machine.initialize(start_state=nested_start_state)
@property
def nested_machines(self):
def nested_machines(self) -> dict[str, 'FiniteMachine']:
"""Dictionary of **all** nested state machines this machine may use."""
return self._nested_machines

0
automaton/py.typed Normal file
View File

View File

@@ -13,6 +13,8 @@
# under the License.
import abc
from collections.abc import Generator
from typing import Any
from automaton import exceptions as excp
from automaton import machines
@@ -34,15 +36,17 @@ class Runner(metaclass=abc.ABCMeta):
the same time).
"""
def __init__(self, machine):
def __init__(self, machine: machines.FiniteMachine) -> None:
self._machine = machine
@abc.abstractmethod
def run(self, event, initialize=True):
def run(self, event: str, initialize: bool = True) -> None:
"""Runs the state machine, using reactions only."""
@abc.abstractmethod
def run_iter(self, event, initialize=True):
def run_iter(
self, event: str, initialize: bool = True
) -> Generator[tuple[str | None, str | None], Any, None]:
"""Returns a iterator/generator that will run the state machine.
NOTE(harlowja): only one runner iterator/generator should be active for
@@ -60,17 +64,19 @@ class FiniteRunner(Runner):
the same time).
"""
def __init__(self, machine):
def __init__(self, machine: machines.FiniteMachine) -> None:
"""Create a runner for the given machine."""
if not isinstance(machine, (machines.FiniteMachine,)):
raise TypeError("FiniteRunner only works with FiniteMachine(s)")
super().__init__(machine)
def run(self, event, initialize=True):
def run(self, event: str, initialize: bool = True) -> None:
for transition in self.run_iter(event, initialize=initialize):
pass
def run_iter(self, event, initialize=True):
def run_iter(
self, event: str, initialize: bool = True
) -> Generator[tuple[str | None, str | None], Any, None]:
if initialize:
self._machine.initialize()
while True:
@@ -102,7 +108,7 @@ class HierarchicalRunner(Runner):
the same time).
"""
def __init__(self, machine):
def __init__(self, machine: machines.HierarchicalFiniteMachine) -> None:
"""Create a runner for the given machine."""
if not isinstance(machine, (machines.HierarchicalFiniteMachine,)):
raise TypeError(
@@ -111,12 +117,14 @@ class HierarchicalRunner(Runner):
)
super().__init__(machine)
def run(self, event, initialize=True):
def run(self, event: str, initialize: bool = True) -> None:
for transition in self.run_iter(event, initialize=initialize):
pass
@staticmethod
def _process_event(machines, event):
def _process_event(
machines: list[machines.FiniteMachine], event: str
) -> machines.HierarchicalFiniteMachine.Effect:
"""Matches a event to the machine hierarchy.
If the lowest level machine does not handle the event, then the
@@ -141,9 +149,11 @@ class HierarchicalRunner(Runner):
machine._current = None
machines.pop()
else:
return result
return result # type: ignore[return-value]
def run_iter(self, event, initialize=True):
def run_iter(
self, event: str, initialize: bool = True
) -> Generator[tuple[str | None, str | None], Any, None]:
"""Returns a iterator/generator that will run the state machine.
This will keep a stack (hierarchy) of machines active and jumps through

View File

@@ -111,7 +111,7 @@ class FSMTest(testcase.TestCase):
self.assertEqual({'up': ['jump'], 'down': ['fall']}, dict(entered))
def test_build_transitions_dct(self):
space = [
space: list[machines.StateDict] = [
{
'name': 'down',
'is_terminal': False,
@@ -383,11 +383,11 @@ class FSMTest(testcase.TestCase):
class HFSMTest(FSMTest):
@staticmethod
def _create_fsm(
def _create_fsm( # type: ignore[override]
start_state, add_start=True, hierarchical=False, add_states=None
):
if hierarchical:
m = machines.HierarchicalFiniteMachine()
m: machines.FiniteMachine = machines.HierarchicalFiniteMachine()
else:
m = machines.FiniteMachine()
if add_start:

View File

@@ -24,6 +24,12 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Typing :: Typed",
]
[project.optional-dependencies]
pydot = [
"pydot>=4.0",
]
[project.urls]
@@ -35,6 +41,20 @@ packages = [
"automaton"
]
[tool.mypy]
python_version = "3.10"
show_column_numbers = true
show_error_context = true
strict = true
disable_error_code = ["import-untyped"]
exclude = "(?x)(doc | releasenotes)"
[[tool.mypy.overrides]]
module = ["automaton.tests.*"]
disallow_untyped_calls = false
disallow_untyped_defs = false
disallow_subclassing_any = false
[tool.ruff]
line-length = 79

View File

@@ -1,5 +1,2 @@
# See: https://bugs.launchpad.net/pbr/+bug/1384919 for why this is here...
pbr>=2.0.0 # Apache-2.0
# For pretty formatting machines/state tables...
PrettyTable>=0.7.2 # BSD
typing-extensions>=4.0.0 # PSF-2.0

26
tox.ini
View File

@@ -15,10 +15,24 @@ deps =
commands = stestr run --slowest {posargs}
[testenv:pep8]
skip_install = true
description =
Run style checks.
deps =
pre-commit>=2.6.0 # MIT
commands = pre-commit run -a
pre-commit
{[testenv:mypy]deps}
commands =
pre-commit run -a
{[testenv:mypy]commands}
[testenv:mypy]
description =
Run type checks.
deps =
{[testenv]deps}
mypy
pydot
commands =
mypy --cache-dir="{envdir}/mypy_cache" {posargs:automaton}
[testenv:venv]
commands = {posargs}
@@ -64,3 +78,9 @@ usedevelop = False
select = H
show-source = True
exclude = .venv,.git,.tox,dist,doc,*lib/python*,*egg,build
[hacking]
import_exceptions =
collections.abc
typing
typing_extensions