Merge "Fix service-list filter"

This commit is contained in:
Jenkins 2016-02-24 15:02:44 +00:00 committed by Gerrit Code Review
commit 71388796d3
8 changed files with 50 additions and 31 deletions

@ -99,7 +99,8 @@ def _list_hosts(req, service=None):
"""Returns a summary list of hosts.""" """Returns a summary list of hosts."""
curr_time = timeutils.utcnow(with_timezone=True) curr_time = timeutils.utcnow(with_timezone=True)
context = req.environ['cinder.context'] context = req.environ['cinder.context']
services = objects.ServiceList.get_all(context, False) filters = {'disabled': False}
services = objects.ServiceList.get_all(context, filters)
zone = '' zone = ''
if 'zone' in req.GET: if 'zone' in req.GET:
zone = req.GET['zone'] zone = req.GET['zone']

@ -82,27 +82,20 @@ class ServiceController(wsgi.Controller):
authorize(context, action='index') authorize(context, action='index')
detailed = self.ext_mgr.is_loaded('os-extended-services') detailed = self.ext_mgr.is_loaded('os-extended-services')
now = timeutils.utcnow(with_timezone=True) now = timeutils.utcnow(with_timezone=True)
services = objects.ServiceList.get_all(context)
host = '' filters = {}
if 'host' in req.GET: if 'host' in req.GET:
host = req.GET['host'] filters['host'] = req.GET['host']
service = '' if 'binary' in req.GET:
if 'service' in req.GET: filters['binary'] = req.GET['binary']
service = req.GET['service'] elif 'service' in req.GET:
filters['binary'] = req.GET['service']
versionutils.report_deprecated_feature(LOG, _( versionutils.report_deprecated_feature(LOG, _(
"Query by service parameter is deprecated. " "Query by service parameter is deprecated. "
"Please use binary parameter instead.")) "Please use binary parameter instead."))
binary = ''
if 'binary' in req.GET:
binary = req.GET['binary']
if host: services = objects.ServiceList.get_all(context, filters)
services = [s for s in services if s.host == host]
# NOTE(uni): deprecating service request key, binary takes precedence
binary_key = binary or service
if binary_key:
services = [s for s in services if s.binary == binary_key]
svcs = [] svcs = []
for svc in services: for svc in services:

@ -106,9 +106,9 @@ def service_get_by_host_and_topic(context, host, topic):
return IMPL.service_get_by_host_and_topic(context, host, topic) return IMPL.service_get_by_host_and_topic(context, host, topic)
def service_get_all(context, disabled=None): def service_get_all(context, filters=None):
"""Get all services.""" """Get all services."""
return IMPL.service_get_all(context, disabled) return IMPL.service_get_all(context, filters)
def service_get_all_by_topic(context, topic, disabled=None): def service_get_all_by_topic(context, topic, disabled=None):

@ -373,11 +373,22 @@ def service_get(context, service_id):
@require_admin_context @require_admin_context
def service_get_all(context, disabled=None): def service_get_all(context, filters=None):
if filters and not is_valid_model_filters(models.Service, filters):
return []
query = model_query(context, models.Service) query = model_query(context, models.Service)
if disabled is not None: try:
query = query.filter_by(disabled=disabled) host = filters.pop('host')
host_attr = models.Service.host
conditions = or_(host_attr == host, host_attr.op('LIKE')(host + '@%'))
query = query.filter(conditions)
except KeyError:
pass
if filters:
query = query.filter_by(**filters)
return query.all() return query.all()

@ -70,7 +70,7 @@ def stub_utcnow(with_timezone=False):
return datetime.datetime(2013, 7, 3, 0, 0, 2, tzinfo=tzinfo) return datetime.datetime(2013, 7, 3, 0, 0, 2, tzinfo=tzinfo)
def stub_service_get_all(self, req): def stub_service_get_all(context, filters=None):
return SERVICE_LIST return SERVICE_LIST

@ -132,7 +132,13 @@ class FakeRequestWithHostBinary(object):
def fake_service_get_all(context, filters=None): def fake_service_get_all(context, filters=None):
return fake_services_list filters = filters or {}
host = filters.get('host')
binary = filters.get('binary')
return [s for s in fake_services_list
if (not host or s['host'] == host or
s['host'].startswith(host + '@'))
and (not binary or s['binary'] == binary)]
def fake_service_get_by_host_binary(context, host, binary): def fake_service_get_by_host_binary(context, host, binary):

@ -142,8 +142,9 @@ class TestServiceList(test_objects.BaseObjectsTestCase):
db_service = fake_service.fake_db_service() db_service = fake_service.fake_db_service()
service_get_all.return_value = [db_service] service_get_all.return_value = [db_service]
services = objects.ServiceList.get_all(self.context, 'foo') filters = {'host': 'host', 'binary': 'foo', 'disabled': False}
service_get_all.assert_called_once_with(self.context, 'foo') services = objects.ServiceList.get_all(self.context, filters)
service_get_all.assert_called_once_with(self.context, filters)
self.assertEqual(1, len(services)) self.assertEqual(1, len(services))
TestService._compare(self, db_service, services[0]) TestService._compare(self, db_service, services[0])

@ -178,18 +178,25 @@ class DBAPIServiceTestCase(BaseTest):
def test_service_get_all(self): def test_service_get_all(self):
values = [ values = [
{'host': 'host1', 'topic': 'topic1'}, {'host': 'host1', 'binary': 'b1'},
{'host': 'host2', 'topic': 'topic2'}, {'host': 'host1@ceph', 'binary': 'b2'},
{'host': 'host2', 'binary': 'b2'},
{'disabled': True} {'disabled': True}
] ]
services = [self._create_service(vals) for vals in values] services = [self._create_service(vals) for vals in values]
disabled_services = [services[-1]] disabled_services = [services[-1]]
non_disabled_services = services[:-1] non_disabled_services = services[:-1]
expected = services[:2]
expected_bin = services[1:3]
compares = [ compares = [
(services, db.service_get_all(self.ctxt)), (services, db.service_get_all(self.ctxt, {})),
(disabled_services, db.service_get_all(self.ctxt, True)), (expected, db.service_get_all(self.ctxt, {'host': 'host1'})),
(non_disabled_services, db.service_get_all(self.ctxt, False)) (expected_bin, db.service_get_all(self.ctxt, {'binary': 'b2'})),
(disabled_services, db.service_get_all(self.ctxt,
{'disabled': True})),
(non_disabled_services, db.service_get_all(self.ctxt,
{'disabled': False})),
] ]
for comp in compares: for comp in compares:
self._assertEqualListsOfObjects(*comp) self._assertEqualListsOfObjects(*comp)