diff --git a/neutron/db/db_base_plugin_common.py b/neutron/db/db_base_plugin_common.py index 9e00f103a70..a2ebcd25181 100644 --- a/neutron/db/db_base_plugin_common.py +++ b/neutron/db/db_base_plugin_common.py @@ -28,6 +28,7 @@ from neutron.common import exceptions from neutron.common import utils from neutron.db import common_db_mixin from neutron.db import models_v2 +from neutron.objects import subnet as subnet_obj from neutron.objects import subnetpool as subnetpool_obj LOG = logging.getLogger(__name__) @@ -233,9 +234,8 @@ class DbBasePluginCommon(common_db_mixin.CommonDbMixin): return port def _get_dns_by_subnet(self, context, subnet_id): - dns_qry = context.session.query(models_v2.DNSNameServer) - return dns_qry.filter_by(subnet_id=subnet_id).order_by( - models_v2.DNSNameServer.order).all() + return subnet_obj.DNSNameServer.get_objects(context, + subnet_id=subnet_id) def _get_route_by_subnet(self, context, subnet_id): route_qry = context.session.query(models_v2.SubnetRoute) diff --git a/neutron/db/ipam_backend_mixin.py b/neutron/db/ipam_backend_mixin.py index fc6c1a7c513..1aacb643ce9 100644 --- a/neutron/db/ipam_backend_mixin.py +++ b/neutron/db/ipam_backend_mixin.py @@ -38,6 +38,7 @@ from neutron.db import segments_db from neutron.extensions import portbindings from neutron.extensions import segment from neutron.ipam import utils as ipam_utils +from neutron.objects import subnet as subnet_obj from neutron.services.segments import db as segment_svc_db from neutron.services.segments import exceptions as segment_exc @@ -148,14 +149,14 @@ class IpamBackendMixin(db_base_plugin_common.DbBasePluginCommon): # when update subnet's DNS nameservers. And store new # nameservers with order one by one. for dns in old_dns_list: - context.session.delete(dns) + dns.delete() for order, server in enumerate(new_dns_addr_list): - dns = models_v2.DNSNameServer( - address=server, - order=order, - subnet_id=id) - context.session.add(dns) + dns = subnet_obj.DNSNameServer(context, + address=server, + order=order, + subnet_id=id) + dns.create() del s["dns_nameservers"] return new_dns_addr_list @@ -475,11 +476,11 @@ class IpamBackendMixin(db_base_plugin_common.DbBasePluginCommon): # by one when create subnet with DNS nameservers if validators.is_attr_set(dns_nameservers): for order, server in enumerate(dns_nameservers): - dns = models_v2.DNSNameServer( - address=server, - order=order, - subnet_id=subnet.id) - context.session.add(dns) + dns = subnet_obj.DNSNameServer(context, + address=server, + order=order, + subnet_id=subnet.id) + dns.create() if validators.is_attr_set(host_routes): for rt in host_routes: diff --git a/neutron/objects/base.py b/neutron/objects/base.py index 284a219b6ef..6a6d327760d 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -111,6 +111,12 @@ class Pager(object): context, model, id=self.marker) return res + def __str__(self): + return str(self.__dict__) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + @six.add_metaclass(abc.ABCMeta) class NeutronObject(obj_base.VersionedObject, diff --git a/neutron/objects/subnet.py b/neutron/objects/subnet.py index 1029235be19..8d119aa2c3b 100644 --- a/neutron/objects/subnet.py +++ b/neutron/objects/subnet.py @@ -37,6 +37,18 @@ class DNSNameServer(base.NeutronDbObject): 'order': obj_fields.IntegerField() } + @classmethod + def get_objects(cls, context, _pager=None, **kwargs): + """Fetch DNSNameServer objects with default sort by 'order' field. + """ + if not _pager: + _pager = base.Pager() + if not _pager.sorts: + # (NOTE) True means ASC, False is DESC + _pager.sorts = [('order', True)] + return super(DNSNameServer, cls).get_objects(context, _pager, + **kwargs) + @obj_base.VersionedObjectRegistry.register class Route(base.NeutronDbObject): diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index d1acc14a38a..073ab07a288 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -341,6 +341,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): super(BaseObjectIfaceTestCase, self).setUp() self.model_map = collections.defaultdict(list) self.model_map[self._test_class.db_model] = self.db_objs + self.pager_map = collections.defaultdict(lambda: None) def test_get_object(self): with mock.patch.object(obj_db_api, 'get_object', @@ -378,7 +379,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): field].objname)[0] mock_calls.append( mock.call( - self.context, obj_class.db_model, _pager=None, + self.context, obj_class.db_model, + _pager=self.pager_map[obj_class.obj_name()], **{k: db_obj[v] for k, v in obj_class.foreign_keys.items()})) return mock_calls @@ -390,7 +392,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): objs = self._test_class.get_objects(self.context) self._validate_objects(self.db_objs, objs) mock_calls = [ - mock.call(self.context, self._test_class.db_model, _pager=None) + mock.call(self.context, self._test_class.db_model, + _pager=self.pager_map[self._test_class.obj_name()]) ] mock_calls.extend(self._get_synthetic_fields_get_objects_calls( self.db_objs)) @@ -404,10 +407,10 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): objs = self._test_class.get_objects(self.context, **self.valid_field_filter) self._validate_objects(self.db_objs, objs) - mock_calls = [ mock.call( - self.context, self._test_class.db_model, _pager=None, + self.context, self._test_class.db_model, + _pager=self.pager_map[self._test_class.obj_name()], **self._test_class.modify_fields_to_db(self.valid_field_filter) ) ] @@ -887,3 +890,13 @@ class RegisterFilterHookOnModelTestCase(UniqueObjectBase): base.register_filter_hook_on_model( FakeNeutronObject.db_model, filter_name) self.assertIn(filter_name, self.registered_object.extra_filter_names) + + +class PagerTestCase(test_base.BaseTestCase): + def test_comparison(self): + pager = base.Pager(sorts=[('order', True)]) + pager2 = base.Pager(sorts=[('order', True)]) + self.assertEqual(pager, pager2) + + pager3 = base.Pager() + self.assertNotEqual(pager, pager3) diff --git a/neutron/tests/unit/objects/test_subnet.py b/neutron/tests/unit/objects/test_subnet.py index b5690e26540..87397ba9129 100644 --- a/neutron/tests/unit/objects/test_subnet.py +++ b/neutron/tests/unit/objects/test_subnet.py @@ -11,7 +11,9 @@ # under the License. import itertools +from operator import itemgetter +from neutron.objects import base as obj_base from neutron.objects import subnet from neutron.tests.unit.objects import test_base as obj_test_base from neutron.tests.unit import testlib_api @@ -40,6 +42,11 @@ class DNSNameServerObjectIfaceTestCase(obj_test_base.BaseObjectIfaceTestCase): _test_class = subnet.DNSNameServer + def setUp(self): + super(DNSNameServerObjectIfaceTestCase, self).setUp() + self.pager_map[self._test_class.obj_name()] = ( + obj_base.Pager(sorts=[('order', True)])) + class DNSNameServerDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, testlib_api.SqlTestCase): @@ -48,11 +55,59 @@ class DNSNameServerDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, def setUp(self): super(DNSNameServerDbObjectTestCase, self).setUp() + # (NOTE) If two object have the same value for a field and + # they are sorted using that field, the order is not deterministic. + # To avoid breaking the tests we ensure unique values for every field + while not self._is_objects_unique(): + self.db_objs = list(self.get_random_fields() for _ in range(3)) + self.obj_fields = [self._test_class.modify_fields_from_db(db_obj) + for db_obj in self.db_objs] self._create_test_network() self._create_test_subnet(self._network) for obj in itertools.chain(self.db_objs, self.obj_fields): obj['subnet_id'] = self._subnet['id'] + def _is_objects_unique(self): + order_set = set([x['order'] for x in self.db_objs]) + subnet_id_set = set([x['subnet_id'] for x in self.db_objs]) + address_set = set([x['address'] for x in self.db_objs]) + return 3 == len(order_set) == len(subnet_id_set) == len(address_set) + + def _create_dnsnameservers(self): + for obj in self.obj_fields: + dns = self._make_object(obj) + dns.create() + + def test_get_objects_sort_by_order_asc(self): + self._create_dnsnameservers() + objs = self._test_class.get_objects(self.context) + fields_sorted = sorted([dict(obj) for obj in self.obj_fields], + key=itemgetter('order')) + self.assertEqual( + fields_sorted, + [obj_test_base.get_obj_db_fields(obj) for obj in objs]) + + def test_get_objects_sort_by_order_desc(self): + self._create_dnsnameservers() + pager = obj_base.Pager(sorts=[('order', False)]) + objs = self._test_class.get_objects(self.context, _pager=pager, + subnet_id=self._subnet.id) + fields_sorted = sorted([dict(obj) for obj in self.obj_fields], + key=itemgetter('order'), reverse=True) + self.assertEqual( + fields_sorted, + [obj_test_base.get_obj_db_fields(obj) for obj in objs]) + + def test_get_objects_sort_by_address_asc_using_pager(self): + self._create_dnsnameservers() + pager = obj_base.Pager(sorts=[('address', True)]) + objs = self._test_class.get_objects(self.context, _pager=pager) + fields_sorted = sorted([dict(obj) for obj in self.obj_fields], + key=itemgetter('address')) + self.assertEqual( + fields_sorted, + [obj_test_base.get_obj_db_fields(obj) for obj in objs]) + class RouteObjectIfaceTestCase(obj_test_base.BaseObjectIfaceTestCase): @@ -76,6 +131,11 @@ class SubnetObjectIfaceTestCase(obj_test_base.BaseObjectIfaceTestCase): _test_class = subnet.Subnet + def setUp(self): + super(SubnetObjectIfaceTestCase, self).setUp() + self.pager_map[subnet.DNSNameServer.obj_name()] = ( + obj_base.Pager(sorts=[('order', True)])) + class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, testlib_api.SqlTestCase): @@ -87,3 +147,18 @@ class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, self._create_test_network() for obj in itertools.chain(self.db_objs, self.obj_fields): obj['network_id'] = self._network['id'] + + def test_get_dns_nameservers_in_order(self): + obj = self._make_object(self.obj_fields[0]) + obj.create() + dns_nameservers = [(2, '1.2.3.4'), (1, '5.6.7.8'), (4, '7.7.7.7')] + for order, address in dns_nameservers: + dns = subnet.DNSNameServer(self.context, order=order, + address=address, + subnet_id=obj.id) + dns.create() + + new = self._test_class.get_object(self.context, id=obj.id) + self.assertEqual(1, new.dns_nameservers[0].order) + self.assertEqual(2, new.dns_nameservers[1].order) + self.assertEqual(4, new.dns_nameservers[-1].order)