diff --git a/quark/db/api.py b/quark/db/api.py index 48c5708..ebfa21b 100644 --- a/quark/db/api.py +++ b/quark/db/api.py @@ -122,8 +122,9 @@ def _model_query(context, model, filters, fields=None): for key, value in filters.items(): if key in in_filters: - model_type = getattr(model, key) - model_filters.append(model_type.in_(value)) + if value: + model_type = getattr(model, key) + model_filters.append(model_type.in_(value)) elif key in eq_filters: model_type = getattr(model, key) model_filters.append(model_type == value) @@ -136,10 +137,12 @@ def _model_query(context, model, filters, fields=None): etypes = [] for etype in value: etypes.append(protocols.translate_ethertype(etype)) - model_filters.append(model.ethertype.in_(etypes)) + if etypes: + model_filters.append(model.ethertype.in_(etypes)) elif key == "ip_address": - model_filters.append(model.address.in_( - [ip.ipv6().value for ip in value])) + if value: + model_filters.append(model.address.in_( + [ip.ipv6().value for ip in value])) elif key == 'protocol': pnums = [] for version in (protocols.PROTOCOLS_V4, protocols.PROTOCOLS_V6): @@ -155,11 +158,13 @@ def _model_query(context, model, filters, fields=None): model_filters.append(model.port_id == value) elif key == "tenant_id": if model == models.IPAddress: - model_filters.append(model.used_by_tenant_id.in_(value)) + if value: + model_filters.append(model.used_by_tenant_id.in_(value)) elif model in _NO_TENANT_MODELS: pass else: - model_filters.append(model.tenant_id.in_(value)) + if value: + model_filters.append(model.tenant_id.in_(value)) return model_filters @@ -780,15 +785,26 @@ def _subnet_find(context, limit, sorts, marker, page_reverse, fields, if INVERT_DEFAULTS in defaults: invert_defaults = True defaults.pop(0) + + # when 'invert_defaults' were the only entry in defaults, + # defaults will be empty now. The next 4 lines optimize + # performance by avoiding running the in_ filter on an empty set: + # like so: models.Subnet.id.in_([]) + if defaults: + subnet_filter = models.Subnet.id.in_(defaults) + else: + # if defaults is an empty list, just create a False + # BinaryExpression + subnet_filter = models.Subnet.id != models.Subnet.id + if filters and invert_defaults: - query = query.filter(and_(not_(models.Subnet.id.in_(defaults)), + query = query.filter(and_(not_(subnet_filter), and_(*model_filters))) elif not provider_query and filters and not invert_defaults: - query = query.filter(or_(models.Subnet.id.in_(defaults), - and_(*model_filters))) + query = query.filter(or_(subnet_filter, and_(*model_filters))) elif not invert_defaults: - query = query.filter(models.Subnet.id.in_(defaults)) + query = query.filter(subnet_filter) else: query = query.filter(*model_filters) diff --git a/quark/plugin_modules/subnets.py b/quark/plugin_modules/subnets.py index 46dd087..bd02d3f 100644 --- a/quark/plugin_modules/subnets.py +++ b/quark/plugin_modules/subnets.py @@ -69,10 +69,13 @@ def _validate_subnet_cidr(context, network_id, new_subnet_cidr): raise n_exc.BadRequest(resource="subnet", msg="Invalid or missing cidr") + filters = { + 'network_id': network_id, + 'shared': [False] + } # Using admin context here, in case we actually share networks later - subnet_list = db_api.subnet_find(context.elevated(), None, None, None, - False, network_id=network_id, - shared=[False]) + subnet_list = db_api.subnet_find(context=context.elevated(), **filters) + for subnet in subnet_list: if (netaddr.IPSet([subnet.cidr]) & new_subnet_ipset): # don't give out details of the overlapping subnet diff --git a/quark/tests/plugin_modules/test_subnets.py b/quark/tests/plugin_modules/test_subnets.py index a9bf291..fceda96 100644 --- a/quark/tests/plugin_modules/test_subnets.py +++ b/quark/tests/plugin_modules/test_subnets.py @@ -408,15 +408,19 @@ class TestQuarkCreateSubnet(test_quark_plugin.TestQuarkPlugin): mock.patch("quark.db.api.route_create"), mock.patch("quark.db.api.subnet_find"), mock.patch("neutron.common.rpc.get_notifier"), - _allocation_pools_mock() + _allocation_pools_mock(), + mock.patch("sqlalchemy.orm.session.SessionTransaction.commit"), + mock.patch( + "sqlalchemy.orm.unitofwork.UOWTransaction.register_object") ) as (subnet_create, net_find, dns_create, route_create, subnet_find, - get_notifier, alloc_pools_method): + get_notifier, alloc_pools_method, commit, register_object): subnet_create.return_value = subnet_mod net_find.return_value = network route_create.side_effect = route_models dns_create.side_effect = dns_models alloc_pools_method.__get__ = mock.Mock( return_value=allocation_pools) + register_object.return_value = True yield subnet_create, dns_create, route_create def test_create_subnet(self): @@ -860,8 +864,9 @@ class TestQuarkAllocationPoolCache(test_quark_plugin.TestQuarkPlugin): mock.patch("quark.db.api.route_find"), mock.patch("quark.db.api.route_update"), mock.patch("quark.db.api.route_create"), - ) as (subnet_find, subnet_update, - dns_create, route_find, route_update, route_create): + mock.patch("sqlalchemy.orm.session.SessionTransaction.commit") + ) as (subnet_find, subnet_update, dns_create, route_find, + route_update, route_create, commit): subnet_find.return_value = subnet_mod if has_subnet: route_find.return_value = (subnet_mod["routes"][0] if @@ -965,9 +970,12 @@ class TestQuarkUpdateSubnet(test_quark_plugin.TestQuarkPlugin): mock.patch("quark.db.api.route_find"), mock.patch("quark.db.api.route_update"), mock.patch("quark.db.api.route_create"), + mock.patch("sqlalchemy.orm.session.SessionTransaction.commit"), + mock.patch( + "sqlalchemy.orm.unitofwork.UOWTransaction.register_object") ) as (subnet_find, subnet_update, dns_create, - route_find, route_update, route_create): + route_find, route_update, route_create, commit, register_object): subnet_find.return_value = subnet_mod if has_subnet: route_find.return_value = (subnet_mod["routes"][0] if @@ -983,6 +991,7 @@ class TestQuarkUpdateSubnet(test_quark_plugin.TestQuarkPlugin): if new_ip_policy: new_subnet_mod["ip_policy"] = new_ip_policy subnet_update.return_value = new_subnet_mod + register_object.return_value = True yield dns_create, route_update, route_create def test_update_subnet_not_found(self): @@ -1471,9 +1480,10 @@ class TestQuarkCreateSubnetAttrFilters(test_quark_plugin.TestQuarkPlugin): mock.patch("quark.db.api.route_create"), mock.patch("quark.plugin_views._make_subnet_dict"), mock.patch("quark.db.api.subnet_find"), - mock.patch("neutron.common.rpc.get_notifier") + mock.patch("neutron.common.rpc.get_notifier"), + mock.patch("sqlalchemy.orm.session.SessionTransaction.commit") ) as (subnet_create, net_find, dns_create, route_create, sub_dict, - subnet_find, get_notifier): + subnet_find, get_notifier, commit): route_create.return_value = models.Route() yield subnet_create, net_find