Cover more task related things with type hints

Signed-off-by: Andriy Kurilin <andr.kurilin@gmail.com>
Change-Id: Ib806217076ae85d1744301b381b17f1c7556e5ec
This commit is contained in:
Andriy Kurilin
2025-07-30 09:44:18 +02:00
parent 3ca8e86288
commit 9835218073
32 changed files with 610 additions and 337 deletions

View File

@@ -24,6 +24,7 @@ Added
~~~~~
* CI jobs for checking compatibility with python 3.12
* Python type hints for task-related plugins and base classes
Changed
~~~~~~~

View File

@@ -307,14 +307,6 @@ disable_error_code = ["attr-defined", "index", "no-untyped-def", "var-annotated"
module = "rally.env.platform"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.common.validators"
disable_error_code = ["arg-type", "assignment", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.exporters.elastic.client"
disable_error_code = ["no-untyped-def"]
@@ -347,58 +339,6 @@ disable_error_code = ["no-untyped-def"]
module = "rally.plugins.task.exporters.trends"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.hook_triggers.event"
disable_error_code = ["dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.hook_triggers.periodic"
disable_error_code = ["dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.hooks.sys_call"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.scenarios.dummy.dummy"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.scenarios.requests.http_requests"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.scenarios.requests.utils"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.sla.failure_rate"
disable_error_code = ["dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.sla.iteration_time"
disable_error_code = ["dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.sla.max_average_duration"
disable_error_code = ["dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.sla.max_average_duration_per_atomic"
disable_error_code = ["assignment", "dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.sla.outliers"
disable_error_code = ["dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.sla.performance_degradation"
disable_error_code = ["dict-item", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.task.types"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.plugins.verification.reporters"
disable_error_code = ["no-untyped-def"]
@@ -419,10 +359,6 @@ disable_error_code = ["no-untyped-def"]
module = "rally.task.functional"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.task.hook"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.task.processing.charts"
disable_error_code = ["index", "no-untyped-def", "override", "var-annotated"]
@@ -439,18 +375,10 @@ disable_error_code = ["no-untyped-def"]
module = "rally.task.service"
disable_error_code = ["call-overload", "no-untyped-def", "var-annotated"]
[[tool.mypy.overrides]]
module = "rally.task.sla"
disable_error_code = ["no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.task.task_cfg"
disable_error_code = ["assignment", "attr-defined", "no-untyped-def"]
[[tool.mypy.overrides]]
module = "rally.task.types"
disable_error_code = ["no-untyped-def", "var-annotated"]
[[tool.mypy.overrides]]
module = "rally.task.utils"
disable_error_code = ["index", "no-untyped-def"]

View File

@@ -13,16 +13,24 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import functools
import os
import typing as t
import typing_extensions as te
from rally.common.plugin import discover
if t.TYPE_CHECKING:
P = te.ParamSpec("P")
R = t.TypeVar("R")
PLUGINS_LOADED = False
def load():
def load() -> None:
global PLUGINS_LOADED
if not PLUGINS_LOADED:
@@ -52,9 +60,9 @@ def load():
PLUGINS_LOADED = True
def ensure_plugins_are_loaded(func):
def ensure_plugins_are_loaded(func: t.Callable[P, R]) -> t.Callable[P, R]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: t.Any, **kwargs: t.Any) -> R:
load()
return func(*args, **kwargs)
return wrapper

View File

@@ -12,8 +12,11 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import inspect
import os
import typing as t
import jsonschema
@@ -22,6 +25,12 @@ from rally.common import validation
from rally import exceptions
from rally.task import context as context_lib
if t.TYPE_CHECKING: # pragma: no cover
from rally.common.plugin import plugin
from rally.task import scenario
import jsonschema.protocols
LOG = logging.getLogger(__name__)
@@ -29,13 +38,26 @@ LOG = logging.getLogger(__name__)
class JsonSchemaValidator(validation.Validator):
"""JSON schema validator"""
def validate(self, context, config, plugin_cls, plugin_cfg):
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
schema = getattr(plugin_cls, "CONFIG_SCHEMA", {"type": "null"})
validator = jsonschema.validators.validator_for(
plugin_cls.CONFIG_SCHEMA, default=jsonschema.Draft7Validator
schema,
default=t.cast(
type[jsonschema.protocols.Validator],
jsonschema.Draft7Validator
)
)
try:
jsonschema.validate(
plugin_cfg, plugin_cls.CONFIG_SCHEMA, cls=validator
plugin_cfg, schema,
cls=validator # type: ignore[arg-type]
)
except jsonschema.ValidationError as err:
self.fail(str(err))
@@ -45,12 +67,18 @@ class JsonSchemaValidator(validation.Validator):
class ArgsValidator(validation.Validator):
"""Scenario arguments validator"""
def validate(self, context, config, plugin_cls, plugin_cfg):
scenario = plugin_cls
name = scenario.get_name()
platform = scenario.get_platform()
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[scenario.Scenario], # type: ignore[override]
plugin_cfg: dict[str, t.Any] | None
) -> None:
scenario_cls = plugin_cls
name = scenario_cls.get_name()
platform = scenario_cls.get_platform()
args_spec = inspect.signature(scenario.run).parameters
args_spec = inspect.signature(scenario_cls.run).parameters
missed_args = [
p.name
for i, p in enumerate(args_spec.values())
@@ -62,8 +90,8 @@ class ArgsValidator(validation.Validator):
hint_msg = (" Use `rally plugin show --name %s --platform %s` "
"to display scenario description." % (name, platform))
if "args" in config:
missed_args = set(missed_args) - set(config["args"])
if config is not None and "args" in config:
missed_args = sorted(set(missed_args) - set(config["args"]))
if missed_args:
msg = ("Argument(s) '%(args)s' should be specified in task config."
"%(hint)s" % {"args": "', '".join(missed_args),
@@ -75,7 +103,7 @@ class ArgsValidator(validation.Validator):
if p.kind == inspect.Parameter.VAR_KEYWORD
)
if not support_kwargs and "args" in config:
if not support_kwargs and config is not None and "args" in config:
redundant_args = [p for p in config["args"] if p not in args_spec]
if redundant_args:
msg = ("Unexpected argument(s) found ['%(args)s'].%(hint)s" %
@@ -92,17 +120,29 @@ class RequiredParameterValidator(validation.Validator):
:param subdict: sub-dict of "config" to search. if
not defined - will search in "config"
:param params: list of required parameters
:param params: list of required parameters. If item is list/tuple,
the nested items will be treated as oneOf options.
"""
def __init__(self, params=None, subdict=None):
def __init__(
self,
params: list[str | tuple[str, ...] | list[str]] | None = None,
subdict: str | None = None
) -> None:
super(RequiredParameterValidator, self).__init__()
self.subdict = subdict
self.params = params
self.params = params or []
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
missing: list[str] = []
args: dict[str, t.Any] = config.get("args", {}) if config else {}
def validate(self, context, config, plugin_cls, plugin_cfg):
missing = []
args = config.get("args", {})
if self.subdict:
args = args.get(self.subdict, {})
for arg in self.params:
@@ -111,9 +151,9 @@ class RequiredParameterValidator(validation.Validator):
if case in args:
break
else:
arg = "'/'".join(arg)
arg_str = "'/'".join(arg)
missing.append("'%s' (at least one parameter should be "
"specified)" % arg)
"specified)" % arg_str)
else:
if arg not in args:
missing.append("'%s'" % arg)
@@ -138,25 +178,38 @@ class NumberValidator(validation.Validator):
:param integer_only: Only accept integers
"""
def __init__(self, param_name, minval=None, maxval=None, nullable=False,
integer_only=False):
def __init__(
self,
param_name: str,
minval: int | float | None = None,
maxval: int | float | None = None,
nullable: bool = False,
integer_only: bool = False
) -> None:
self.param_name = param_name
self.minval = minval
self.maxval = maxval
self.nullable = nullable
self.integer_only = integer_only
def validate(self, context, config, plugin_cls, plugin_cfg):
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
value: t.Any = None
if config is not None:
value = config.get("args", {}).get(self.param_name)
value = config.get("args", {}).get(self.param_name)
num_func = float
num_func: type[int] | type[float] = float
if self.integer_only:
# NOTE(boris-42): Force check that passed value is not float, this
# is important cause int(float_numb) won't raise exception
if isinstance(value, float):
return self.fail("%(name)s is %(val)s which hasn't int type"
% {"name": self.param_name, "val": value})
self.fail("%(name)s is %(val)s which hasn't int type"
% {"name": self.param_name, "val": value})
num_func = int
# None may be valid if the scenario sets a sensible default.
@@ -194,13 +247,18 @@ class EnumValidator(validation.Validator):
:param case_insensitive: Ignore case in enum values
"""
def __init__(self, param_name, values, missed=False,
case_insensitive=False):
def __init__(
self,
param_name: str,
values: list[t.Any],
missed: bool = False,
case_insensitive: bool = False
) -> None:
self.param_name = param_name
self.missed = missed
self.case_insensitive = case_insensitive
if self.case_insensitive:
self.values = []
self.values: list[t.Any] = []
for value in values:
if isinstance(value, str):
value = value.lower()
@@ -208,8 +266,16 @@ class EnumValidator(validation.Validator):
else:
self.values = values
def validate(self, context, config, plugin_cls, plugin_cfg):
value = config.get("args", {}).get(self.param_name)
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
value = None
if config is not None:
value = config.get("args", {}).get(self.param_name)
if value:
if self.case_insensitive:
if isinstance(value, str):
@@ -237,8 +303,15 @@ class MapKeysParameterValidator(validation.Validator):
keys are specified, defaults to False, otherwise defaults to True
:param missed: Allow to accept optional parameter
"""
def __init__(self, param_name, required=None, allowed=None,
additional=True, missed=False):
def __init__(
self,
param_name: str,
required: list[str] | None = None,
allowed: list[str] | None = None,
additional: bool = True,
missed: bool = False
) -> None:
super(MapKeysParameterValidator, self).__init__()
self.param_name = param_name
self.required = required or []
@@ -246,8 +319,16 @@ class MapKeysParameterValidator(validation.Validator):
self.additional = additional
self.missed = missed
def validate(self, context, config, plugin_cls, plugin_cfg):
parameter = config.get("args", {}).get(self.param_name)
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
parameter = None
if config is not None:
parameter = config.get("args", {}).get(self.param_name)
if parameter:
required_diff = set(self.required) - set(parameter.keys())
@@ -284,7 +365,10 @@ class MapKeysParameterValidator(validation.Validator):
@validation.configure(name="restricted_parameters")
class RestrictedParametersValidator(validation.Validator):
def __init__(self, param_names, subdict=None):
def __init__(
self, param_names: str | list[str] | tuple[str, ...],
subdict: str | None = None
) -> None:
"""Validates that parameters is not set.
:param param_names: parameter or parameters list to be validated.
@@ -298,12 +382,17 @@ class RestrictedParametersValidator(validation.Validator):
self.params = [param_names]
self.subdict = subdict
def validate(self, context, config, plugin_cls, plugin_cfg):
restricted_params = []
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
restricted_params: list[str] = []
args: dict[str, t.Any] = config.get("args", {}) if config else {}
for param_name in self.params:
source = config.get("args", {})
if self.subdict:
source = source.get(self.subdict) or {}
source = (args.get(self.subdict) or {}) if self.subdict else args
if param_name in source:
restricted_params.append(param_name)
if restricted_params:
@@ -315,7 +404,10 @@ class RestrictedParametersValidator(validation.Validator):
@validation.configure(name="required_contexts")
class RequiredContextsValidator(validation.Validator):
def __init__(self, *args, contexts=None):
def __init__(
self, *args: str,
contexts: t.Iterable[str | tuple[str, ...]] | None = None
) -> None:
"""Validator checks if required contexts are specified.
:param contexts: list of strings and tuples with context names that
@@ -326,7 +418,7 @@ class RequiredContextsValidator(validation.Validator):
if isinstance(contexts, (list, tuple)):
# services argument is a list, so it is a new way of validators
# usage, args in this case should not be provided
self.contexts = contexts
self.contexts: list[str | tuple[str, ...]] = list(contexts)
if args:
LOG.warning("Positional argument is not what "
"'required_context' decorator expects. "
@@ -335,11 +427,13 @@ class RequiredContextsValidator(validation.Validator):
# it is an old way validator
self.contexts = []
if contexts:
self.contexts.append(contexts)
self.contexts.append(t.cast(tuple[str, ...], contexts))
self.contexts.extend(args)
@staticmethod
def _match(requested_ctx_name, input_contexts):
def _match(
requested_ctx_name: str, input_contexts: dict[str, t.Any]
) -> bool:
requested_ctx_name_extended = f"{requested_ctx_name}@"
for input_ctx_name in input_contexts:
if (requested_ctx_name == input_ctx_name
@@ -358,9 +452,17 @@ class RequiredContextsValidator(validation.Validator):
return False
def validate(self, context, config, plugin_cls, plugin_cfg):
missing_contexts = []
input_contexts = config.get("contexts", {})
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
missing_contexts: list[str] = []
input_contexts: dict[str, t.Any] = {}
if config is not None:
input_contexts = config.get("contexts", {})
for required_ctx in self.contexts:
if isinstance(required_ctx, tuple):
@@ -381,7 +483,7 @@ class RequiredContextsValidator(validation.Validator):
@validation.configure(name="required_param_or_context")
class RequiredParamOrContextValidator(validation.Validator):
def __init__(self, param_name, ctx_name):
def __init__(self, param_name: str, ctx_name: str) -> None:
"""Validator checks if required image is specified.
:param param_name: name of parameter
@@ -391,21 +493,30 @@ class RequiredParamOrContextValidator(validation.Validator):
self.param_name = param_name
self.ctx_name = ctx_name
def validate(self, context, config, plugin_cls, plugin_cfg):
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
msg = ("You should specify either scenario argument %s or"
" use context %s." % (self.param_name, self.ctx_name))
if self.ctx_name in config.get("contexts", {}):
return
if self.param_name in config.get("args", {}):
return
if config is not None:
if self.ctx_name in config.get("contexts", {}):
return
if self.param_name in config.get("args", {}):
return
self.fail(msg)
@validation.configure(name="file_exists")
class FileExistsValidator(validation.Validator):
def __init__(self, param_name, mode=os.R_OK, required=True):
def __init__(
self, param_name: str, mode: int = os.R_OK, required: bool = True
) -> None:
"""Validator checks parameter is proper path to file with proper mode.
Ensure a file exists and can be accessed with the specified mode.
@@ -428,7 +539,10 @@ class FileExistsValidator(validation.Validator):
self.mode = mode
self.required = required
def _file_access_ok(self, filename, mode, param_name, required=True):
def _file_access_ok(
self, filename: str | None, mode: int, param_name: str,
required: bool = True
) -> None:
if not filename:
if not required:
return
@@ -439,7 +553,16 @@ class FileExistsValidator(validation.Validator):
"mode": mode,
"param_name": param_name})
def validate(self, context, config, plugin_cls, plugin_cfg):
def validate(
self,
context: dict[str, t.Any],
config: dict[str, t.Any] | None,
plugin_cls: type[plugin.Plugin],
plugin_cfg: dict[str, t.Any] | None
) -> None:
filename = None
if config is not None:
filename = config.get("args", {}).get(self.param_name)
self._file_access_ok(config.get("args", {}).get(self.param_name),
self.mode, self.param_name, self.required)
self._file_access_ok(filename, self.mode, self.param_name,
self.required)

View File

@@ -13,6 +13,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import typing as t
from rally import consts
from rally.task import hook
@@ -63,12 +67,12 @@ class EventTrigger(hook.HookTrigger):
]
}
def get_listening_event(self):
def get_listening_event(self) -> str:
return self.config["unit"]
def on_event(self, event_type, value=None):
def on_event(self, event_type: str, value: t.Any = None) -> bool:
if not (event_type == self.get_listening_event()
and value in self.config["at"]):
# do nothing
return
super(EventTrigger, self).on_event(event_type, value)
return False
return super(EventTrigger, self).on_event(event_type, value)

View File

@@ -13,9 +13,16 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import typing as t
from rally import consts
from rally.task import hook
if t.TYPE_CHECKING: # pragma: no cover
from rally.common import objects
@hook.configure(name="periodic")
class PeriodicTrigger(hook.HookTrigger):
@@ -51,19 +58,24 @@ class PeriodicTrigger(hook.HookTrigger):
]
}
def __init__(self, context, task, hook_cls):
super(PeriodicTrigger, self).__init__(context, task, hook_cls)
def __init__(
self,
hook_cfg: dict[str, t.Any],
task: objects.Task,
hook_cls: type[hook.HookAction]
) -> None:
super(PeriodicTrigger, self).__init__(hook_cfg, task, hook_cls)
self.config.setdefault(
"start", 0 if self.config["unit"] == "time" else 1)
self.config.setdefault("end", float("Inf"))
def get_listening_event(self):
def get_listening_event(self) -> str:
return self.config["unit"]
def on_event(self, event_type, value=None):
def on_event(self, event_type: str, value: t.Any = None) -> bool:
if not (event_type == self.get_listening_event()
and self.config["start"] <= value <= self.config["end"]
and (value - self.config["start"]) % self.config["step"] == 0):
# do nothing
return
super(PeriodicTrigger, self).on_event(event_type, value)
return False
return super(PeriodicTrigger, self).on_event(event_type, value)

View File

@@ -36,7 +36,7 @@ class SysCallHook(hook.HookAction):
"description": "Command to execute."
}
def run(self):
def run(self) -> None:
LOG.debug("sys_call hook: Running command %s" % self.config)
proc = subprocess.Popen(shlex.split(self.config),
stdout=subprocess.PIPE,

View File

@@ -37,7 +37,7 @@ def _worker_process(
duration: float | None,
context: dict[str, t.Any],
cls: type[runner.scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
args: dict[str, t.Any],
event_queue: multiprocessing.Queue[dict[str, t.Any]],
aborted: multiprocessing.synchronize.Event,
@@ -225,7 +225,7 @@ class ConstantScenarioRunner(runner.ScenarioRunner):
def _run_scenario(
self,
cls: type[runner.scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
context: dict[str, t.Any],
args: dict[str, t.Any]
) -> None:
@@ -330,7 +330,7 @@ class ConstantForDurationScenarioRunner(runner.ScenarioRunner):
def _run_scenario(
self,
cls: type[runner.scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
context: dict[str, t.Any],
args: dict[str, t.Any]
) -> None:

View File

@@ -254,7 +254,7 @@ class RPSScenarioRunner(runner.ScenarioRunner):
def _run_scenario(
self,
cls: type[scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
context: dict[str, t.Any],
args: dict[str, t.Any]
) -> None:

View File

@@ -53,7 +53,7 @@ class SerialScenarioRunner(runner.ScenarioRunner):
def _run_scenario(
self,
cls: type[scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
context: dict[str, t.Any],
args: dict[str, t.Any]
) -> None:

View File

@@ -10,7 +10,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import random
import typing as t
from rally.common import utils
from rally.common import validation
@@ -30,7 +33,13 @@ class DummyScenarioException(exceptions.RallyException):
@scenario.configure(name="Dummy.failure")
class DummyFailure(scenario.Scenario):
def run(self, sleep=0.1, from_iteration=0, to_iteration=0, each=1):
def run(
self,
sleep: float = 0.1,
from_iteration: int = 0,
to_iteration: int = 0,
each: int = 1,
) -> None:
"""Raise errors in some iterations.
:param sleep: float iteration sleep time in seconds
@@ -52,14 +61,14 @@ class DummyFailure(scenario.Scenario):
class Dummy(scenario.Scenario):
@atomic.action_timer("bar")
def bar(self, sleep):
def bar(self, sleep: float) -> None:
utils.interruptable_sleep(sleep)
@atomic.action_timer("foo")
def foo(self, sleep):
def foo(self, sleep: float) -> None:
self.bar(sleep)
def run(self, sleep=0):
def run(self, sleep: float = 0, **kwargs: t.Any) -> None:
"""Do nothing and sleep for the given number of seconds (0 by default).
Dummy.dummy can be used for testing performance of different
@@ -76,7 +85,12 @@ class Dummy(scenario.Scenario):
@scenario.configure(name="Dummy.dummy_exception")
class DummyException(scenario.Scenario):
def run(self, size_of_message=1, sleep=1, message=""):
def run(
self,
size_of_message: int = 1,
sleep: float = 1,
message: str = ""
) -> None:
"""Throws an exception.
Dummy.dummy_exception used for testing if exceptions are processed
@@ -99,7 +113,7 @@ class DummyException(scenario.Scenario):
@scenario.configure(name="Dummy.dummy_exception_probability")
class DummyExceptionProbability(scenario.Scenario):
def run(self, exception_probability=0.5):
def run(self, exception_probability: float = 0.5) -> None:
"""Throws an exception with given probability.
Dummy.dummy_exception_probability used for testing if exceptions are
@@ -119,7 +133,7 @@ class DummyExceptionProbability(scenario.Scenario):
@scenario.configure(name="Dummy.dummy_output")
class DummyOutput(scenario.Scenario):
def run(self, random_range=25):
def run(self, random_range: int = 25) -> None:
"""Generate dummy output.
This scenario generates example of output data.
@@ -198,7 +212,7 @@ class DummyOutput(scenario.Scenario):
class DummyRandomFailInAtomic(scenario.Scenario):
"""Randomly throw exceptions in atomic actions."""
def _play_roulette(self, exception_probability):
def _play_roulette(self, exception_probability: float) -> None:
"""Throw an exception with given probability.
:raises KeyError: when exception_probability is bigger
@@ -206,7 +220,7 @@ class DummyRandomFailInAtomic(scenario.Scenario):
if random.random() < exception_probability:
raise KeyError("Dummy test exception")
def run(self, exception_probability=0.5):
def run(self, exception_probability: float = 0.5) -> None:
"""Dummy.dummy_random_fail_in_atomic in dummy actions.
Can be used to test atomic actions
@@ -232,7 +246,12 @@ class DummyRandomFailInAtomic(scenario.Scenario):
@scenario.configure(name="Dummy.dummy_random_action")
class DummyRandomAction(scenario.Scenario):
def run(self, actions_num=5, sleep_min=0, sleep_max=0):
def run(
self,
actions_num: int = 5,
sleep_min: float = 0,
sleep_max: float = 0
) -> None:
"""Sleep random time in dummy actions.
:param actions_num: int number of actions to generate
@@ -248,7 +267,11 @@ class DummyRandomAction(scenario.Scenario):
@scenario.configure(name="Dummy.dummy_timed_atomic_actions")
class DummyTimedAtomicAction(scenario.Scenario):
def run(self, number_of_actions=5, sleep_factor=1):
def run(
self,
number_of_actions: int = 5,
sleep_factor: float = 1
) -> None:
"""Run some sleepy atomic actions for SLA atomic action tests.
:param number_of_actions: int number of atomic actions to create

View File

@@ -11,6 +11,7 @@
# under the License.
import random
import typing as t
from rally.plugins.task.scenarios.requests import utils
from rally.task import scenario
@@ -22,7 +23,9 @@ from rally.task import scenario
@scenario.configure(name="HttpRequests.check_request")
class HttpRequestsCheckRequest(utils.RequestScenario):
def run(self, url, method, status_code, **kwargs):
def run(
self, url: str, method: str, status_code: int, **kwargs: t.Any
) -> None:
"""Standard way for testing web services using HTTP requests.
This scenario is used to make request and check it with expected
@@ -40,7 +43,11 @@ class HttpRequestsCheckRequest(utils.RequestScenario):
@scenario.configure(name="HttpRequests.check_random_request")
class HttpRequestsCheckRandomRequest(utils.RequestScenario):
def run(self, requests, status_code):
def run(
self,
requests: list[dict[str, t.Any]],
status_code: int,
) -> None:
"""Executes random HTTP requests from provided list.
This scenario takes random url from list of requests, and raises
@@ -48,7 +55,7 @@ class HttpRequestsCheckRandomRequest(utils.RequestScenario):
:param requests: List of request dicts
:param status_code: Expected Response Code it will
be used only if we doesn't specified it in request proper
be used only if we doesn't specified it in request proper
"""
request = random.choice(requests)

View File

@@ -10,6 +10,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as t
import requests
from rally.task import atomic
@@ -20,7 +22,9 @@ class RequestScenario(scenario.Scenario):
"""Base class for Request scenarios with basic atomic actions."""
@atomic.action_timer("requests.check_request")
def _check_request(self, url, method, status_code, **kwargs):
def _check_request(
self, url: str, method: str, status_code: int, **kwargs: t.Any
) -> None:
"""Compare request status code with specified code
:param status_code: Expected status code of request

View File

@@ -19,9 +19,16 @@ SLA (Service-level agreement) is set of details for determining compliance
with contracted values such as maximum error rate or minimum response time.
"""
from __future__ import annotations
import typing as t
from rally import consts
from rally.task import sla
if t.TYPE_CHECKING: # pragma: no cover
from rally.task import runner
@sla.configure(name="failure_rate")
class FailureRate(sla.SLA):
@@ -37,7 +44,7 @@ class FailureRate(sla.SLA):
"additionalProperties": False,
}
def __init__(self, criterion_value):
def __init__(self, criterion_value: dict[str, float]) -> None:
super(FailureRate, self).__init__(criterion_value)
self.min_percent = self.criterion_value.get("min", 0)
self.max_percent = self.criterion_value.get("max", 100)
@@ -45,7 +52,7 @@ class FailureRate(sla.SLA):
self.total = 0
self.error_rate = 0.0
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
self.total += 1
if iteration["error"]:
self.errors += 1
@@ -53,7 +60,7 @@ class FailureRate(sla.SLA):
self.success = self.min_percent <= self.error_rate <= self.max_percent
return self.success
def merge(self, other):
def merge(self, other: FailureRate) -> bool:
self.total += other.total
self.errors += other.errors
if self.total:
@@ -61,7 +68,7 @@ class FailureRate(sla.SLA):
self.success = self.min_percent <= self.error_rate <= self.max_percent
return self.success
def details(self):
def details(self) -> str:
return ("Failure rate criteria %.2f%% <= %.2f%% <= %.2f%% - %s" %
(self.min_percent, self.error_rate,
self.max_percent, self.status()))

View File

@@ -19,9 +19,16 @@ SLA (Service-level agreement) is set of details for determining compliance
with contracted values such as maximum error rate or minimum response time.
"""
from __future__ import annotations
import typing as t
from rally import consts
from rally.task import sla
if t.TYPE_CHECKING: # pragma: no cover
from rally.task import runner
@sla.configure(name="max_seconds_per_iteration")
class IterationTime(sla.SLA):
@@ -32,22 +39,22 @@ class IterationTime(sla.SLA):
"minimum": 0.0,
"exclusiveMinimum": 0.0}
def __init__(self, criterion_value):
def __init__(self, criterion_value: float) -> None:
super(IterationTime, self).__init__(criterion_value)
self.max_iteration_time = 0.0
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
if iteration["duration"] > self.max_iteration_time:
self.max_iteration_time = iteration["duration"]
self.success = self.max_iteration_time <= self.criterion_value
return self.success
def merge(self, other):
def merge(self, other: IterationTime) -> bool:
if other.max_iteration_time > self.max_iteration_time:
self.max_iteration_time = other.max_iteration_time
self.success = self.max_iteration_time <= self.criterion_value
return self.success
def details(self):
def details(self) -> str:
return ("Maximum seconds per iteration %.2fs <= %.2fs - %s" %
(self.max_iteration_time, self.criterion_value, self.status()))

View File

@@ -19,10 +19,17 @@ SLA (Service-level agreement) is set of details for determining compliance
with contracted values such as maximum error rate or minimum response time.
"""
from __future__ import annotations
import typing as t
from rally.common import streaming_algorithms
from rally import consts
from rally.task import sla
if t.TYPE_CHECKING: # pragma: no cover
from rally.task import runner
@sla.configure(name="max_avg_duration")
class MaxAverageDuration(sla.SLA):
@@ -33,24 +40,24 @@ class MaxAverageDuration(sla.SLA):
"exclusiveMinimum": 0.0
}
def __init__(self, criterion_value):
def __init__(self, criterion_value: float) -> None:
super(MaxAverageDuration, self).__init__(criterion_value)
self.avg = 0.0
self.avg_comp = streaming_algorithms.MeanComputation()
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
if not iteration.get("error"):
self.avg_comp.add(iteration["duration"])
self.avg = self.avg_comp.result()
self.success = self.avg <= self.criterion_value
return self.success
def merge(self, other):
def merge(self, other: MaxAverageDuration) -> bool:
self.avg_comp.merge(other.avg_comp)
self.avg = self.avg_comp.result() or 0.0
self.success = self.avg <= self.criterion_value
return self.success
def details(self):
def details(self) -> str:
return ("Average duration of one iteration %.2fs <= %.2fs - %s" %
(self.avg, self.criterion_value, self.status()))

View File

@@ -19,34 +19,49 @@ SLA (Service-level agreement) is set of details for determining compliance
with contracted values such as maximum error rate or minimum response time.
"""
from __future__ import annotations
import collections
import typing as t
from rally.common import streaming_algorithms
from rally import consts
from rally.task import sla
if t.TYPE_CHECKING: # pragma: no cover
from rally.task import runner
@sla.configure(name="max_avg_duration_per_atomic")
class MaxAverageDurationPerAtomic(sla.SLA):
"""Maximum average duration of one iterations atomic actions in seconds."""
CONFIG_SCHEMA = {"type": "object", "$schema": consts.JSON_SCHEMA,
"patternProperties": {".*": {
"type": "number",
"description": "The name of atomic action."}},
"minProperties": 1,
"additionalProperties": False}
CONFIG_SCHEMA = {
"type": "object",
"$schema": consts.JSON_SCHEMA,
"patternProperties": {
".*": {
"type": "number",
"description": "The name of atomic action."
}
},
"minProperties": 1,
"additionalProperties": False
}
def __init__(self, criterion_value):
def __init__(self, criterion_value: dict[str, float]) -> None:
super(MaxAverageDurationPerAtomic, self).__init__(criterion_value)
self.avg_by_action = collections.defaultdict(float)
self.avg_comp_by_action = collections.defaultdict(
streaming_algorithms.MeanComputation)
self.avg_by_action: dict[str, float] = collections.defaultdict(float)
self.avg_comp_by_action: collections.defaultdict[
str, streaming_algorithms.MeanComputation
] = collections.defaultdict(streaming_algorithms.MeanComputation)
self.criterion_items = self.criterion_value.items()
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
if not iteration.get("error"):
for action in iteration["atomic_actions"]:
duration = action["finished_at"] - action["started_at"]
started_at = action["started_at"] or 0.0
finished_at = action["finished_at"] or started_at
duration = finished_at - started_at
self.avg_comp_by_action[action["name"]].add(duration)
result = self.avg_comp_by_action[action["name"]].result()
self.avg_by_action[action["name"]] = result
@@ -54,7 +69,7 @@ class MaxAverageDurationPerAtomic(sla.SLA):
for atom, val in self.criterion_items)
return self.success
def merge(self, other):
def merge(self, other: MaxAverageDurationPerAtomic) -> bool:
for atom, comp in self.avg_comp_by_action.items():
if atom in other.avg_comp_by_action:
comp.merge(other.avg_comp_by_action[atom])
@@ -64,7 +79,7 @@ class MaxAverageDurationPerAtomic(sla.SLA):
for atom, val in self.criterion_items)
return self.success
def details(self):
def details(self) -> str:
strs = ["Action: '%s'. %.2fs <= %.2fs" %
(atom, self.avg_by_action[atom], val)
for atom, val in self.criterion_items]

View File

@@ -19,10 +19,17 @@ SLA (Service-level agreement) is set of details for determining compliance
with contracted values such as maximum error rate or minimum response time.
"""
from __future__ import annotations
import typing as t
from rally.common import streaming_algorithms
from rally import consts
from rally.task import sla
if t.TYPE_CHECKING: # pragma: no cover
from rally.task import runner
@sla.configure(name="outliers")
class Outliers(sla.SLA):
@@ -43,7 +50,7 @@ class Outliers(sla.SLA):
"additionalProperties": False,
}
def __init__(self, criterion_value):
def __init__(self, criterion_value: dict[str, t.Any]) -> None:
super(Outliers, self).__init__(criterion_value)
self.max_outliers = self.criterion_value.get("max", 0)
# NOTE(msdubov): Having 3 as default is reasonable (need enough data).
@@ -51,11 +58,11 @@ class Outliers(sla.SLA):
self.sigmas = self.criterion_value.get("sigmas", 3.0)
self.iterations = 0
self.outliers = 0
self.threshold = None
self.threshold: float | None = None
self.mean_comp = streaming_algorithms.MeanComputation()
self.std_comp = streaming_algorithms.StdDevComputation()
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
# NOTE(ikhudoshyn): This method can not be implemented properly.
# After adding a new iteration, both mean and standard deviation
# may change. Hence threshold will change as well. In this case we
@@ -84,7 +91,7 @@ class Outliers(sla.SLA):
self.success = self.outliers <= self.max_outliers
return self.success
def merge(self, other):
def merge(self, other: Outliers) -> bool:
# NOTE(ikhudoshyn): This method can not be implemented properly.
# After merge, both mean and standard deviation may change.
# Hence threshold will change as well. In this case we
@@ -106,6 +113,6 @@ class Outliers(sla.SLA):
self.success = self.outliers <= self.max_outliers
return self.success
def details(self):
def details(self) -> str:
return ("Maximum number of outliers %i <= %i - %s" %
(self.outliers, self.max_outliers, self.status()))

View File

@@ -19,11 +19,18 @@ SLA (Service-level agreement) is set of details for determining compliance
with contracted values such as maximum error rate or minimum response time.
"""
from __future__ import annotations
import typing as t
from rally.common import streaming_algorithms
from rally import consts
from rally.task import sla
from rally.utils import strutils
if t.TYPE_CHECKING: # pragma: no cover
from rally.task import runner
@sla.configure(name="performance_degradation")
class PerformanceDegradation(sla.SLA):
@@ -49,22 +56,22 @@ class PerformanceDegradation(sla.SLA):
"additionalProperties": False,
}
def __init__(self, criterion_value):
def __init__(self, criterion_value: dict[str, float]) -> None:
super(PerformanceDegradation, self).__init__(criterion_value)
self.max_degradation = self.criterion_value["max_degradation"]
self.degradation = streaming_algorithms.DegradationComputation()
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
if not iteration.get("error"):
self.degradation.add(iteration["duration"])
self.success = self.degradation.result() <= self.max_degradation
return self.success
def merge(self, other):
def merge(self, other: PerformanceDegradation) -> bool:
self.degradation.merge(other.degradation)
self.success = self.degradation.result() <= self.max_degradation
return self.success
def details(self):
def details(self) -> str:
res = strutils.format_float_to_str(self.degradation.result() or 0.0)
return "Current degradation: %s%% - %s" % (res, self.status())

View File

@@ -13,6 +13,7 @@
# under the License.
import os
import typing as t
import requests
@@ -25,7 +26,7 @@ from rally.task import types
class PathOrUrl(types.ResourceType):
"""Check whether file exists or url available."""
def pre_process(self, resource_spec, config):
def pre_process(self, resource_spec: str, config: dict[str, t.Any]) -> str:
path = os.path.expanduser(resource_spec)
if os.path.isfile(path):
return path
@@ -44,7 +45,7 @@ class PathOrUrl(types.ResourceType):
class FileType(types.ResourceType):
"""Return content of the file by its path."""
def pre_process(self, resource_spec, config):
def pre_process(self, resource_spec: str, config: dict[str, t.Any]) -> str:
with open(os.path.expanduser(resource_spec), "r") as f:
return f.read()
@@ -53,7 +54,7 @@ class FileType(types.ResourceType):
class ExpandUserPath(types.ResourceType):
"""Expands user path."""
def pre_process(self, resource_spec, config):
def pre_process(self, resource_spec: str, config: dict[str, t.Any]) -> str:
return os.path.expanduser(resource_spec)
@@ -61,8 +62,10 @@ class ExpandUserPath(types.ResourceType):
class FileTypeDict(types.ResourceType):
"""Return the dictionary of items with file path and file content."""
def pre_process(self, resource_spec, config):
file_type_dict = {}
def pre_process(
self, resource_spec: list[str], config: dict[str, t.Any]
) -> dict[str, str]:
file_type_dict: dict[str, str] = {}
for file_path in resource_spec:
file_path = os.path.expanduser(file_path)
with open(file_path, "r") as f:

0
rally/py.typed Normal file
View File

View File

@@ -13,9 +13,13 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import abc
import collections
import threading
import typing as t
import typing_extensions as te
from rally.common import logging
from rally.common.plugin import plugin
@@ -24,8 +28,15 @@ from rally.common import validation
from rally import consts
from rally import exceptions
from rally.task.processing import charts
from rally.task import scenario
from rally.task import utils
if t.TYPE_CHECKING: # pragma: no cover
from rally.common import objects
A = t.TypeVar("A", bound="HookAction")
T = t.TypeVar("T", bound="HookTrigger")
LOG = logging.getLogger(__name__)
@@ -33,14 +44,32 @@ LOG = logging.getLogger(__name__)
configure = plugin.configure
class HookExecutor(object):
class HookResult(t.TypedDict):
"""Structure for hook execution result."""
status: str
started_at: float
finished_at: float
triggered_by: dict[str, t.Any]
error: te.NotRequired[dict[str, str]]
output: te.NotRequired[scenario._Output]
class TriggerResults(t.TypedDict):
"""Structure for trigger results collection."""
config: dict[str, t.Any]
results: list[HookResult]
summary: dict[str, int]
class HookExecutor:
"""Runs hooks and collects results from them."""
def __init__(self, config, task):
def __init__(self, config: dict[str, t.Any], task: objects.Task) -> None:
self.config = config
self.task = task
self.triggers = collections.defaultdict(list)
self.triggers: collections.defaultdict[str, list[HookTrigger]] = (
collections.defaultdict(list))
for hook_cfg in config.get("hooks", []):
action_name = hook_cfg["action"][0]
trigger_name = hook_cfg["trigger"][0]
@@ -54,7 +83,7 @@ class HookExecutor(object):
self._timer_thread = threading.Thread(target=self._timer_method)
self._timer_stop_event = threading.Event()
def _timer_method(self):
def _timer_method(self) -> None:
"""Timer thread method.
It generates events with type "time" to inform HookExecutor
@@ -68,15 +97,15 @@ class HookExecutor(object):
seconds_since_start += 1
stopwatch.sleep(seconds_since_start)
def _start_timer(self):
def _start_timer(self) -> None:
self._timer_thread.start()
def _stop_timer(self):
def _stop_timer(self) -> None:
self._timer_stop_event.set()
if self._timer_thread.ident is not None:
self._timer_thread.join()
def on_event(self, event_type, value):
def on_event(self, event_type: str, value: t.Any) -> None:
"""Notify about event.
This method should be called to inform HookExecutor that
@@ -95,7 +124,7 @@ class HookExecutor(object):
% (trigger_obj.hook_cls.__name__, self.task["uuid"],
event_type, value))
def results(self):
def results(self) -> list[TriggerResults]:
"""Returns list of dicts with hook results."""
if "time" in self.triggers:
self._stop_timer()
@@ -114,25 +143,32 @@ class HookAction(plugin.Plugin, validation.ValidatablePluginMixin,
CONFIG_SCHEMA = {"type": "null"}
def __init__(self, task, config, triggered_by):
def __init__(
self,
task: objects.Task,
config: t.Any,
triggered_by: dict[str, t.Any]
) -> None:
self.task = task
self.config = config
self._triggered_by = triggered_by
self._thread = threading.Thread(target=self._thread_method)
self._started_at = 0.0
self._finished_at = 0.0
self._result = {
self._result: HookResult = {
"status": consts.HookStatus.SUCCESS,
"started_at": self._started_at,
"finished_at": self._finished_at,
"triggered_by": self._triggered_by,
}
def _thread_method(self):
def _thread_method(self) -> None:
# Run hook synchronously
self.run_sync()
def set_error(self, exception_name, description, details):
def set_error(
self, exception_name: str, description: str, details: str
) -> None:
"""Set error related information to result.
:param exception_name: name of exception as string
@@ -143,11 +179,15 @@ class HookAction(plugin.Plugin, validation.ValidatablePluginMixin,
self._result["error"] = {"etype": exception_name,
"msg": description, "details": details}
def set_status(self, status):
def set_status(self, status: str) -> None:
"""Set status to result."""
self._result["status"] = status
def add_output(self, additive=None, complete=None):
def add_output(
self,
additive: dict[str, t.Any] | None = None,
complete: dict[str, t.Any] | None = None
) -> None:
"""Save custom output.
:param additive: dict with additive output
@@ -161,13 +201,15 @@ class HookAction(plugin.Plugin, validation.ValidatablePluginMixin,
message = charts.validate_output(key, value)
if message:
raise exceptions.RallyException(message)
self._result["output"][key].append(value)
self._result["output"][
key # type: ignore[literal-required]
].append(value)
def run_async(self):
def run_async(self) -> None:
"""Run hook asynchronously."""
self._thread.start()
def run_sync(self):
def run_sync(self) -> None:
"""Run hook synchronously."""
try:
with rutils.Timer() as timer:
@@ -182,7 +224,7 @@ class HookAction(plugin.Plugin, validation.ValidatablePluginMixin,
self._result["finished_at"] = self._finished_at
@abc.abstractmethod
def run(self):
def run(self) -> None:
"""Run method.
This method should be implemented in plugin.
@@ -195,12 +237,12 @@ class HookAction(plugin.Plugin, validation.ValidatablePluginMixin,
add_output - provide data for report
"""
def result(self):
def result(self) -> HookResult:
"""Wait and return result of hook."""
if self._thread.ident is not None:
# hook is still running, wait for result
self._thread.join()
return self._result
return t.cast(HookResult, self._result)
@validation.add_default("jsonschema")
@@ -209,20 +251,25 @@ class HookTrigger(plugin.Plugin, validation.ValidatablePluginMixin,
metaclass=abc.ABCMeta):
"""Factory for hook trigger classes."""
CONFIG_SCHEMA = {"type": "null"}
CONFIG_SCHEMA: dict = {"type": "null"}
def __init__(self, hook_cfg, task, hook_cls):
def __init__(
self,
hook_cfg: dict[str, t.Any],
task: objects.Task,
hook_cls: type[HookAction]
) -> None:
self.hook_cfg = hook_cfg
self.config = self.hook_cfg["trigger"][1]
self.task = task
self.hook_cls = hook_cls
self._runs = []
self._runs: list[HookAction] = []
@abc.abstractmethod
def get_listening_event(self):
def get_listening_event(self) -> str:
"""Returns event type to listen."""
def on_event(self, event_type, value=None):
def on_event(self, event_type: str, value: t.Any = None) -> bool:
"""Launch hook on specified event."""
LOG.info("Hook action %s is triggered for Task %s by %s=%s"
% (self.hook_cls.get_name(), self.task["uuid"],
@@ -232,11 +279,14 @@ class HookTrigger(plugin.Plugin, validation.ValidatablePluginMixin,
{"event_type": event_type, "value": value})
action.run_async()
self._runs.append(action)
return True
def get_results(self):
results = {"config": self.hook_cfg,
"results": [],
"summary": {}}
def get_results(self) -> TriggerResults:
results: TriggerResults = {
"config": self.hook_cfg,
"results": [],
"summary": {}
}
for action in self._runs:
action_result = action.result()
results["results"].append(action_result)

View File

@@ -75,7 +75,7 @@ def _get_scenario_context(
def _run_scenario_once(
cls: type[scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
context_obj: dict[str, t.Any],
scenario_kwargs: dict[str, t.Any],
event_queue: multiprocessing.Queue[dict[str, t.Any]] | DequeAsQueue
@@ -118,7 +118,7 @@ def _run_scenario_once(
def _worker_thread(
queue: multiprocessing.Queue[ScenarioRunnerResult],
cls: type[scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
context_obj: dict[str, t.Any],
scenario_kwargs: dict[str, t.Any],
event_queue: multiprocessing.Queue[dict[str, t.Any]]
@@ -186,7 +186,7 @@ class ScenarioRunner(plugin.Plugin, validation.ValidatablePluginMixin,
def _run_scenario(
self,
cls: type[scenario.Scenario],
method_name: str,
method_name: t.Literal["run"],
context: dict[str, t.Any],
args: dict[str, t.Any]
) -> None:

View File

@@ -210,6 +210,20 @@ class Scenario(plugin.Plugin,
key # type: ignore[literal-required]
].append(value)
if not t.TYPE_CHECKING:
def run(self, **kwargs: t.Any) -> None:
"""Execute the scenario's workload.
This method must be implemented by all scenario plugins.
It defines the actual workload that the scenario will execute.
:param kwargs: Scenario-specific arguments from task configuration
"""
raise NotImplementedError()
else:
run: t.Callable
@classmethod
def _get_doc(cls) -> str | None:
return getattr(cls, "run", None).__doc__ or ""
def _get_doc(cls) -> str:
"""Get scenario documentation from run method."""
return cls.run.__doc__ or ""

View File

@@ -19,36 +19,56 @@ SLA (Service-level agreement) is set of details for determining compliance
with contracted values such as maximum error rate or minimum response time.
"""
from __future__ import annotations
import abc
import itertools
import typing as t
from rally.common.plugin import plugin
from rally.common import validation
if t.TYPE_CHECKING: # pragma: no cover
from rally.task import runner
S = t.TypeVar("S", bound="SLA")
configure = plugin.configure
def _format_result(criterion_name, success, detail):
class SLAResult(t.TypedDict):
"""Structure for SLA result data."""
criterion: str
success: bool
detail: str
def _format_result(
criterion_name: str, success: bool, detail: str
) -> SLAResult:
"""Returns the SLA result dict corresponding to the current state."""
return {"criterion": criterion_name,
"success": success,
"detail": detail}
class SLAChecker(object):
class SLAChecker:
"""Base SLA checker class."""
def __init__(self, config):
def __init__(self, config: dict[str, t.Any]) -> None:
self.config = config
self.unexpected_failure = None
self.unexpected_failure: Exception | None = None
self.aborted_on_sla = False
self.aborted_manually = False
self.sla_criteria = [SLA.get(name)(criterion_value)
for name, criterion_value
in config.get("sla", {}).items()]
self.sla_criteria: list[SLA] = [
SLA.get(name)(criterion_value)
for name, criterion_value
in config.get("sla", {}).items()
]
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
"""Process the result of a single iteration.
The call to add_iteration() will return True if all the SLA checks
@@ -58,7 +78,7 @@ class SLAChecker(object):
"""
return all([sla.add_iteration(iteration) for sla in self.sla_criteria])
def merge(self, other):
def merge(self, other: SLAChecker) -> bool:
self._validate_config(other)
self._validate_sla_types(other)
@@ -66,12 +86,12 @@ class SLAChecker(object):
for self_sla, other_sla
in zip(self.sla_criteria, other.sla_criteria)])
def _validate_sla_types(self, other):
def _validate_sla_types(self, other: SLAChecker) -> None:
for self_sla, other_sla in itertools.zip_longest(
self.sla_criteria, other.sla_criteria):
self_sla.validate_type(other_sla)
def _validate_config(self, other):
def _validate_config(self, other: SLAChecker) -> None:
self_config = self.config.get("sla", {})
other_config = other.config.get("sla", {})
if self_config != other_config:
@@ -80,7 +100,7 @@ class SLAChecker(object):
"Only SLACheckers with the same config could be merged."
% (self_config, other_config))
def results(self):
def results(self) -> list[SLAResult]:
results = [sla.result() for sla in self.sla_criteria]
if self.aborted_on_sla:
results.append(_format_result(
@@ -99,13 +119,13 @@ class SLAChecker(object):
return results
def set_aborted_on_sla(self):
def set_aborted_on_sla(self) -> None:
self.aborted_on_sla = True
def set_aborted_manually(self):
def set_aborted_manually(self) -> None:
self.aborted_manually = True
def set_unexpected_failure(self, exc):
def set_unexpected_failure(self, exc: Exception) -> None:
self.unexpected_failure = exc
@@ -115,14 +135,14 @@ class SLA(plugin.Plugin, validation.ValidatablePluginMixin,
metaclass=abc.ABCMeta):
"""Factory for criteria classes."""
CONFIG_SCHEMA = {"type": "null"}
CONFIG_SCHEMA: dict = {"type": "null"}
def __init__(self, criterion_value):
def __init__(self, criterion_value: t.Any) -> None:
self.criterion_value = criterion_value
self.success = True
@abc.abstractmethod
def add_iteration(self, iteration):
def add_iteration(self, iteration: runner.ScenarioRunnerResult) -> bool:
"""Process the result of a single iteration and perform a SLA check.
The call to add_iteration() will return True if the SLA check passed,
@@ -132,20 +152,20 @@ class SLA(plugin.Plugin, validation.ValidatablePluginMixin,
:returns: True if the SLA check passed, False otherwise
"""
def result(self):
def result(self) -> SLAResult:
"""Returns the SLA result dict corresponding to the current state."""
return _format_result(self.get_name(), self.success, self.details())
@abc.abstractmethod
def details(self):
def details(self) -> str:
"""Returns the string describing the current results of the SLA."""
def status(self):
def status(self) -> str:
"""Return "Passed" or "Failed" depending on the current SLA status."""
return "Passed" if self.success else "Failed"
@abc.abstractmethod
def merge(self, other):
def merge(self: S, other: S) -> bool:
"""Merge aggregated data from another SLA instance into self.
Process the results of several iterations aggregated in another
@@ -177,7 +197,7 @@ class SLA(plugin.Plugin, validation.ValidatablePluginMixin,
:returns: True if the SLA check passed, False otherwise
"""
def validate_type(self, other):
def validate_type(self: S, other: S) -> None:
if type(self) is not type(other):
raise TypeError(
"Error merging SLAs of types %s, %s. Only SLAs of the same "

View File

@@ -13,21 +13,30 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import abc
import copy
import operator
import re
import typing as t
from rally.common import logging
from rally.common.plugin import plugin
from rally import exceptions
from rally.task import scenario
if t.TYPE_CHECKING: # pragma: no cover
RT = t.TypeVar("RT", bound="ResourceType")
S = t.TypeVar("S", bound=scenario.Scenario)
LOG = logging.getLogger(__name__)
def convert(**kwargs):
def convert(
**kwargs: dict[str, t.Any]
) -> t.Callable[[type[S]], type[S]]:
"""Decorator to define resource transformation(s) on scenario parameters.
The ``kwargs`` passed as arguments are used to map a key in the
@@ -42,7 +51,7 @@ def convert(**kwargs):
may be added in the future.
"""
def wrapper(cls):
def wrapper(cls: type[S]) -> type[S]:
for k, v in kwargs.items():
if "type" not in v:
LOG.warning(
@@ -56,7 +65,11 @@ def convert(**kwargs):
return wrapper
def preprocess(name, context, args):
def preprocess(
name: str,
context: dict[str, t.Any],
args: dict[str, t.Any]
) -> dict[str, t.Any]:
"""Run preprocessor on scenario arguments.
:param name: Scenario plugin name
@@ -72,8 +85,8 @@ def preprocess(name, context, args):
processed_args = copy.deepcopy(args)
cache = {}
resource_types = {}
cache: dict[str, t.Any] = {}
resource_types: dict[str, ResourceType] = {}
for src, type_cfg in preprocessors.items():
if type_cfg["type"] not in resource_types:
resource_cls = ResourceType.get(type_cfg["type"])
@@ -92,14 +105,18 @@ def preprocess(name, context, args):
class ResourceType(plugin.Plugin, metaclass=abc.ABCMeta):
"""A helper plugin for pre-processing input data of resources."""
def __init__(self, context, cache=None):
def __init__(
self, context: dict[str, t.Any], cache: dict[str, t.Any] | None = None
) -> None:
self._context = context
self._global_cache = cache if cache is not None else {}
self._global_cache.setdefault(self.get_name(), {})
self._cache = self._global_cache[self.get_name()]
@abc.abstractmethod
def pre_process(self, resource_spec, config):
def pre_process(
self, resource_spec: t.Any, config: dict[str, t.Any]
) -> t.Any:
"""Pre-process resource.
:param resource_spec: A specification of the resource from the task
@@ -107,7 +124,11 @@ class ResourceType(plugin.Plugin, metaclass=abc.ABCMeta):
"""
def obj_from_name(resource_config, resources, typename):
def obj_from_name(
resource_config: dict[str, t.Any],
resources: t.Iterable[t.Any],
typename: str
) -> t.Any:
"""Return the resource whose name matches the pattern.
resource_config has to contain `name`, as it is used to lookup a resource.
@@ -164,7 +185,11 @@ def obj_from_name(resource_config, resources, typename):
return matching[0]
def obj_from_id(resource_config, resources, typename):
def obj_from_id(
resource_config: dict[str, t.Any],
resources: t.Iterable[t.Any],
typename: str
) -> t.Any:
"""Return the resource whose name matches the id.
resource_config has to contain `id`, as it is used to lookup a resource.
@@ -190,7 +215,12 @@ def obj_from_id(resource_config, resources, typename):
typename=typename.title(), resource_config=resource_config))
def _id_from_name(resource_config, resources, typename, id_attr="id"):
def _id_from_name(
resource_config: dict[str, t.Any],
resources: t.Iterable[t.Any],
typename: str,
id_attr: str = "id"
) -> t.Any:
"""Return the id of the resource whose name matches the pattern.
resource_config has to contain `name`, as it is used to lookup an id.
@@ -215,7 +245,11 @@ def _id_from_name(resource_config, resources, typename, id_attr="id"):
attr=id_attr, type=typename))
def _name_from_id(resource_config, resources, typename):
def _name_from_id(
resource_config: dict[str, t.Any],
resources: t.Iterable[t.Any],
typename: str
) -> str:
"""Return the name of the resource which has the id.
resource_config has to contain `id`, as it is used to lookup a name.

View File

@@ -24,43 +24,31 @@ from rally.task import context
from rally.task import scenario
class FakeScenario(scenario.Scenario):
def idle_time(self):
return 0
def do_it(self, **kwargs):
pass
def with_output(self, **kwargs):
return {"data": {"a": 1}, "error": None}
def with_add_output(self):
self.add_output(additive={"title": "Additive",
"description": "Additive description",
"data": [["a", 1]],
"chart_plugin": "FooPlugin"},
complete={"title": "Complete",
"description": "Complete description",
"data": [["a", [[1, 2], [2, 3]]]],
"chart_plugin": "BarPlugin"})
def too_long(self, **kwargs):
pass
def something_went_wrong(self, **kwargs):
raise Exception("Something went wrong")
def raise_timeout(self, **kwargs):
raise multiprocessing.TimeoutError()
@scenario.configure(name="classbased.fooscenario")
class FakeClassBasedScenario(FakeScenario):
class FakeScenario(scenario.Scenario):
"""Fake class-based scenario."""
def run(self, *args, **kwargs):
pass
def run(
self,
*args,
raise_exc: bool = False,
raise_timeout_err: bool = False,
with_add_output: bool = False,
**kwargs
) -> None:
if raise_exc:
raise Exception("Something went wrong")
if raise_timeout_err:
raise multiprocessing.TimeoutError()
if with_add_output:
self.add_output(additive={"title": "Additive",
"description": "Additive description",
"data": [["a", 1]],
"chart_plugin": "FooPlugin"},
complete={"title": "Complete",
"description": "Complete description",
"data": [["a", [[1, 2], [2, 3]]]],
"chart_plugin": "BarPlugin"})
class FakeTimer(rally_utils.Timer):

View File

@@ -122,7 +122,7 @@ class ConstantScenarioRunnerTestCase(test.TestCase):
runner_obj = constant.ConstantScenarioRunner(self.task, self.config)
runner_obj._run_scenario(
fakes.FakeScenario, "do_it", self.context, self.args)
fakes.FakeScenario, "run", self.context, self.args)
self.assertEqual(self.config["times"], len(runner_obj.result_queue))
for result_batch in runner_obj.result_queue:
for result in result_batch:
@@ -131,8 +131,10 @@ class ConstantScenarioRunnerTestCase(test.TestCase):
def test__run_scenario_exception(self):
runner_obj = constant.ConstantScenarioRunner(self.task, self.config)
runner_obj._run_scenario(fakes.FakeScenario, "something_went_wrong",
self.context, self.args)
runner_obj._run_scenario(
fakes.FakeScenario, "run", self.context,
args=dict(raise_exc=True, **self.args)
)
self.assertEqual(self.config["times"], len(runner_obj.result_queue))
for result_batch in runner_obj.result_queue:
for result in result_batch:
@@ -143,8 +145,9 @@ class ConstantScenarioRunnerTestCase(test.TestCase):
runner_obj = constant.ConstantScenarioRunner(self.task, self.config)
runner_obj.abort()
runner_obj._run_scenario(fakes.FakeScenario, "do_it", self.context,
self.args)
runner_obj._run_scenario(
fakes.FakeScenario, "run", self.context, self.args
)
self.assertEqual(0, len(runner_obj.result_queue))
@mock.patch(RUNNERS + "constant.multiprocessing.Queue")
@@ -222,7 +225,7 @@ class ConstantScenarioRunnerTestCase(test.TestCase):
runner_obj = constant.ConstantScenarioRunner(self.task,
sample["input"])
runner_obj._run_scenario(fakes.FakeScenario, "do_it", self.context,
runner_obj._run_scenario(fakes.FakeScenario, "run", self.context,
self.args)
mock_cpu_count.assert_called_once_with()
@@ -279,7 +282,7 @@ class ConstantForDurationScenarioRunnerTestCase(test.TestCase):
runner_obj = constant.ConstantForDurationScenarioRunner(
mock.MagicMock(), self.config)
runner_obj._run_scenario(fakes.FakeScenario, "do_it",
runner_obj._run_scenario(fakes.FakeScenario, "run",
self.context, self.args)
# NOTE(mmorais/msimonin): when duration is 0, scenario executes exactly
# 1 time per unit of parrallelism
@@ -293,8 +296,10 @@ class ConstantForDurationScenarioRunnerTestCase(test.TestCase):
runner_obj = constant.ConstantForDurationScenarioRunner(
mock.MagicMock(), self.config)
runner_obj._run_scenario(fakes.FakeScenario, "something_went_wrong",
self.context, self.args)
runner_obj._run_scenario(
fakes.FakeScenario, "run", self.context,
args=dict(raise_exc=True, **self.args)
)
# NOTE(mmorais/msimonin): when duration is 0, scenario executes exactly
# 1 time per unit of parrallelism
expected_times = self.config["concurrency"]
@@ -308,8 +313,10 @@ class ConstantForDurationScenarioRunnerTestCase(test.TestCase):
runner_obj = constant.ConstantForDurationScenarioRunner(
mock.MagicMock(), self.config)
runner_obj._run_scenario(fakes.FakeScenario, "raise_timeout",
self.context, self.args)
runner_obj._run_scenario(
fakes.FakeScenario, "run", self.context,
args=dict(raise_timeout_err=True, **self.args)
)
# NOTE(mmorais/msimonin): when duration is 0, scenario executes exactly
# 1 time per unit of parrallelism
expected_times = self.config["concurrency"]
@@ -324,7 +331,7 @@ class ConstantForDurationScenarioRunnerTestCase(test.TestCase):
self.config)
runner_obj.abort()
runner_obj._run_scenario(fakes.FakeScenario, "do_it",
runner_obj._run_scenario(fakes.FakeScenario, "run",
self.context, self.args)
self.assertEqual(0, len(runner_obj.result_queue))

View File

@@ -264,7 +264,7 @@ class RPSScenarioRunnerTestCase(test.TestCase):
def test__run_scenario(self, mock_sleep, config):
runner_obj = rps.RPSScenarioRunner(self.task, config)
runner_obj._run_scenario(fakes.FakeScenario, "do_it",
runner_obj._run_scenario(fakes.FakeScenario, "run",
{"task": {"uuid": 1}}, {})
self.assertEqual(config["times"], len(runner_obj.result_queue))
@@ -278,8 +278,8 @@ class RPSScenarioRunnerTestCase(test.TestCase):
config = {"times": 4, "rps": 10}
runner_obj = rps.RPSScenarioRunner(self.task, config)
runner_obj._run_scenario(fakes.FakeScenario, "something_went_wrong",
{"task": {"uuid": 1}}, {})
runner_obj._run_scenario(fakes.FakeScenario, "run",
{"task": {"uuid": 1}}, {"raise_exc": True})
self.assertEqual(config["times"], len(runner_obj.result_queue))
for result_batch in runner_obj.result_queue:
for result in result_batch:
@@ -291,7 +291,7 @@ class RPSScenarioRunnerTestCase(test.TestCase):
runner_obj = rps.RPSScenarioRunner(self.task, config)
runner_obj.abort()
runner_obj._run_scenario(fakes.FakeScenario, "do_it",
runner_obj._run_scenario(fakes.FakeScenario, "run",
{}, {})
self.assertEqual(0, len(runner_obj.result_queue))
@@ -381,7 +381,7 @@ class RPSScenarioRunnerTestCase(test.TestCase):
runner_obj = rps.RPSScenarioRunner(self.task, sample["input"])
runner_obj._run_scenario(fakes.FakeScenario, "do_it", {}, {})
runner_obj._run_scenario(fakes.FakeScenario, "run", {}, {})
mock_cpu_count.assert_called_once_with()
mock__log_debug_info.assert_called_once_with(

View File

@@ -36,7 +36,7 @@ class SerialScenarioRunnerTestCase(test.TestCase):
runner = serial.SerialScenarioRunner(mock.MagicMock(),
{"times": times})
runner._run_scenario(fakes.FakeScenario, "do_it",
runner._run_scenario(fakes.FakeScenario, "run",
fakes.FakeContext().context, {})
self.assertEqual(times, len(runner.result_queue))
@@ -48,7 +48,7 @@ class SerialScenarioRunnerTestCase(test.TestCase):
ctxt["iteration"] = i + 1
ctxt["task"] = mock.ANY
expected_calls.append(
mock.call(fakes.FakeScenario, "do_it", ctxt, {},
mock.call(fakes.FakeScenario, "run", ctxt, {},
deque_as_queue_inst)
)
mock__run_scenario_once.assert_has_calls(expected_calls)
@@ -58,7 +58,7 @@ class SerialScenarioRunnerTestCase(test.TestCase):
runner = serial.SerialScenarioRunner(mock.MagicMock(),
{"times": 5})
runner.abort()
runner._run_scenario(fakes.FakeScenario, "do_it",
runner._run_scenario(fakes.FakeScenario, "run",
fakes.FakeContext().context, {})
self.assertEqual(0, len(runner.result_queue))

View File

@@ -76,7 +76,7 @@ class ScenarioRunnerHelpersTestCase(test.TestCase):
@mock.patch(BASE + "rutils.Timer", side_effect=fakes.FakeTimer)
def test_run_scenario_once_without_scenario_output(self, mock_timer):
result = runner._run_scenario_once(
fakes.FakeScenario, "do_it", mock.MagicMock(), {},
fakes.FakeScenario, "run", mock.MagicMock(), {},
mock.MagicMock())
expected_result = {
@@ -92,8 +92,8 @@ class ScenarioRunnerHelpersTestCase(test.TestCase):
@mock.patch(BASE + "rutils.Timer", side_effect=fakes.FakeTimer)
def test_run_scenario_once_with_added_scenario_output(self, mock_timer):
result = runner._run_scenario_once(
fakes.FakeScenario, "with_add_output", mock.MagicMock(), {},
mock.MagicMock())
fakes.FakeScenario, "run", mock.MagicMock(),
{"with_add_output": True}, mock.MagicMock())
expected_result = {
"duration": fakes.FakeTimer().duration(),
@@ -115,7 +115,7 @@ class ScenarioRunnerHelpersTestCase(test.TestCase):
@mock.patch(BASE + "rutils.Timer", side_effect=fakes.FakeTimer)
def test_run_scenario_once_exception(self, mock_timer):
result = runner._run_scenario_once(
fakes.FakeScenario, "something_went_wrong", mock.MagicMock(), {},
fakes.FakeScenario, "run", mock.MagicMock(), {"raise_exc": True},
mock.MagicMock())
expected_error = result.pop("error")
expected_result = {
@@ -139,7 +139,7 @@ class ScenarioRunnerTestCase(test.TestCase):
@mock.patch(BASE + "rutils.Timer.duration", return_value=10)
def test_run(self, mock_timer_duration):
scenario_class = fakes.FakeClassBasedScenario
scenario_class = fakes.FakeScenario
runner_obj = serial.SerialScenarioRunner(
mock.MagicMock(),
mock.MagicMock())

View File

@@ -143,10 +143,7 @@ commands = \
filterwarnings =
error
# we do not use anything inner from OptionParser, so we do not care about it's parent
ignore:The frontend.OptionParser class will be replaced by a subclass of argparse.ArgumentParser in Docutils 0.21 or later.:DeprecationWarning:
# we do not use Option directly, it is initialized by OptionParser by itself.
# as soon as docutils team get rid of frontend.Option, they will also fix OptionParser
ignore: The frontend.Option class will be removed in Docutils 0.21 or later.:DeprecationWarning:
ignore:The frontend.Option.* class will be.*:DeprecationWarning:
# python 3.10
ignore:The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives:DeprecationWarning:
# pytest-cov & pytest-xdist