diff --git a/mistral/tests/unit/test_expressions.py b/mistral/tests/unit/test_expressions.py index 2f008ce2..fb80d3bb 100644 --- a/mistral/tests/unit/test_expressions.py +++ b/mistral/tests/unit/test_expressions.py @@ -72,6 +72,7 @@ class YaqlEvaluatorTest(base.BaseTest): self.assertEqual(item, {'name': 'ubuntu'}) def test_function_length(self): + # Lists. self.assertEqual(3, expr.evaluate('$.length()', [1, 2, 3])) self.assertEqual(2, expr.evaluate('$.length()', ['one', 'two'])) self.assertEqual(4, expr.evaluate( @@ -79,6 +80,27 @@ class YaqlEvaluatorTest(base.BaseTest): {'array': ['1', '2', '3', '4']}) ) + # Strings. + self.assertEqual(3, expr.evaluate('$.length()', '123')) + self.assertEqual(2, expr.evaluate('$.length()', '12')) + self.assertEqual( + 4, + expr.evaluate('$.string.length()', {'string': '1234'}) + ) + + # Generators. + self.assertEqual( + 2, + expr.evaluate( + "$[$.state = 'active'].length()", + [ + {'state': 'active'}, + {'state': 'active'}, + {'state': 'passive'} + ] + ) + ) + class InlineYAQLEvaluatorTest(base.BaseTest): def setUp(self): diff --git a/mistral/yaql_utils.py b/mistral/yaql_utils.py index cd814640..804aeb2b 100644 --- a/mistral/yaql_utils.py +++ b/mistral/yaql_utils.py @@ -14,11 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections +import six +import types import yaql +from yaql import context as ctx def create_yaql_context(): - ctx = yaql.create_context() _register_functions(ctx) @@ -27,11 +30,17 @@ def create_yaql_context(): def _register_functions(yaql_ctx): - yaql_ctx.register_function(length, 'length') - yaql_ctx.register_function(length, 'size') + yaql_ctx.register_function(_generator_length, 'length') + yaql_ctx.register_function(_string_and_iterable_length, 'length') # Additional convenience YAQL functions. -def length(a): - return len(a()) +@ctx.EvalArg('a', arg_type=(six.string_types, collections.Iterable)) +def _string_and_iterable_length(a): + return len(a) + + +@ctx.EvalArg('a', arg_type=types.GeneratorType) +def _generator_length(a): + return sum(1 for i in a)