Merge "Refactor REGEX filters to eliminate 500 errors"

This commit is contained in:
Jenkins 2016-12-12 20:10:16 +00:00 committed by Gerrit Code Review
commit ae63a7a5a8
2 changed files with 66 additions and 11 deletions

View File

@ -2309,17 +2309,51 @@ def _tag_instance_filter(context, query, filters):
return query return query
def _get_regexp_op_for_connection(db_connection): def _db_connection_type(db_connection):
"""Returns a lowercase symbol for the db type.
This is useful when we need to change what we are doing per DB
(like handling regexes). In a CellsV2 world it probably needs to
do something better than use the database configuration string.
"""
db_string = db_connection.split(':')[0].split('+')[0] db_string = db_connection.split(':')[0].split('+')[0]
return db_string.lower()
def _safe_regex_mysql(raw_string):
"""Make regex safe to mysql.
Certain items like '|' are interpreted raw by mysql REGEX. If you
search for a single | then you trigger an error because it's
expecting content on either side.
For consistency sake we escape all '|'. This does mean we wouldn't
support something like foo|bar to match completely different
things, however, one can argue putting such complicated regex into
name search probably means you are doing this wrong.
"""
return raw_string.replace('|', '\\|')
def _get_regexp_ops(connection):
"""Return safety filter and db opts for regex."""
regexp_op_map = { regexp_op_map = {
'postgresql': '~', 'postgresql': '~',
'mysql': 'REGEXP', 'mysql': 'REGEXP',
'sqlite': 'REGEXP' 'sqlite': 'REGEXP'
} }
return regexp_op_map.get(db_string, 'LIKE') regex_safe_filters = {
'mysql': _safe_regex_mysql
}
db_type = _db_connection_type(connection)
return (regex_safe_filters.get(db_type, lambda x: x),
regexp_op_map.get(db_type, 'LIKE'))
def _regex_instance_filter(query, filters): def _regex_instance_filter(query, filters):
"""Applies regular expression filtering to an Instance query. """Applies regular expression filtering to an Instance query.
Returns the updated query. Returns the updated query.
@ -2329,7 +2363,7 @@ def _regex_instance_filter(query, filters):
""" """
model = models.Instance model = models.Instance
db_regexp_op = _get_regexp_op_for_connection(CONF.database.connection) safe_regex_filter, db_regexp_op = _get_regexp_ops(CONF.database.connection)
for filter_name in filters: for filter_name in filters:
try: try:
column_attr = getattr(model, filter_name) column_attr = getattr(model, filter_name)
@ -2345,6 +2379,7 @@ def _regex_instance_filter(query, filters):
query = query.filter(column_attr.op(db_regexp_op)( query = query.filter(column_attr.op(db_regexp_op)(
u'%' + filter_val + u'%')) u'%' + filter_val + u'%'))
else: else:
filter_val = safe_regex_filter(filter_val)
query = query.filter(column_attr.op(db_regexp_op)( query = query.filter(column_attr.op(db_regexp_op)(
filter_val)) filter_val))
return query return query

View File

@ -297,8 +297,8 @@ def _create_aggregate_with_hosts(context=context.get_admin_context(),
return result return result
@mock.patch.object(sqlalchemy_api, '_get_regexp_op_for_connection', @mock.patch.object(sqlalchemy_api, '_get_regexp_ops',
return_value='LIKE') return_value=(lambda x: x, 'LIKE'))
class UnsupportedDbRegexpTestCase(DbTestCase): class UnsupportedDbRegexpTestCase(DbTestCase):
def test_instance_get_all_by_filters_paginate(self, mock_get_regexp): def test_instance_get_all_by_filters_paginate(self, mock_get_regexp):
@ -1090,21 +1090,25 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase):
self.assertEqual(test1, expected_dict) self.assertEqual(test1, expected_dict)
def test_get_regexp_op_for_database_sqlite(self): def test_get_regexp_op_for_database_sqlite(self):
op = sqlalchemy_api._get_regexp_op_for_connection('sqlite:///') filter, op = sqlalchemy_api._get_regexp_ops('sqlite:///')
self.assertEqual('|', filter('|'))
self.assertEqual('REGEXP', op) self.assertEqual('REGEXP', op)
def test_get_regexp_op_for_database_mysql(self): def test_get_regexp_op_for_database_mysql(self):
op = sqlalchemy_api._get_regexp_op_for_connection( filter, op = sqlalchemy_api._get_regexp_ops(
'mysql+pymysql://root@localhost') 'mysql+pymysql://root@localhost')
self.assertEqual('\\|', filter('|'))
self.assertEqual('REGEXP', op) self.assertEqual('REGEXP', op)
def test_get_regexp_op_for_database_postgresql(self): def test_get_regexp_op_for_database_postgresql(self):
op = sqlalchemy_api._get_regexp_op_for_connection( filter, op = sqlalchemy_api._get_regexp_ops(
'postgresql://localhost') 'postgresql://localhost')
self.assertEqual('|', filter('|'))
self.assertEqual('~', op) self.assertEqual('~', op)
def test_get_regexp_op_for_database_unknown(self): def test_get_regexp_op_for_database_unknown(self):
op = sqlalchemy_api._get_regexp_op_for_connection('notdb:///') filter, op = sqlalchemy_api._get_regexp_ops('notdb:///')
self.assertEqual('|', filter('|'))
self.assertEqual('LIKE', op) self.assertEqual('LIKE', op)
@mock.patch.object(sqlalchemy_api.main_context_manager._factory, @mock.patch.object(sqlalchemy_api.main_context_manager._factory,
@ -1145,6 +1149,22 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase):
mock_get.assert_called_once_with(mock.sentinel.elevated, 'foo') mock_get.assert_called_once_with(mock.sentinel.elevated, 'foo')
ctxt.elevated.assert_called_once_with(read_deleted='yes') ctxt.elevated.assert_called_once_with(read_deleted='yes')
def test_replace_sub_expression(self):
ret = sqlalchemy_api._safe_regex_mysql('|')
self.assertEqual('\\|', ret)
ret = sqlalchemy_api._safe_regex_mysql('||')
self.assertEqual('\\|\\|', ret)
ret = sqlalchemy_api._safe_regex_mysql('a||')
self.assertEqual('a\\|\\|', ret)
ret = sqlalchemy_api._safe_regex_mysql('|a|')
self.assertEqual('\\|a\\|', ret)
ret = sqlalchemy_api._safe_regex_mysql('||a')
self.assertEqual('\\|\\|a', ret)
class SqlAlchemyDbApiTestCase(DbTestCase): class SqlAlchemyDbApiTestCase(DbTestCase):
def test_instance_get_all_by_host(self): def test_instance_get_all_by_host(self):