diff --git a/neutron/db/common_db_mixin.py b/neutron/db/common_db_mixin.py index d7eedd53d4b..2a65420e5e4 100644 --- a/neutron/db/common_db_mixin.py +++ b/neutron/db/common_db_mixin.py @@ -129,7 +129,8 @@ class CommonDbMixin(object): query_filter = None if self.model_query_scope(context, model): if hasattr(model, 'rbac_entries'): - rbac_model, join_params = self._get_rbac_query_params(model) + rbac_model, join_params = self._get_rbac_query_params( + model)[:2] query = query.outerjoin(*join_params) query_filter = ( (model.tenant_id == context.tenant_id) | @@ -185,16 +186,24 @@ class CommonDbMixin(object): @staticmethod def _get_rbac_query_params(model): - """Return the class and join params for the rbac relationship.""" + """Return the parameters required to query an model's RBAC entries. + + Returns a tuple of 3 containing: + 1. the relevant RBAC model for a given model + 2. the join parameters required to query the RBAC entries for the model + 3. the ID column of the passed in model that matches the object_id + in the rbac entries. + """ try: cls = model.rbac_entries.property.mapper.class_ - return (cls, (cls, )) + return (cls, (cls, ), model.id) except AttributeError: # an association proxy is being used (e.g. subnets # depends on network's rbac entries) rbac_model = (model.rbac_entries.target_class. rbac_entries.property.mapper.class_) - return (rbac_model, model.rbac_entries.attr) + return (rbac_model, model.rbac_entries.attr, + model.rbac_entries.remote_attr.class_.id) def _apply_filters_to_query(self, query, model, filters, context=None): if filters: @@ -213,17 +222,29 @@ class CommonDbMixin(object): elif key == 'shared' and hasattr(model, 'rbac_entries'): # translate a filter on shared into a query against the # object's rbac entries - rbac, join_params = self._get_rbac_query_params(model) + rbac, join_params, oid_col = self._get_rbac_query_params( + model) query = query.outerjoin(*join_params, aliased=True) matches = [rbac.target_tenant == '*'] if context: matches.append(rbac.target_tenant == context.tenant_id) - is_shared = and_( - ~rbac.object_id.is_(None), - rbac.action == 'access_as_shared', - or_(*matches) - ) - query = query.filter(is_shared if value[0] else ~is_shared) + # any 'access_as_shared' records that match the + # wildcard or requesting tenant + is_shared = and_(rbac.action == 'access_as_shared', + or_(*matches)) + if not value[0]: + # NOTE(kevinbenton): we need to find objects that don't + # have an entry that matches the criteria above so + # we use a subquery to exclude them. + # We can't just filter the inverse of the query above + # because that will still give us a network shared to + # our tenant (or wildcard) if it's shared to another + # tenant. + is_shared = ~oid_col.in_( + query.session.query(rbac.object_id). + filter(is_shared) + ) + query = query.filter(is_shared) for _nam, hooks in six.iteritems(self._model_query_hooks.get(model, {})): result_filter = hooks.get('result_filters', None) diff --git a/neutron/tests/api/admin/test_shared_network_extension.py b/neutron/tests/api/admin/test_shared_network_extension.py index b60ebf9a949..87bf8ebb91c 100644 --- a/neutron/tests/api/admin/test_shared_network_extension.py +++ b/neutron/tests/api/admin/test_shared_network_extension.py @@ -341,3 +341,22 @@ class RBACSharedNetworksTest(base.BaseAdminNetworkTest): with testtools.ExpectedException(lib_exc.Forbidden): self.client.update_rbac_policy(pol['rbac_policy']['id'], target_tenant='*') + + @test.attr(type='smoke') + @test.idempotent_id('86c3529b-1231-40de-803c-aeeeeeee7fff') + def test_filtering_works_with_rbac_records_present(self): + resp = self._make_admin_net_and_subnet_shared_to_tenant_id( + self.client.tenant_id) + net = resp['network'] + sub = resp['subnet'] + self.admin_client.create_rbac_policy( + object_type='network', object_id=net['id'], + action='access_as_shared', target_tenant='*') + for state, assertion in ((False, self.assertNotIn), + (True, self.assertIn)): + nets = [n['id'] for n in + self.admin_client.list_networks(shared=state)['networks']] + assertion(net['id'], nets) + subs = [s['id'] for s in + self.admin_client.list_subnets(shared=state)['subnets']] + assertion(sub['id'], subs)