diff --git a/releasenotes/notes/substring-matching-1d5981b8e5b1d919.yaml b/releasenotes/notes/substring-matching-1d5981b8e5b1d919.yaml new file mode 100644 index 00000000..0b180d99 --- /dev/null +++ b/releasenotes/notes/substring-matching-1d5981b8e5b1d919.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - Add regular expression matching on search values for + certain string fields of sahara objects. This + applies to list operations through the REST + API and therefore applies to the dashboard + and sahara client as well. Closes bug 1503345. diff --git a/sahara/conductor/api.py b/sahara/conductor/api.py index ba2d83f4..990ccd95 100644 --- a/sahara/conductor/api.py +++ b/sahara/conductor/api.py @@ -164,13 +164,22 @@ class LocalApi(object): _get_id(cluster_template)) @r.wrap(r.ClusterTemplateResource) - def cluster_template_get_all(self, context, **kwargs): + def cluster_template_get_all(self, context, regex_search=False, **kwargs): """Get all cluster templates filtered by **kwargs. - e.g. cluster_template_get_all(plugin_name='vanilla', - hadoop_version='1.1') + :param context: The context, and associated authentication, to use with + this operation + + :param regex_search: If True, enable regex matching for filter + values. See the user guide for more information + on how regex matching is handled. If False, + no regex matching is done. + + :param kwargs: Specifies values for named fields by which + to constrain the search """ - return self._manager.cluster_template_get_all(context, **kwargs) + return self._manager.cluster_template_get_all(context, + regex_search, **kwargs) @r.wrap(r.ClusterTemplateResource) def cluster_template_create(self, context, values): diff --git a/sahara/conductor/manager.py b/sahara/conductor/manager.py index cafa9799..121f020c 100644 --- a/sahara/conductor/manager.py +++ b/sahara/conductor/manager.py @@ -289,13 +289,22 @@ class ConductorManager(db_base.Base): """Return the cluster_template or None if it does not exist.""" return self.db.cluster_template_get(context, cluster_template) - def cluster_template_get_all(self, context, **kwargs): + def cluster_template_get_all(self, context, regex_search=False, **kwargs): """Get all cluster templates filtered by **kwargs. - e.g. cluster_template_get_all(plugin_name='vanilla', - hadoop_version='1.1') + :param context: The context, and associated authentication, to use with + this operation + + :param regex_search: If True, enable regex matching for filter + values. See the user guide for more information + on how regex matching is handled. If False, + no regex matching is done. + + :param kwargs: Specifies values for named fields by which + to constrain the search """ - return self.db.cluster_template_get_all(context, **kwargs) + return self.db.cluster_template_get_all(context, + regex_search, **kwargs) def cluster_template_create(self, context, values): """Create a cluster_template from the values dictionary.""" diff --git a/sahara/db/api.py b/sahara/db/api.py index 302830fa..e51dd295 100644 --- a/sahara/db/api.py +++ b/sahara/db/api.py @@ -200,13 +200,21 @@ def cluster_template_get(context, cluster_template): @to_dict -def cluster_template_get_all(context, **kwargs): +def cluster_template_get_all(context, regex_search=False, **kwargs): """Get all cluster templates filtered by **kwargs. - e.g. cluster_template_get_all(plugin_name='vanilla', - hadoop_version='1.1') + :param context: The context, and associated authentication, to use with + this operation + + :param regex_search: If True, enable regex matching for filter + values. See the user guide for more information + on how regex matching is handled. If False, + no regex matching is done. + + :param kwargs: Specifies values for named fields by which + to constrain the search """ - return IMPL.cluster_template_get_all(context, **kwargs) + return IMPL.cluster_template_get_all(context, regex_search, **kwargs) @to_dict diff --git a/sahara/db/sqlalchemy/api.py b/sahara/db/sqlalchemy/api.py index 519b3840..e1ea7332 100644 --- a/sahara/db/sqlalchemy/api.py +++ b/sahara/db/sqlalchemy/api.py @@ -15,6 +15,7 @@ """Implementation of SQLAlchemy backend.""" +import copy import sys import threading @@ -172,6 +173,53 @@ def like_filter(query, cls, search_opts): return query, remaining +def _get_regex_op(connection): + db = connection.split(':')[0].split('+')[0] + regexp_op_map = { + 'postgresql': '~', + 'mysql': 'REGEXP' + } + return regexp_op_map.get(db, None) + + +def regex_filter(query, cls, regex_cols, search_opts): + """Add regex filters for specified columns. + + Add a regex filter to the query for any entry in the + 'search_opts' dict where the key is the name of a column in + 'cls' and listed in 'regex_cols' and the value is a string. + + Return the modified query and any entries in search_opts + whose keys do not match columns or whose values are not + strings. + + This is only supported for mysql and postgres. For other + databases, the query is not altered. + + :param query: a non-null query object + :param cls: the database model class the filters will apply to + :param regex_cols: a list of columns for which regex is supported + :param search_opts: a dictionary whose key/value entries are interpreted as + column names and search patterns + :returns: a tuple containing the modified query and a dictionary of + unused search_opts + """ + + regex_op = _get_regex_op(CONF.database.connection) + if not regex_op: + return query, copy.copy(search_opts) + + remaining = {} + for k, v in six.iteritems(search_opts): + if isinstance(v, six.string_types) and ( + k in cls.__table__.columns and k in regex_cols): + col = cls.__table__.columns[k] + query = query.filter(col.op(regex_op)(v)) + else: + remaining[k] = v + return query, remaining + + def setup_db(): try: engine = get_engine() @@ -398,8 +446,14 @@ def cluster_template_get(context, cluster_template_id): return _cluster_template_get(context, get_session(), cluster_template_id) -def cluster_template_get_all(context, **kwargs): +def cluster_template_get_all(context, regex_search=False, **kwargs): + + regex_cols = ['name', 'description', 'plugin_name'] + query = model_query(m.ClusterTemplate, context) + if regex_search: + query, kwargs = regex_filter(query, + m.ClusterTemplate, regex_cols, kwargs) return query.filter_by(**kwargs).all() diff --git a/sahara/service/api.py b/sahara/service/api.py index 8c52081f..12265196 100644 --- a/sahara/service/api.py +++ b/sahara/service/api.py @@ -176,7 +176,8 @@ def update_cluster(id, values): # ClusterTemplate ops def get_cluster_templates(**kwargs): - return conductor.cluster_template_get_all(context.ctx(), **kwargs) + return conductor.cluster_template_get_all(context.ctx(), + regex_search=True, **kwargs) def get_cluster_template(id): diff --git a/sahara/tests/unit/conductor/manager/test_templates.py b/sahara/tests/unit/conductor/manager/test_templates.py index dd8b225b..e96a3e12 100644 --- a/sahara/tests/unit/conductor/manager/test_templates.py +++ b/sahara/tests/unit/conductor/manager/test_templates.py @@ -16,12 +16,14 @@ import copy import uuid +import mock import six from sqlalchemy import exc as sa_ex import testtools from sahara.conductor import manager from sahara import context +from sahara.db.sqlalchemy import models as m from sahara import exceptions as ex from sahara.service.validations import cluster_template_schema as cl_schema from sahara.service.validations import node_group_template_schema as ngt_schema @@ -399,18 +401,28 @@ class ClusterTemplates(test_base.ConductorManagerTestCase): def test_clt_search(self): ctx = context.ctx() - self.api.cluster_template_create(ctx, SAMPLE_CLT) + clt = copy.deepcopy(SAMPLE_CLT) + clt["name"] = "frederica" + clt["plugin_name"] = "test_plugin" + self.api.cluster_template_create(ctx, clt) lst = self.api.cluster_template_get_all(ctx) self.assertEqual(1, len(lst)) - kwargs = {'name': SAMPLE_CLT['name'], - 'plugin_name': SAMPLE_CLT['plugin_name']} + # Exact match + kwargs = {'name': clt['name'], + 'plugin_name': clt['plugin_name']} lst = self.api.cluster_template_get_all(ctx, **kwargs) self.assertEqual(1, len(lst)) # Valid field but no matching value - kwargs = {'name': SAMPLE_CLT['name']+"foo"} + kwargs = {'name': clt['name']+"foo"} + lst = self.api.cluster_template_get_all(ctx, **kwargs) + self.assertEqual(0, len(lst)) + + # Valid field with substrings + kwargs = {'name': "red", + 'plugin_name': "test"} lst = self.api.cluster_template_get_all(ctx, **kwargs) self.assertEqual(0, len(lst)) @@ -419,6 +431,26 @@ class ClusterTemplates(test_base.ConductorManagerTestCase): self.api.cluster_template_get_all, ctx, **{'badfield': 'junk'}) + @mock.patch('sahara.db.sqlalchemy.api.regex_filter') + def test_clt_search_regex(self, regex_filter): + + # do this so we can return the correct value + def _regex_filter(query, cls, regex_cols, search_opts): + return query, search_opts + + regex_filter.side_effect = _regex_filter + + ctx = context.ctx() + self.api.cluster_template_get_all(ctx) + self.assertEqual(0, regex_filter.call_count) + + self.api.cluster_template_get_all(ctx, regex_search=True, name="fox") + self.assertEqual(1, regex_filter.call_count) + args, kwargs = regex_filter.call_args + self.assertTrue(type(args[1] is m.ClusterTemplate)) + self.assertEqual(args[2], ["name", "description", "plugin_name"]) + self.assertEqual(args[3], {"name": "fox"}) + def test_clt_update(self): ctx = context.ctx() clt = self.api.cluster_template_create(ctx, SAMPLE_CLT) diff --git a/sahara/tests/unit/db/test_utils.py b/sahara/tests/unit/db/test_utils.py new file mode 100644 index 00000000..c28f4d61 --- /dev/null +++ b/sahara/tests/unit/db/test_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) 2016 Red Hat Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import testtools + +from sahara import context +from sahara.db.sqlalchemy import api +from sahara.db.sqlalchemy import models as m +import sahara.tests.unit.base as base + + +class TestRegex(testtools.TestCase): + + def test_get_regex_op(self): + regex_op = api._get_regex_op("mysql://user:passw@localhost/sahara") + self.assertEqual("REGEXP", regex_op) + + regex_op = api._get_regex_op("postgresql://localhost/sahara") + self.assertEqual("~", regex_op) + + regex_op = api._get_regex_op("sqlite://user:passw@localhost/sahara") + self.assertIsNone(regex_op) + + +class TestRegexFilter(base.SaharaWithDbTestCase): + + @mock.patch("sahara.db.sqlalchemy.api._get_regex_op") + def test_regex_filter(self, get_regex_op): + query = api.model_query(m.ClusterTemplate, context.ctx()) + + regex_cols = ["name", "description", "plugin_name"] + search_opts = {"name": "fred", + "hadoop_version": "2", + "bogus": "jack", + "plugin_name": "vanilla"} + + # Since regex_op is None remaining_opts should be a copy of search_opts + get_regex_op.return_value = None + query, remaining_opts = api.regex_filter( + query, m.ClusterTemplate, regex_cols, search_opts) + self.assertEqual(search_opts, remaining_opts) + self.assertIsNot(search_opts, remaining_opts) + + # Since regex_cols is [] remaining_opts should be a copy of search_opts + get_regex_op.return_value = "REGEXP" + query, remaining_opts = api.regex_filter( + query, m.ClusterTemplate, [], search_opts) + self.assertEqual(search_opts, remaining_opts) + self.assertIsNot(search_opts, remaining_opts) + + # Remaining should be search_opts with name and plugin_name removed + # These are the only fields that are in regex_cols and also in + # the model. + get_regex_op.return_value = "REGEXP" + query, remaining_opts = api.regex_filter( + query, m.ClusterTemplate, regex_cols, search_opts) + self.assertEqual({"hadoop_version": "2", + "bogus": "jack"}, remaining_opts) + + # bogus is not in the model so it should be left in remaining + # even though regex_cols lists it + regex_cols.append("bogus") + query, remaining_opts = api.regex_filter( + query, m.ClusterTemplate, regex_cols, search_opts) + self.assertEqual({"hadoop_version": "2", + "bogus": "jack"}, remaining_opts) + + # name will not be removed because the value is not a string + search_opts["name"] = 5 + query, remaining_opts = api.regex_filter( + query, m.ClusterTemplate, regex_cols, search_opts) + self.assertEqual({"hadoop_version": "2", + "bogus": "jack", + "name": 5}, remaining_opts)