From ab86e8deaf81d543071244a37c72420e460ccf58 Mon Sep 17 00:00:00 2001
From: Gary Kotton <gkotton@vmware.com>
Date: Mon, 15 Jan 2018 05:48:45 -0800
Subject: [PATCH] TVD: ensure that can return specific tenant/project requests

A service tenant may do a request for a speicific tenants data,
for example, ports (as in the case with a nova boot). So we need
to ensure that the filters requested by the tenant are met.

Change-Id: Ic7ff59a813347f943e6c84478d9f036c90473c9e
---
 vmware_nsx/plugins/nsx/plugin.py        | 63 ++++++++++++++++---------
 vmware_nsx/services/lbaas/nsx/plugin.py | 15 +++++-
 2 files changed, 54 insertions(+), 24 deletions(-)

diff --git a/vmware_nsx/plugins/nsx/plugin.py b/vmware_nsx/plugins/nsx/plugin.py
index 035b1a5cb3..401f1dec7c 100644
--- a/vmware_nsx/plugins/nsx/plugin.py
+++ b/vmware_nsx/plugins/nsx/plugin.py
@@ -252,8 +252,10 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                 del data[field]
 
     def _list_availability_zones(self, context, filters=None):
-        p = self._get_plugin_from_project(context, context.project_id)
-        return p._list_availability_zones(context, filters=filters)
+        p = self._get_plugin_for_request(context, filters)
+        if p:
+            return p._list_availability_zones(context, filters=filters)
+        return []
 
     def validate_availability_zones(self, context, resource_type,
                                     availability_zones):
@@ -311,12 +313,29 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
         p = self._get_plugin_from_net_id(context, id)
         return p.get_network(context, id, fields=fields)
 
+    def _get_plugin_for_request(self, context, filters):
+        project_id = context.project_id
+        if filters:
+            if filters.get('tenant_id'):
+                project_id = filters.get('tenant_id')
+            elif filters.get('project_id'):
+                project_id = filters.get('project_id')
+            else:
+                # A specific filter request is made. So here we
+                # will not filter according to the plugin.
+                return
+            # If there are multiple tenants/prijects being requested then
+            # we will not filter according to the plugin
+            if isinstance(project_id, list):
+                return
+        return self._get_plugin_from_project(context, project_id)
+
     def get_networks(self, context, filters=None, fields=None,
                      sorts=None, limit=None, marker=None,
                      page_reverse=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         filters = filters or {}
         with db_api.context_manager.reader.using(context):
             networks = (
@@ -325,7 +344,7 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                     limit, marker, page_reverse))
             for net in networks[:]:
                 p = self._get_plugin_from_project(context, net['tenant_id'])
-                if p == req_p:
+                if p == req_p or req_p is None:
                     p._extend_get_network_dict_provider(context, net)
                 else:
                     networks.remove(net)
@@ -372,7 +391,7 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                   page_reverse=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         filters = filters or {}
         with db_api.context_manager.reader.using(context):
             ports = (
@@ -385,7 +404,7 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                     port_model = self._get_port(context, port['id'])
                     resource_extend.apply_funcs('ports', port, port_model)
                 p = self._get_plugin_from_net_id(context, port['network_id'])
-                if p == req_p:
+                if p == req_p or req_p is None:
                     if hasattr(p, '_extend_get_port_dict_qos_and_binding'):
                         p._extend_get_port_dict_qos_and_binding(context, port)
                     if hasattr(p,
@@ -421,14 +440,14 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
         else:
             # Read project plugin to filter relevant projects according to
             # plugin
-            req_p = self._get_plugin_from_project(context, context.project_id)
+            req_p = self._get_plugin_for_request(context, filters)
             filters = filters or {}
             subnets = super(NsxTVDPlugin, self).get_subnets(
                 context, filters=filters, fields=fields, sorts=sorts,
                 limit=limit, marker=marker, page_reverse=page_reverse)
             for subnet in subnets[:]:
                 p = self._get_plugin_from_project(context, subnet['tenant_id'])
-                if p != req_p:
+                if req_p and p != req_p:
                     subnets.remove(subnet)
             return subnets
 
@@ -545,13 +564,13 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                     page_reverse=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         routers = super(NsxTVDPlugin, self).get_routers(
             context, filters=filters, fields=fields, sorts=sorts,
             limit=limit, marker=marker, page_reverse=page_reverse)
         for router in routers[:]:
             p = self._get_plugin_from_project(context, router['tenant_id'])
-            if p != req_p:
+            if req_p and p != req_p:
                 routers.remove(router)
         return routers
 
@@ -585,14 +604,14 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                         page_reverse=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         fips = super(NsxTVDPlugin, self).get_floatingips(
             context, filters=filters, fields=fields, sorts=sorts,
             limit=limit, marker=marker, page_reverse=page_reverse)
         for fip in fips[:]:
             p = self._get_plugin_from_project(context,
                                               fip['tenant_id'])
-            if p != req_p:
+            if req_p and p != req_p:
                 fips.remove(fip)
         return fips
 
@@ -633,14 +652,14 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                             marker=None, page_reverse=False, default_sg=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         sgs = super(NsxTVDPlugin, self).get_security_groups(
             context, filters=filters, fields=fields, sorts=sorts,
             limit=limit, marker=marker, page_reverse=page_reverse,
             default_sg=default_sg)
         for sg in sgs[:]:
             p = self._get_plugin_from_project(context, sg['tenant_id'])
-            if p != req_p:
+            if req_p and p != req_p:
                 sgs.remove(sg)
         return sgs
 
@@ -664,13 +683,13 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                                  page_reverse=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         rules = super(NsxTVDPlugin, self).get_security_group_rules(
             context, filters=filters, fields=fields, sorts=sorts,
             limit=limit, marker=marker, page_reverse=page_reverse)
         for rule in rules[:]:
             p = self._get_plugin_from_project(context, rule['tenant_id'])
-            if p != req_p:
+            if req_p and p != req_p:
                 rules.remove(rule)
         return rules
 
@@ -810,8 +829,8 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
 
     def get_housekeepers(self, context, filters=None, fields=None, sorts=None,
                          limit=None, marker=None, page_reverse=False):
-        p = self._get_plugin_from_project(context, context.project_id)
-        if hasattr(p, 'housekeeper'):
+        p = self._get_plugin_for_request(context, filters)
+        if p and hasattr(p, 'housekeeper'):
             return p.housekeeper.list()
         return []
 
@@ -826,14 +845,14 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                            page_reverse=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         address_scopes = super(NsxTVDPlugin, self).get_address_scopes(
             context, filters=filters, fields=fields, sorts=sorts,
             limit=limit, marker=marker, page_reverse=page_reverse)
         for address_scope in address_scopes[:]:
             p = self._get_plugin_from_project(context,
                                               address_scope['tenant_id'])
-            if p != req_p:
+            if req_p and p != req_p:
                 address_scopes.remove(address_scope)
         return address_scopes
 
@@ -842,13 +861,13 @@ class NsxTVDPlugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
                         page_reverse=False):
         # Read project plugin to filter relevant projects according to
         # plugin
-        req_p = self._get_plugin_from_project(context, context.project_id)
+        req_p = self._get_plugin_for_request(context, filters)
         pools = super(NsxTVDPlugin, self).get_subnetpools(
             context, filters=filters, fields=fields, sorts=sorts,
             limit=limit, marker=marker, page_reverse=page_reverse)
         for pool in pools[:]:
             p = self._get_plugin_from_project(context,
                                               pool['tenant_id'])
-            if p != req_p:
+            if req_p and p != req_p:
                 pools.remove(pool)
         return pools
diff --git a/vmware_nsx/services/lbaas/nsx/plugin.py b/vmware_nsx/services/lbaas/nsx/plugin.py
index bef5065e53..4a8ad00eb1 100644
--- a/vmware_nsx/services/lbaas/nsx/plugin.py
+++ b/vmware_nsx/services/lbaas/nsx/plugin.py
@@ -19,7 +19,18 @@ from vmware_nsx.db import db as nsx_db
 
 
 class LoadBalancerTVDPluginv2(plugin.LoadBalancerPluginv2):
-    def _get_project_mapping(self, context, project_id):
+
+    def _get_project_mapping(self, context, filters):
+        project_id = context.project_id
+        if filters:
+            if filters.get('tenant_id'):
+                project_id = filters.get('tenant_id')
+            elif filters.get('project_id'):
+                project_id = filters.get('project_id')
+            # If multiple are requested then we revert to
+            # the context's project id
+            if isinstance(project_id, list):
+                project_id = context.project_id
         mapping = nsx_db.get_project_plugin_mapping(
                 context.session, project_id)
         if mapping:
@@ -28,7 +39,7 @@ class LoadBalancerTVDPluginv2(plugin.LoadBalancerPluginv2):
             raise exceptions.ObjectNotFound(id=project_id)
 
     def _filter_entries(self, method, context, filters=None, fields=None):
-        req_p = self._get_project_mapping(context, context.project_id)
+        req_p = self._get_project_mapping(context, filters)
         entries = method(context, filters=filters, fields=fields)
         for entry in entries[:]:
             p = self._get_project_mapping(context,