diff --git a/openstack/common/scheduler/base_filter.py b/openstack/common/scheduler/base_filter.py index f794e2d6c..249b7f5f5 100644 --- a/openstack/common/scheduler/base_filter.py +++ b/openstack/common/scheduler/base_filter.py @@ -43,6 +43,17 @@ class BaseFilter(object): if self._filter_one(obj, filter_properties): yield obj + # Set to true in a subclass if a filter only needs to be run once + # for each request rather than for each instance + run_filter_once_per_request = False + + def run_filter_for_index(self, index): + """Return True if the filter needs to be run for the "index-th" + instance in a request. Only need to override this if a filter + needs anything other than "first only" or "all" behaviour. + """ + return not (self.run_filter_once_per_request and index > 0) + class BaseFilterHandler(base_handler.BaseHandler): """Base class to handle loading filter classes. @@ -51,23 +62,34 @@ class BaseFilterHandler(base_handler.BaseHandler): """ def get_filtered_objects(self, filter_classes, objs, - filter_properties): + filter_properties, index=0): + """Get objects after filter + + :param filter_classes: filters that will be used to filter the + objects + :param objs: objects that will be filtered + :param filter_properties: client filter properties + :param index: This value needs to be increased in the caller + function of get_filtered_objects when handling + each resource. + """ list_objs = list(objs) LOG.debug("Starting with %d host(s)", len(list_objs)) for filter_cls in filter_classes: cls_name = filter_cls.__name__ filter_class = filter_cls() - objs = filter_class.filter_all(list_objs, filter_properties) - if objs is None: - LOG.debug("Filter %(cls_name)s says to stop filtering", - {'cls_name': cls_name}) - return - list_objs = list(objs) - msg = (_("Filter %(cls_name)s returned %(obj_len)d host(s)") - % {'cls_name': cls_name, 'obj_len': len(list_objs)}) - if not list_objs: - LOG.info(msg) - break - LOG.debug(msg) + if filter_class.run_filter_for_index(index): + objs = filter_class.filter_all(list_objs, filter_properties) + if objs is None: + LOG.debug("Filter %(cls_name)s says to stop filtering", + {'cls_name': cls_name}) + return + list_objs = list(objs) + msg = (_("Filter %(cls_name)s returned %(obj_len)d host(s)") + % {'cls_name': cls_name, 'obj_len': len(list_objs)}) + if not list_objs: + LOG.info(msg) + break + LOG.debug(msg) return list_objs diff --git a/openstack/common/scheduler/filters/availability_zone_filter.py b/openstack/common/scheduler/filters/availability_zone_filter.py index a389ded88..63b9051a8 100644 --- a/openstack/common/scheduler/filters/availability_zone_filter.py +++ b/openstack/common/scheduler/filters/availability_zone_filter.py @@ -19,6 +19,9 @@ from openstack.common.scheduler import filters class AvailabilityZoneFilter(filters.BaseHostFilter): """Filters Hosts by availability zone.""" + # Availability zones do not change within a request + run_filter_once_per_request = True + def host_passes(self, host_state, filter_properties): spec = filter_properties.get('request_spec', {}) props = spec.get('resource_properties', {}) diff --git a/tests/unit/scheduler/test_base_filter.py b/tests/unit/scheduler/test_base_filter.py index 885e557b6..04bdcefd4 100644 --- a/tests/unit/scheduler/test_base_filter.py +++ b/tests/unit/scheduler/test_base_filter.py @@ -93,6 +93,7 @@ class FakeFilter5(BaseFakeFilter): Should not be included. """ + run_filter_once_per_request = True pass @@ -126,26 +127,44 @@ class TestBaseFilterHandler(test.BaseTestCase): result = self.handler.get_all_classes() self.assertEqual(expected, result) - def _get_filtered_objects(self): + def _get_filtered_objects(self, filter_classes, index=0): filter_objs_initial = [1, 2, 3, 4] filter_properties = {'x': 'y'} - filter_classes = [FakeFilter1, FakeFilter2, FakeFilter3, FakeFilter4] return self.handler.get_filtered_objects(filter_classes, filter_objs_initial, - filter_properties) + filter_properties, + index) def test_get_filtered_objects_return_none(self): + filter_classes = [FakeFilter1, FakeFilter2, FakeFilter3, FakeFilter4] + def fake_filter_all(self, list_objs, filter_properties): return with contextlib.nested( mock.patch.object(FakeFilter3, 'filter_all', fake_filter_all), mock.patch.object(FakeFilter4, 'filter_all') ) as (fake3_filter_all, fake4_filter_all): - result = self._get_filtered_objects() + result = self._get_filtered_objects(filter_classes) self.assertIsNone(result) self.assertFalse(fake4_filter_all.called) def test_get_filtered_objects(self): filter_objs_expected = [1, 2, 3, 4] - result = self._get_filtered_objects() + filter_classes = [FakeFilter1, FakeFilter2, FakeFilter3, FakeFilter4] + result = self._get_filtered_objects(filter_classes) self.assertEqual(filter_objs_expected, result) + + def test_get_filtered_objects_with_filter_run_once(self): + filter_objs_expected = [1, 2, 3, 4] + filter_classes = [FakeFilter5] + + with mock.patch.object(FakeFilter5, 'filter_all', + return_value=filter_objs_expected + ) as fake5_filter_all: + result = self._get_filtered_objects(filter_classes, index=1) + self.assertEqual(filter_objs_expected, result) + self.assertFalse(fake5_filter_all.called) + + result = self._get_filtered_objects(filter_classes) + self.assertEqual(filter_objs_expected, result) + self.assertTrue(fake5_filter_all.called)