diff --git a/nova/compute/api.py b/nova/compute/api.py index 900ab706fd0c..c045a86c842a 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -1076,7 +1076,6 @@ class API(base.Base): filter_mapping = { 'image': 'image_ref', 'name': 'display_name', - 'instance_name': 'name', 'tenant_id': 'project_id', 'flavor': _remap_flavor_filter, 'fixed_ip': _remap_fixed_ip_filter} diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 2993aaaf3385..bbd5f2e4556f 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -22,7 +22,6 @@ import copy import datetime import functools -import re import warnings from nova import block_device @@ -43,7 +42,6 @@ from sqlalchemy.orm import joinedload_all from sqlalchemy.sql.expression import asc from sqlalchemy.sql.expression import desc from sqlalchemy.sql.expression import literal_column -from sqlalchemy.sql.expression import or_ from sqlalchemy.sql import func FLAGS = flags.FLAGS @@ -246,7 +244,19 @@ def exact_filter(query, model, filters, legal_keys): # OK, filtering on this key; what value do we search for? value = filters.pop(key) - if isinstance(value, (list, tuple, set, frozenset)): + if key == 'metadata': + column_attr = getattr(model, key) + if isinstance(value, list): + for item in value: + for k, v in item.iteritems(): + query = query.filter(column_attr.any(key=k)) + query = query.filter(column_attr.any(value=v)) + + else: + for k, v in value.iteritems(): + query = query.filter(column_attr.any(key=k)) + query = query.filter(column_attr.any(value=v)) + elif isinstance(value, (list, tuple, set, frozenset)): # Looking for values in a list; apply to query directly column_attr = getattr(model, key) query = query.filter(column_attr.in_(value)) @@ -1516,28 +1526,6 @@ def instance_get_all_by_filters(context, filters, sort_key, sort_dir): will be returned by default, unless there's a filter that says otherwise""" - def _regexp_filter_by_metadata(instance, meta): - inst_metadata = [{node['key']: node['value']} - for node in instance['metadata']] - if isinstance(meta, list): - for node in meta: - if node not in inst_metadata: - return False - elif isinstance(meta, dict): - for k, v in meta.iteritems(): - if {k: v} not in inst_metadata: - return False - return True - - def _regexp_filter_by_column(instance, filter_name, filter_re): - try: - v = getattr(instance, filter_name) - except AttributeError: - return True - if v and filter_re.match(unicode(v)): - return True - return False - sort_fn = {'desc': desc, 'asc': asc} session = get_session() @@ -1579,39 +1567,48 @@ def instance_get_all_by_filters(context, filters, sort_key, sort_dir): # Filters for exact matches that we can do along with the SQL query... # For other filters that don't match this, we will do regexp matching exact_match_filter_names = ['project_id', 'user_id', 'image_ref', - 'vm_state', 'instance_type_id', 'uuid'] + 'vm_state', 'instance_type_id', 'uuid', + 'metadata'] # Filter the query query_prefix = exact_filter(query_prefix, models.Instance, filters, exact_match_filter_names) + query_prefix = regex_filter(query_prefix, models.Instance, filters) instances = query_prefix.all() - if not instances: - return [] - - # Now filter on everything else for regexp matching.. - # For filters not in the list, we'll attempt to use the filter_name - # as a column name in Instance.. - regexp_filter_funcs = {} - - for filter_name in filters.iterkeys(): - filter_func = regexp_filter_funcs.get(filter_name, None) - filter_re = re.compile(str(filters[filter_name])) - if filter_func: - filter_l = lambda instance: filter_func(instance, filter_re) - elif filter_name == 'metadata': - filter_l = lambda instance: _regexp_filter_by_metadata(instance, - filters[filter_name]) - else: - filter_l = lambda instance: _regexp_filter_by_column(instance, - filter_name, filter_re) - instances = filter(filter_l, instances) - if not instances: - break - return instances +def regex_filter(query, model, filters): + """Applies regular expression filtering to a query. + + Returns the updated query. + + :param query: query to apply filters to + :param model: model object the query applies to + :param filters: dictionary of filters with regex values + """ + + regexp_op_map = { + 'postgresql': '~', + 'mysql': 'REGEXP', + 'oracle': 'REGEXP_LIKE', + 'sqlite': 'REGEXP' + } + db_string = FLAGS.sql_connection.split(':')[0].split('+')[0] + db_regexp_op = regexp_op_map.get(db_string, 'LIKE') + for filter_name in filters.iterkeys(): + try: + column_attr = getattr(model, filter_name) + except AttributeError: + continue + if 'property' == type(column_attr).__name__: + continue + query = query.filter(column_attr.op(db_regexp_op)( + str(filters[filter_name]))) + return query + + @require_context def instance_get_active_by_window(context, begin, end=None, project_id=None, host=None): diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index 6aa5050e4c8b..cada9d79ae33 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -18,6 +18,7 @@ """Session Handling for SQLAlchemy backend.""" +import re import time from sqlalchemy.exc import DisconnectionError, OperationalError @@ -85,6 +86,16 @@ def is_db_connection_error(args): return False +def regexp(expr, item): + reg = re.compile(expr) + return reg.search(unicode(item)) is not None + + +class AddRegexFactory(sqlalchemy.interfaces.PoolListener): + def connect(delf, dbapi_con, con_record): + dbapi_con.create_function('REGEXP', 2, regexp) + + def get_engine(): """Return a SQLAlchemy engine.""" global _ENGINE @@ -109,6 +120,7 @@ def get_engine(): if FLAGS.sql_connection == "sqlite://": engine_args["poolclass"] = StaticPool engine_args["connect_args"] = {'check_same_thread': False} + engine_args['listeners'] = [AddRegexFactory()] _ENGINE = sqlalchemy.create_engine(FLAGS.sql_connection, **engine_args) diff --git a/nova/tests/compute/test_compute.py b/nova/tests/compute/test_compute.py index 0b0316e26581..827374f3766b 100644 --- a/nova/tests/compute/test_compute.py +++ b/nova/tests/compute/test_compute.py @@ -3146,14 +3146,14 @@ class ComputeAPITestCase(BaseTestCase): 'display_name': 'not-woot'}) instances = self.compute_api.get_all(c, - search_opts={'name': 'woo.*'}) + search_opts={'name': '^woo.*'}) self.assertEqual(len(instances), 2) instance_uuids = [instance['uuid'] for instance in instances] self.assertTrue(instance1['uuid'] in instance_uuids) self.assertTrue(instance2['uuid'] in instance_uuids) instances = self.compute_api.get_all(c, - search_opts={'name': 'woot.*'}) + search_opts={'name': '^woot.*'}) instance_uuids = [instance['uuid'] for instance in instances] self.assertEqual(len(instances), 1) self.assertTrue(instance1['uuid'] in instance_uuids) @@ -3166,7 +3166,7 @@ class ComputeAPITestCase(BaseTestCase): self.assertTrue(instance3['uuid'] in instance_uuids) instances = self.compute_api.get_all(c, - search_opts={'name': 'n.*'}) + search_opts={'name': '^n.*'}) self.assertEqual(len(instances), 1) instance_uuids = [instance['uuid'] for instance in instances] self.assertTrue(instance3['uuid'] in instance_uuids) @@ -3179,35 +3179,6 @@ class ComputeAPITestCase(BaseTestCase): db.instance_destroy(c, instance2['uuid']) db.instance_destroy(c, instance3['uuid']) - def test_get_all_by_instance_name_regexp(self): - """Test searching instances by name""" - self.flags(instance_name_template='instance-%d') - - c = context.get_admin_context() - instance1 = self._create_fake_instance() - instance2 = self._create_fake_instance({'id': 2}) - instance3 = self._create_fake_instance({'id': 10}) - - instances = self.compute_api.get_all(c, - search_opts={'instance_name': 'instance.*'}) - self.assertEqual(len(instances), 3) - - instances = self.compute_api.get_all(c, - search_opts={'instance_name': '.*\-\d$'}) - self.assertEqual(len(instances), 2) - instance_uuids = [instance['uuid'] for instance in instances] - self.assertTrue(instance1['uuid'] in instance_uuids) - self.assertTrue(instance2['uuid'] in instance_uuids) - - instances = self.compute_api.get_all(c, - search_opts={'instance_name': 'i.*2'}) - self.assertEqual(len(instances), 1) - self.assertEqual(instances[0]['uuid'], instance2['uuid']) - - db.instance_destroy(c, instance1['uuid']) - db.instance_destroy(c, instance2['uuid']) - db.instance_destroy(c, instance3['uuid']) - def test_get_all_by_multiple_options_at_once(self): """Test searching by multiple options at once""" c = context.get_admin_context() diff --git a/nova/tests/test_db_api.py b/nova/tests/test_db_api.py index b2b1cf9e272b..89b6480d6708 100644 --- a/nova/tests/test_db_api.py +++ b/nova/tests/test_db_api.py @@ -51,6 +51,34 @@ class DbApiTestCase(test.TestCase): result = db.instance_get_all_by_filters(self.context, {}) self.assertEqual(2, len(result)) + def test_instance_get_all_by_filters_regex(self): + self.create_instances_with_args(display_name='test1') + self.create_instances_with_args(display_name='teeeest2') + self.create_instances_with_args(display_name='diff') + result = db.instance_get_all_by_filters(self.context, + {'display_name': 't.*st.'}) + self.assertEqual(2, len(result)) + + def test_instance_get_all_by_filters_regex_unsupported_db(self): + """Ensure that the 'LIKE' operator is used for unsupported dbs.""" + self.flags(sql_connection="notdb://") + self.create_instances_with_args(display_name='test1') + self.create_instances_with_args(display_name='test.*') + self.create_instances_with_args(display_name='diff') + result = db.instance_get_all_by_filters(self.context, + {'display_name': 'test.*'}) + self.assertEqual(1, len(result)) + result = db.instance_get_all_by_filters(self.context, + {'display_name': '%test%'}) + self.assertEqual(2, len(result)) + + def test_instance_get_all_by_filters_metadata(self): + self.create_instances_with_args(metadata={'foo': 'bar'}) + self.create_instances_with_args() + result = db.instance_get_all_by_filters(self.context, + {'metadata': {'foo': 'bar'}}) + self.assertEqual(1, len(result)) + def test_instance_get_all_by_filters_unicode_value(self): self.create_instances_with_args(display_name=u'test♥') result = db.instance_get_all_by_filters(self.context,