Merge "Reuse caller's session in ML2 DB methods"
This commit is contained in:
commit
7d9d38773c
|
@ -76,10 +76,10 @@ class SecurityGroupServerRpcCallback(object):
|
|||
def plugin(self):
|
||||
return manager.NeutronManager.get_plugin()
|
||||
|
||||
def _get_devices_info(self, devices):
|
||||
def _get_devices_info(self, context, devices):
|
||||
return dict(
|
||||
(port['id'], port)
|
||||
for port in self.plugin.get_ports_from_devices(devices)
|
||||
for port in self.plugin.get_ports_from_devices(context, devices)
|
||||
if port and not port['device_owner'].startswith('network:')
|
||||
)
|
||||
|
||||
|
@ -93,7 +93,7 @@ class SecurityGroupServerRpcCallback(object):
|
|||
:returns: port correspond to the devices with security group rules
|
||||
"""
|
||||
devices_info = kwargs.get('devices')
|
||||
ports = self._get_devices_info(devices_info)
|
||||
ports = self._get_devices_info(context, devices_info)
|
||||
return self.plugin.security_group_rules_for_ports(context, ports)
|
||||
|
||||
def security_group_info_for_devices(self, context, **kwargs):
|
||||
|
@ -110,7 +110,7 @@ class SecurityGroupServerRpcCallback(object):
|
|||
Note that sets are serialized into lists by rpc code.
|
||||
"""
|
||||
devices_info = kwargs.get('devices')
|
||||
ports = self._get_devices_info(devices_info)
|
||||
ports = self._get_devices_info(context, devices_info)
|
||||
return self.plugin.security_group_info_for_ports(context, ports)
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ DHCP_RULE_PORT = {4: (67, 68, q_const.IPv4), 6: (547, 546, q_const.IPv6)}
|
|||
class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
|
||||
"""Mixin class to add agent-based security group implementation."""
|
||||
|
||||
def get_port_from_device(self, device):
|
||||
def get_port_from_device(self, context, device):
|
||||
"""Get port dict from device name on an agent.
|
||||
|
||||
Subclass must provide this method or get_ports_from_devices.
|
||||
|
@ -59,13 +59,14 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
|
|||
"or get_ports_from_devices.")
|
||||
% self.__class__.__name__)
|
||||
|
||||
def get_ports_from_devices(self, devices):
|
||||
def get_ports_from_devices(self, context, devices):
|
||||
"""Bulk method of get_port_from_device.
|
||||
|
||||
Subclasses may override this to provide better performance for DB
|
||||
queries, backend calls, etc.
|
||||
"""
|
||||
return [self.get_port_from_device(device) for device in devices]
|
||||
return [self.get_port_from_device(context, device)
|
||||
for device in devices]
|
||||
|
||||
def create_security_group_rule(self, context, security_group_rule):
|
||||
bulk_rule = {'security_group_rules': [security_group_rule]}
|
||||
|
|
|
@ -19,7 +19,6 @@ from sqlalchemy import or_
|
|||
from sqlalchemy.orm import exc
|
||||
|
||||
from neutron.common import constants as n_const
|
||||
from neutron.db import api as db_api
|
||||
from neutron.db import models_v2
|
||||
from neutron.db import securitygroups_db as sg_db
|
||||
from neutron.extensions import portbindings
|
||||
|
@ -244,14 +243,14 @@ def get_port(session, port_id):
|
|||
return
|
||||
|
||||
|
||||
def get_port_from_device_mac(device_mac):
|
||||
def get_port_from_device_mac(context, device_mac):
|
||||
LOG.debug("get_port_from_device_mac() called for mac %s", device_mac)
|
||||
session = db_api.get_session()
|
||||
qry = session.query(models_v2.Port).filter_by(mac_address=device_mac)
|
||||
qry = context.session.query(models_v2.Port).filter_by(
|
||||
mac_address=device_mac)
|
||||
return qry.first()
|
||||
|
||||
|
||||
def get_ports_and_sgs(port_ids):
|
||||
def get_ports_and_sgs(context, port_ids):
|
||||
"""Get ports from database with security group info."""
|
||||
|
||||
# break large queries into smaller parts
|
||||
|
@ -259,25 +258,24 @@ def get_ports_and_sgs(port_ids):
|
|||
LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
|
||||
"query %(maxp)s. Partitioning queries.",
|
||||
{'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
|
||||
return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
|
||||
get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
|
||||
return (get_ports_and_sgs(context, port_ids[:MAX_PORTS_PER_QUERY]) +
|
||||
get_ports_and_sgs(context, port_ids[MAX_PORTS_PER_QUERY:]))
|
||||
|
||||
LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
|
||||
|
||||
if not port_ids:
|
||||
# if port_ids is empty, avoid querying to DB to ask it for nothing
|
||||
return []
|
||||
ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
|
||||
ports_to_sg_ids = get_sg_ids_grouped_by_port(context, port_ids)
|
||||
return [make_port_dict_with_security_groups(port, sec_groups)
|
||||
for port, sec_groups in ports_to_sg_ids.iteritems()]
|
||||
|
||||
|
||||
def get_sg_ids_grouped_by_port(port_ids):
|
||||
def get_sg_ids_grouped_by_port(context, port_ids):
|
||||
sg_ids_grouped_by_port = {}
|
||||
session = db_api.get_session()
|
||||
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
|
||||
|
||||
with session.begin(subtransactions=True):
|
||||
with context.session.begin(subtransactions=True):
|
||||
# partial UUIDs must be individually matched with startswith.
|
||||
# full UUIDs may be matched directly in an IN statement
|
||||
partial_uuids = set(port_id for port_id in port_ids
|
||||
|
@ -288,8 +286,8 @@ def get_sg_ids_grouped_by_port(port_ids):
|
|||
if full_uuids:
|
||||
or_criteria.append(models_v2.Port.id.in_(full_uuids))
|
||||
|
||||
query = session.query(models_v2.Port,
|
||||
sg_db.SecurityGroupPortBinding.security_group_id)
|
||||
query = context.session.query(
|
||||
models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id)
|
||||
query = query.outerjoin(sg_db.SecurityGroupPortBinding,
|
||||
models_v2.Port.id == sg_binding_port)
|
||||
query = query.filter(or_(*or_criteria))
|
||||
|
|
|
@ -1460,11 +1460,12 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
|
|||
port_host = db.get_port_binding_host(context.session, port_id)
|
||||
return (port_host == host)
|
||||
|
||||
def get_ports_from_devices(self, devices):
|
||||
port_ids_to_devices = dict((self._device_to_port_id(device), device)
|
||||
for device in devices)
|
||||
def get_ports_from_devices(self, context, devices):
|
||||
port_ids_to_devices = dict(
|
||||
(self._device_to_port_id(context, device), device)
|
||||
for device in devices)
|
||||
port_ids = port_ids_to_devices.keys()
|
||||
ports = db.get_ports_and_sgs(port_ids)
|
||||
ports = db.get_ports_and_sgs(context, port_ids)
|
||||
for port in ports:
|
||||
# map back to original requested id
|
||||
port_id = next((port_id for port_id in port_ids
|
||||
|
@ -1474,7 +1475,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
|
|||
return ports
|
||||
|
||||
@staticmethod
|
||||
def _device_to_port_id(device):
|
||||
def _device_to_port_id(context, device):
|
||||
# REVISIT(rkukura): Consider calling into MechanismDrivers to
|
||||
# process device names, or having MechanismDrivers supply list
|
||||
# of device prefixes to strip.
|
||||
|
@ -1484,7 +1485,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
|
|||
# REVISIT(irenab): Consider calling into bound MD to
|
||||
# handle the get_device_details RPC
|
||||
if not uuidutils.is_uuid_like(device):
|
||||
port = db.get_port_from_device_mac(device)
|
||||
port = db.get_port_from_device_mac(context, device)
|
||||
if port:
|
||||
return port.id
|
||||
return device
|
||||
|
|
|
@ -67,7 +67,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
|
|||
{'device': device, 'agent_id': agent_id, 'host': host})
|
||||
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
port_id = plugin._device_to_port_id(device)
|
||||
port_id = plugin._device_to_port_id(rpc_context, device)
|
||||
port_context = plugin.get_bound_port_context(rpc_context,
|
||||
port_id,
|
||||
host,
|
||||
|
@ -144,7 +144,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
|
|||
"%(agent_id)s",
|
||||
{'device': device, 'agent_id': agent_id})
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
port_id = plugin._device_to_port_id(device)
|
||||
port_id = plugin._device_to_port_id(rpc_context, device)
|
||||
port_exists = True
|
||||
if (host and not plugin.port_bound_to_host(rpc_context,
|
||||
port_id, host)):
|
||||
|
@ -173,7 +173,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
|
|||
LOG.debug("Device %(device)s up at agent %(agent_id)s",
|
||||
{'device': device, 'agent_id': agent_id})
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
port_id = plugin._device_to_port_id(device)
|
||||
port_id = plugin._device_to_port_id(rpc_context, device)
|
||||
if (host and not plugin.port_bound_to_host(rpc_context,
|
||||
port_id, host)):
|
||||
LOG.debug("Device %(device)s not bound to the"
|
||||
|
|
|
@ -56,7 +56,7 @@ IPv6 = 6
|
|||
class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin):
|
||||
|
||||
@staticmethod
|
||||
def get_port_from_device(device):
|
||||
def get_port_from_device(context, device):
|
||||
port = nvsd_db.get_port_from_device(device)
|
||||
if port:
|
||||
port['device'] = device
|
||||
|
|
|
@ -95,7 +95,7 @@ class SecurityGroupRpcTestPlugin(test_sg.SecurityGroupTestPlugin,
|
|||
self.notify_security_groups_member_updated(context, port)
|
||||
del self.devices[id]
|
||||
|
||||
def get_port_from_device(self, device):
|
||||
def get_port_from_device(self, context, device):
|
||||
device = self.devices.get(device)
|
||||
if device:
|
||||
device['security_group_rules'] = []
|
||||
|
|
|
@ -201,7 +201,8 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
|
|||
self._setup_neutron_network(network_id)
|
||||
port = self._setup_neutron_port(network_id, port_id)
|
||||
|
||||
observed_port = ml2_db.get_port_from_device_mac(port['mac_address'])
|
||||
observed_port = ml2_db.get_port_from_device_mac(self.ctx,
|
||||
port['mac_address'])
|
||||
self.assertEqual(port_id, observed_port.id)
|
||||
|
||||
def test_get_locked_port_and_binding(self):
|
||||
|
|
|
@ -621,23 +621,26 @@ class TestMl2PluginOnly(Ml2PluginV2TestCase):
|
|||
('qvo567890', '567890')]
|
||||
for device, expected in input_output:
|
||||
self.assertEqual(expected,
|
||||
ml2_plugin.Ml2Plugin._device_to_port_id(device))
|
||||
ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||
self.context, device))
|
||||
|
||||
def test__device_to_port_id_mac_address(self):
|
||||
with self.port() as p:
|
||||
mac = p['port']['mac_address']
|
||||
port_id = p['port']['id']
|
||||
self.assertEqual(port_id,
|
||||
ml2_plugin.Ml2Plugin._device_to_port_id(mac))
|
||||
ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||
self.context, mac))
|
||||
|
||||
def test__device_to_port_id_not_uuid_not_mac(self):
|
||||
dev = '1234567'
|
||||
self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(dev))
|
||||
self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||
self.context, dev))
|
||||
|
||||
def test__device_to_port_id_UUID(self):
|
||||
port_id = uuidutils.generate_uuid()
|
||||
self.assertEqual(port_id,
|
||||
ml2_plugin.Ml2Plugin._device_to_port_id(port_id))
|
||||
self.assertEqual(port_id, ml2_plugin.Ml2Plugin._device_to_port_id(
|
||||
self.context, port_id))
|
||||
|
||||
|
||||
class TestMl2DvrPortsV2(TestMl2PortsV2):
|
||||
|
|
|
@ -79,14 +79,14 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||
self.plugin.get_bound_port_context.return_value = None
|
||||
self.assertEqual(
|
||||
{'device': 'fake_device'},
|
||||
self.callbacks.get_device_details('fake_context',
|
||||
self.callbacks.get_device_details(mock.Mock(),
|
||||
device='fake_device'))
|
||||
|
||||
def test_get_device_details_port_context_without_bounded_segment(self):
|
||||
self.plugin.get_bound_port_context().bottom_bound_segment = None
|
||||
self.assertEqual(
|
||||
{'device': 'fake_device'},
|
||||
self.callbacks.get_device_details('fake_context',
|
||||
self.callbacks.get_device_details(mock.Mock(),
|
||||
device='fake_device'))
|
||||
|
||||
def test_get_device_details_port_status_equal_new_status(self):
|
||||
|
@ -103,7 +103,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||
port['admin_state_up'] = admin_state_up
|
||||
port['status'] = status
|
||||
self.plugin.update_port_status.reset_mock()
|
||||
self.callbacks.get_device_details('fake_context')
|
||||
self.callbacks.get_device_details(mock.Mock())
|
||||
self.assertEqual(status == new_status,
|
||||
not self.plugin.update_port_status.called)
|
||||
|
||||
|
@ -113,7 +113,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||
self.plugin.get_bound_port_context().current = port
|
||||
self.plugin.get_bound_port_context().network.current = (
|
||||
{"id": "fake_network"})
|
||||
self.callbacks.get_device_details('fake_context', host='fake_host',
|
||||
self.callbacks.get_device_details(mock.Mock(), host='fake_host',
|
||||
cached_networks=cached_networks)
|
||||
self.assertTrue('fake_port' in cached_networks)
|
||||
|
||||
|
@ -123,7 +123,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||
port_context.current = port
|
||||
port_context.host = 'fake'
|
||||
self.plugin.update_port_status.reset_mock()
|
||||
self.callbacks.get_device_details('fake_context',
|
||||
self.callbacks.get_device_details(mock.Mock(),
|
||||
host='fake_host')
|
||||
self.assertFalse(self.plugin.update_port_status.called)
|
||||
|
||||
|
@ -132,7 +132,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||
port_context = self.plugin.get_bound_port_context()
|
||||
port_context.current = port
|
||||
self.plugin.update_port_status.reset_mock()
|
||||
self.callbacks.get_device_details('fake_context')
|
||||
self.callbacks.get_device_details(mock.Mock())
|
||||
self.assertTrue(self.plugin.update_port_status.called)
|
||||
|
||||
def test_get_devices_details_list(self):
|
||||
|
@ -159,8 +159,8 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||
def _test_update_device_not_bound_to_host(self, func):
|
||||
self.plugin.port_bound_to_host.return_value = False
|
||||
self.plugin._device_to_port_id.return_value = 'fake_port_id'
|
||||
res = func('fake_context', device='fake_device', host='fake_host')
|
||||
self.plugin.port_bound_to_host.assert_called_once_with('fake_context',
|
||||
res = func(mock.Mock(), device='fake_device', host='fake_host')
|
||||
self.plugin.port_bound_to_host.assert_called_once_with(mock.ANY,
|
||||
'fake_port_id',
|
||||
'fake_host')
|
||||
return res
|
||||
|
@ -180,18 +180,18 @@ class RpcCallbacksTestCase(base.BaseTestCase):
|
|||
self.plugin._device_to_port_id.return_value = 'fake_port_id'
|
||||
self.assertEqual(
|
||||
{'device': 'fake_device', 'exists': False},
|
||||
self.callbacks.update_device_down('fake_context',
|
||||
self.callbacks.update_device_down(mock.Mock(),
|
||||
device='fake_device',
|
||||
host='fake_host'))
|
||||
self.plugin.update_port_status.assert_called_once_with(
|
||||
'fake_context', 'fake_port_id', constants.PORT_STATUS_DOWN,
|
||||
mock.ANY, 'fake_port_id', constants.PORT_STATUS_DOWN,
|
||||
'fake_host')
|
||||
|
||||
def test_update_device_down_call_update_port_status_failed(self):
|
||||
self.plugin.update_port_status.side_effect = exc.StaleDataError
|
||||
self.assertEqual({'device': 'fake_device', 'exists': False},
|
||||
self.callbacks.update_device_down(
|
||||
'fake_context', device='fake_device'))
|
||||
mock.Mock(), device='fake_device'))
|
||||
|
||||
|
||||
class RpcApiTestCase(base.BaseTestCase):
|
||||
|
|
|
@ -19,6 +19,7 @@ import math
|
|||
import mock
|
||||
|
||||
from neutron.common import constants as const
|
||||
from neutron import context
|
||||
from neutron.extensions import securitygroup as ext_sg
|
||||
from neutron import manager
|
||||
from neutron.tests import tools
|
||||
|
@ -51,6 +52,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||
test_sg_rpc.SGNotificationTestMixin):
|
||||
def setUp(self):
|
||||
super(TestMl2SecurityGroups, self).setUp()
|
||||
self.ctx = context.get_admin_context()
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
plugin.start_rpc_listeners()
|
||||
|
||||
|
@ -75,7 +77,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||
]
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
# should match full ID and starting chars
|
||||
ports = plugin.get_ports_from_devices(
|
||||
ports = plugin.get_ports_from_devices(self.ctx,
|
||||
[orig_ports[0]['id'], orig_ports[1]['id'][0:8],
|
||||
orig_ports[2]['id']])
|
||||
self.assertEqual(len(orig_ports), len(ports))
|
||||
|
@ -92,7 +94,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||
|
||||
def test_security_group_get_ports_from_devices_with_bad_id(self):
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
ports = plugin.get_ports_from_devices(['bad_device_id'])
|
||||
ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id'])
|
||||
self.assertFalse(ports)
|
||||
|
||||
def test_security_group_no_db_calls_with_no_ports(self):
|
||||
|
@ -100,7 +102,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||
with mock.patch(
|
||||
'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
|
||||
) as get_mock:
|
||||
self.assertFalse(plugin.get_ports_from_devices([]))
|
||||
self.assertFalse(plugin.get_ports_from_devices(self.ctx, []))
|
||||
self.assertFalse(get_mock.called)
|
||||
|
||||
def test_large_port_count_broken_into_parts(self):
|
||||
|
@ -114,10 +116,10 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||
mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
|
||||
return_value={}),
|
||||
) as (max_mock, get_mock):
|
||||
plugin.get_ports_from_devices(
|
||||
plugin.get_ports_from_devices(self.ctx,
|
||||
['%s%s' % (const.TAP_DEVICE_PREFIX, i)
|
||||
for i in range(ports_to_query)])
|
||||
all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
|
||||
all_call_args = map(lambda x: x[1][1], get_mock.mock_calls)
|
||||
last_call_args = all_call_args.pop()
|
||||
# all but last should be getting MAX_PORTS_PER_QUERY ports
|
||||
self.assertTrue(
|
||||
|
@ -139,14 +141,14 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
|
|||
# have one matching 'IN' critiera for all of the IDs
|
||||
with contextlib.nested(
|
||||
mock.patch('neutron.plugins.ml2.db.or_'),
|
||||
mock.patch('neutron.plugins.ml2.db.db_api.get_session')
|
||||
) as (or_mock, sess_mock):
|
||||
qmock = sess_mock.return_value.query
|
||||
mock.patch('sqlalchemy.orm.Session.query')
|
||||
) as (or_mock, qmock):
|
||||
fmock = qmock.return_value.outerjoin.return_value.filter
|
||||
# return no ports to exit the method early since we are mocking
|
||||
# the query
|
||||
fmock.return_value = []
|
||||
plugin.get_ports_from_devices([test_base._uuid(),
|
||||
plugin.get_ports_from_devices(self.ctx,
|
||||
[test_base._uuid(),
|
||||
test_base._uuid()])
|
||||
# the or_ function should only have one argument
|
||||
or_mock.assert_called_once_with(mock.ANY)
|
||||
|
|
|
@ -89,7 +89,8 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
|
|||
req.get_response(self.api))
|
||||
port_id = res['port']['id']
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
port_dict = plugin.get_port_from_device(port_id)
|
||||
port_dict = plugin.get_port_from_device(mock.Mock(),
|
||||
port_id)
|
||||
self.assertEqual(port_id, port_dict['id'])
|
||||
self.assertEqual([security_group_id],
|
||||
port_dict[ext_sg.SECURITYGROUPS])
|
||||
|
@ -101,5 +102,5 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
|
|||
def test_security_group_get_port_from_device_with_no_port(self):
|
||||
|
||||
plugin = manager.NeutronManager.get_plugin()
|
||||
port_dict = plugin.get_port_from_device('bad_device_id')
|
||||
port_dict = plugin.get_port_from_device(mock.Mock(), 'bad_device_id')
|
||||
self.assertIsNone(port_dict)
|
||||
|
|
Loading…
Reference in New Issue