diff --git a/rally/benchmark/scenarios/base.py b/rally/benchmark/scenarios/base.py index c98a350fde..69acf2170b 100644 --- a/rally/benchmark/scenarios/base.py +++ b/rally/benchmark/scenarios/base.py @@ -92,11 +92,11 @@ class Scenario(object): if "." in name: scenario_group, scenario_name = name.split(".", 1) scenario_cls = Scenario.get_by_name(scenario_group) - if hasattr(scenario_cls, scenario_name): + if Scenario.is_scenario(scenario_cls, scenario_name): return getattr(scenario_cls, scenario_name) else: for scenario_cls in utils.itersubclasses(Scenario): - if name in dir(scenario_cls): + if Scenario.is_scenario(scenario_cls, name): return getattr(scenario_cls, name) raise exceptions.NoSuchScenario(name=name) @@ -110,10 +110,8 @@ class Scenario(object): :returns: List of strings """ benchmark_scenarios = [ - ["%s.%s" % (scenario.__name__, method) - for method in dir(scenario) - if Scenario.meta(scenario, method_name=method, - attr_name="is_scenario", default=False)] + ["%s.%s" % (scenario.__name__, func) + for func in dir(scenario) if Scenario.is_scenario(scenario, func)] for scenario in utils.itersubclasses(Scenario) ] benchmark_scenarios_flattened = list(itertools.chain.from_iterable( @@ -169,6 +167,17 @@ class Scenario(object): method = getattr(cls, method_name) return copy.deepcopy(getattr(method, attr_name, default)) + @staticmethod + def is_scenario(cls, method_name): + """Check whether a given method in scenario class is a scenario. + + :param cls: scenario class + :param method_name: method name + :returns: True if the method is a benchmark scenario, False otherwise + """ + return (hasattr(cls, method_name) and + Scenario.meta(cls, "is_scenario", method_name, default=False)) + def context(self): """Returns the context of the current benchmark scenario.""" return self._context diff --git a/rally/cmd/commands/info.py b/rally/cmd/commands/info.py index ce72230553..43d6cd3859 100644 --- a/rally/cmd/commands/info.py +++ b/rally/cmd/commands/info.py @@ -64,6 +64,10 @@ class InfoCommands(object): print(info) else: print("Failed to find any docs for query: '%s'" % query) + substitutions = self._find_substitution(query) + if substitutions: + print("Did you mean one of these?\n\t%s" % + "\n\t".join(substitutions)) return 1 def _find_info(self, query): @@ -72,6 +76,25 @@ class InfoCommands(object): self._get_deploy_engine_info(query) or self._get_server_provider_info(query)) + def _find_substitution(self, query): + max_distance = min(3, len(query) / 4) + scenarios = scenario_base.Scenario.list_benchmark_scenarios() + scenario_groups = list(set(s.split(".")[0] for s in scenarios)) + scenario_methods = list(set(s.split(".")[1] for s in scenarios)) + deploy_engines = [cls.__name__ for cls in utils.itersubclasses( + deploy.EngineFactory)] + server_providers = [cls.__name__ for cls in utils.itersubclasses( + serverprovider.ProviderFactory)] + candidates = (scenarios + scenario_groups + scenario_methods + + deploy_engines + server_providers) + suggestions = [] + # NOTE(msdubov): Incorrect query may either have typos or be truncated. + for candidate in candidates: + if ((utils.distance(query, candidate) <= max_distance or + candidate.startswith(query))): + suggestions.append(candidate) + return suggestions + def _get_scenario_group_info(self, query): try: scenario_group = scenario_base.Scenario.get_by_name(query) diff --git a/rally/utils.py b/rally/utils.py index 33504e696c..009cbfff66 100644 --- a/rally/utils.py +++ b/rally/utils.py @@ -305,3 +305,22 @@ def parse_docstring(docstring): "params": [], "returns": None } + + +def distance(s1, s2): + """Computes the edit distance between two strings. + + The edit distance is the Levenshtein distance. The larger the return value, + the more edits are required to transform one string into the other. + + :param s1: First string to compare + :param s2: Second string to compare + :returns: Integer distance between two strings + """ + n = range(0, len(s1) + 1) + for y in range(1, len(s2) + 1): + l, n = n, [y] + for x in xrange(1, len(s1) + 1): + n.append(min(l[x] + 1, n[-1] + 1, l[x - 1] + + ((s2[y - 1] != s1[x - 1]) and 1 or 0))) + return n[-1] diff --git a/tests/functional/test_cli_info.py b/tests/functional/test_cli_info.py index f6bb5153e9..9c6f7fa07a 100644 --- a/tests/functional/test_cli_info.py +++ b/tests/functional/test_cli_info.py @@ -39,3 +39,21 @@ class InfoTestCase(unittest.TestCase): def test_find_server_provider(self): marker_string = "ExistingServers (server provider)." self.assertIn(marker_string, self.rally("info find ExistingServers")) + + def test_find_fails(self): + self.assertRaises(utils.RallyCmdError, self.rally, + ("info find NonExistingStuff")) + + def test_find_misspelling_typos(self): + marker_string = "ExistingServers" + try: + self.rally("info find ExistinfServert") + except utils.RallyCmdError as e: + self.assertIn(marker_string, e.output) + + def test_find_misspelling_truncated(self): + marker_string = "boot_and_delete_server" + try: + self.rally("info find boot_and_delete") + except utils.RallyCmdError as e: + self.assertIn(marker_string, e.output) diff --git a/tests/unit/benchmark/scenarios/test_base.py b/tests/unit/benchmark/scenarios/test_base.py index 65590c92f1..2991b6f990 100644 --- a/tests/unit/benchmark/scenarios/test_base.py +++ b/tests/unit/benchmark/scenarios/test_base.py @@ -179,31 +179,38 @@ class ScenarioTestCase(test.TestCase): MyFakeScenario.do_it.__dict__[attr_name] = preprocessors scenario = MyFakeScenario() - self.assertEqual(scenario.meta(cls=MyFakeScenario, method_name="do_it", + self.assertEqual(scenario.meta(cls=fakes.FakeScenario, + method_name="do_it", attr_name=attr_name), preprocessors) def test_meta_string_returns_empty_list(self): - - class MyFakeScenario(fakes.FakeScenario): - pass - empty_list = [] - scenario = MyFakeScenario() - self.assertEqual(scenario.meta(cls="MyFakeScenario.do_it", + scenario = fakes.FakeScenario() + self.assertEqual(scenario.meta(cls="FakeScenario.do_it", attr_name="foo", default=empty_list), empty_list) def test_meta_class_returns_empty_list(self): - - class MyFakeScenario(fakes.FakeScenario): - pass - empty_list = [] - scenario = MyFakeScenario() - self.assertEqual(scenario.meta(cls=MyFakeScenario, method_name="do_it", - attr_name="foo", default=empty_list), + scenario = fakes.FakeScenario() + self.assertEqual(scenario.meta(cls=fakes.FakeScenario, + method_name="do_it", attr_name="foo", + default=empty_list), empty_list) + def test_is_scenario_success(self): + scenario = dummy.Dummy() + self.assertTrue(base.Scenario.is_scenario(scenario, "dummy")) + + def test_is_scenario_not_scenario(self): + scenario = dummy.Dummy() + self.assertFalse(base.Scenario.is_scenario(scenario, + "_random_fail_emitter")) + + def test_is_scenario_non_existing(self): + scenario = dummy.Dummy() + self.assertFalse(base.Scenario.is_scenario(scenario, "non_existing")) + def test_sleep_between_invalid_args(self): scenario = base.Scenario() self.assertRaises(exceptions.InvalidArgumentsException, diff --git a/tests/unit/cmd/commands/test_info.py b/tests/unit/cmd/commands/test_info.py index f09d53369b..b3f9ff656c 100644 --- a/tests/unit/cmd/commands/test_info.py +++ b/tests/unit/cmd/commands/test_info.py @@ -20,13 +20,13 @@ from rally.cmd.commands import info from rally.deploy.engines import existing as existing_cloud from rally.deploy.serverprovider.providers import existing as existing_servers from rally import exceptions -from tests.unit import fakes from tests.unit import test SCENARIO = "rally.cmd.commands.info.scenario_base.Scenario" ENGINE = "rally.cmd.commands.info.deploy.EngineFactory" PROVIDER = "rally.cmd.commands.info.serverprovider.ProviderFactory" +DUMMY = "rally.benchmark.scenarios.dummy.dummy.Dummy" class InfoCommandsTestCase(test.TestCase): @@ -58,14 +58,6 @@ class InfoCommandsTestCase(test.TestCase): mock_get_scenario_by_name.assert_called_once_with(query) self.assertEqual(1, status) - @mock.patch(SCENARIO + ".get_scenario_by_name", - return_value=fakes.FakeScenario.do_it) - def test_find_scenario_with_empty_docs(self, mock_get_scenario_by_name): - query = "FakeScenario.do_it" - status = self.info.find(query) - mock_get_scenario_by_name.assert_called_once_with(query) - self.assertEqual(1, status) - @mock.patch(ENGINE + ".get_by_name", return_value=existing_cloud.ExistingCloud) def test_find_existing_cloud(self, mock_get_by_name): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index ccadc8f9e9..655d903054 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -309,3 +309,26 @@ description. "returns": "Return value description." } self.assertEqual(dct, expected) + + +class EditDistanceTestCase(test.TestCase): + + def test_distance_empty_strings(self): + dist = utils.distance("", "") + self.assertEqual(0, dist) + + def test_distance_equal_strings(self): + dist = utils.distance("abcde", "abcde") + self.assertEqual(0, dist) + + def test_distance_replacement(self): + dist = utils.distance("abcde", "__cde") + self.assertEqual(2, dist) + + def test_distance_insertion(self): + dist = utils.distance("abcde", "ab__cde") + self.assertEqual(2, dist) + + def test_distance_deletion(self): + dist = utils.distance("abcde", "abc") + self.assertEqual(2, dist)