Merge "Reuse caller's session in ML2 DB methods"

This commit is contained in:
Jenkins 2015-05-21 17:04:28 +00:00 committed by Gerrit Code Review
commit 7d9d38773c
12 changed files with 66 additions and 59 deletions

View File

@ -76,10 +76,10 @@ class SecurityGroupServerRpcCallback(object):
def plugin(self):
return manager.NeutronManager.get_plugin()
def _get_devices_info(self, devices):
def _get_devices_info(self, context, devices):
return dict(
(port['id'], port)
for port in self.plugin.get_ports_from_devices(devices)
for port in self.plugin.get_ports_from_devices(context, devices)
if port and not port['device_owner'].startswith('network:')
)
@ -93,7 +93,7 @@ class SecurityGroupServerRpcCallback(object):
:returns: port correspond to the devices with security group rules
"""
devices_info = kwargs.get('devices')
ports = self._get_devices_info(devices_info)
ports = self._get_devices_info(context, devices_info)
return self.plugin.security_group_rules_for_ports(context, ports)
def security_group_info_for_devices(self, context, **kwargs):
@ -110,7 +110,7 @@ class SecurityGroupServerRpcCallback(object):
Note that sets are serialized into lists by rpc code.
"""
devices_info = kwargs.get('devices')
ports = self._get_devices_info(devices_info)
ports = self._get_devices_info(context, devices_info)
return self.plugin.security_group_info_for_ports(context, ports)

View File

@ -38,7 +38,7 @@ DHCP_RULE_PORT = {4: (67, 68, q_const.IPv4), 6: (547, 546, q_const.IPv6)}
class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
"""Mixin class to add agent-based security group implementation."""
def get_port_from_device(self, device):
def get_port_from_device(self, context, device):
"""Get port dict from device name on an agent.
Subclass must provide this method or get_ports_from_devices.
@ -59,13 +59,14 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
"or get_ports_from_devices.")
% self.__class__.__name__)
def get_ports_from_devices(self, devices):
def get_ports_from_devices(self, context, devices):
"""Bulk method of get_port_from_device.
Subclasses may override this to provide better performance for DB
queries, backend calls, etc.
"""
return [self.get_port_from_device(device) for device in devices]
return [self.get_port_from_device(context, device)
for device in devices]
def create_security_group_rule(self, context, security_group_rule):
bulk_rule = {'security_group_rules': [security_group_rule]}

View File

@ -19,7 +19,6 @@ from sqlalchemy import or_
from sqlalchemy.orm import exc
from neutron.common import constants as n_const
from neutron.db import api as db_api
from neutron.db import models_v2
from neutron.db import securitygroups_db as sg_db
from neutron.extensions import portbindings
@ -244,14 +243,14 @@ def get_port(session, port_id):
return
def get_port_from_device_mac(device_mac):
def get_port_from_device_mac(context, device_mac):
LOG.debug("get_port_from_device_mac() called for mac %s", device_mac)
session = db_api.get_session()
qry = session.query(models_v2.Port).filter_by(mac_address=device_mac)
qry = context.session.query(models_v2.Port).filter_by(
mac_address=device_mac)
return qry.first()
def get_ports_and_sgs(port_ids):
def get_ports_and_sgs(context, port_ids):
"""Get ports from database with security group info."""
# break large queries into smaller parts
@ -259,25 +258,24 @@ def get_ports_and_sgs(port_ids):
LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
"query %(maxp)s. Partitioning queries.",
{'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
return (get_ports_and_sgs(context, port_ids[:MAX_PORTS_PER_QUERY]) +
get_ports_and_sgs(context, port_ids[MAX_PORTS_PER_QUERY:]))
LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
if not port_ids:
# if port_ids is empty, avoid querying to DB to ask it for nothing
return []
ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
ports_to_sg_ids = get_sg_ids_grouped_by_port(context, port_ids)
return [make_port_dict_with_security_groups(port, sec_groups)
for port, sec_groups in ports_to_sg_ids.iteritems()]
def get_sg_ids_grouped_by_port(port_ids):
def get_sg_ids_grouped_by_port(context, port_ids):
sg_ids_grouped_by_port = {}
session = db_api.get_session()
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
with session.begin(subtransactions=True):
with context.session.begin(subtransactions=True):
# partial UUIDs must be individually matched with startswith.
# full UUIDs may be matched directly in an IN statement
partial_uuids = set(port_id for port_id in port_ids
@ -288,8 +286,8 @@ def get_sg_ids_grouped_by_port(port_ids):
if full_uuids:
or_criteria.append(models_v2.Port.id.in_(full_uuids))
query = session.query(models_v2.Port,
sg_db.SecurityGroupPortBinding.security_group_id)
query = context.session.query(
models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id)
query = query.outerjoin(sg_db.SecurityGroupPortBinding,
models_v2.Port.id == sg_binding_port)
query = query.filter(or_(*or_criteria))

View File

@ -1460,11 +1460,12 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
port_host = db.get_port_binding_host(context.session, port_id)
return (port_host == host)
def get_ports_from_devices(self, devices):
port_ids_to_devices = dict((self._device_to_port_id(device), device)
for device in devices)
def get_ports_from_devices(self, context, devices):
port_ids_to_devices = dict(
(self._device_to_port_id(context, device), device)
for device in devices)
port_ids = port_ids_to_devices.keys()
ports = db.get_ports_and_sgs(port_ids)
ports = db.get_ports_and_sgs(context, port_ids)
for port in ports:
# map back to original requested id
port_id = next((port_id for port_id in port_ids
@ -1474,7 +1475,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
return ports
@staticmethod
def _device_to_port_id(device):
def _device_to_port_id(context, device):
# REVISIT(rkukura): Consider calling into MechanismDrivers to
# process device names, or having MechanismDrivers supply list
# of device prefixes to strip.
@ -1484,7 +1485,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
# REVISIT(irenab): Consider calling into bound MD to
# handle the get_device_details RPC
if not uuidutils.is_uuid_like(device):
port = db.get_port_from_device_mac(device)
port = db.get_port_from_device_mac(context, device)
if port:
return port.id
return device

View File

@ -67,7 +67,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
{'device': device, 'agent_id': agent_id, 'host': host})
plugin = manager.NeutronManager.get_plugin()
port_id = plugin._device_to_port_id(device)
port_id = plugin._device_to_port_id(rpc_context, device)
port_context = plugin.get_bound_port_context(rpc_context,
port_id,
host,
@ -144,7 +144,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
"%(agent_id)s",
{'device': device, 'agent_id': agent_id})
plugin = manager.NeutronManager.get_plugin()
port_id = plugin._device_to_port_id(device)
port_id = plugin._device_to_port_id(rpc_context, device)
port_exists = True
if (host and not plugin.port_bound_to_host(rpc_context,
port_id, host)):
@ -173,7 +173,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
LOG.debug("Device %(device)s up at agent %(agent_id)s",
{'device': device, 'agent_id': agent_id})
plugin = manager.NeutronManager.get_plugin()
port_id = plugin._device_to_port_id(device)
port_id = plugin._device_to_port_id(rpc_context, device)
if (host and not plugin.port_bound_to_host(rpc_context,
port_id, host)):
LOG.debug("Device %(device)s not bound to the"

View File

@ -56,7 +56,7 @@ IPv6 = 6
class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin):
@staticmethod
def get_port_from_device(device):
def get_port_from_device(context, device):
port = nvsd_db.get_port_from_device(device)
if port:
port['device'] = device

View File

@ -95,7 +95,7 @@ class SecurityGroupRpcTestPlugin(test_sg.SecurityGroupTestPlugin,
self.notify_security_groups_member_updated(context, port)
del self.devices[id]
def get_port_from_device(self, device):
def get_port_from_device(self, context, device):
device = self.devices.get(device)
if device:
device['security_group_rules'] = []

View File

@ -201,7 +201,8 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
self._setup_neutron_network(network_id)
port = self._setup_neutron_port(network_id, port_id)
observed_port = ml2_db.get_port_from_device_mac(port['mac_address'])
observed_port = ml2_db.get_port_from_device_mac(self.ctx,
port['mac_address'])
self.assertEqual(port_id, observed_port.id)
def test_get_locked_port_and_binding(self):

View File

@ -621,23 +621,26 @@ class TestMl2PluginOnly(Ml2PluginV2TestCase):
('qvo567890', '567890')]
for device, expected in input_output:
self.assertEqual(expected,
ml2_plugin.Ml2Plugin._device_to_port_id(device))
ml2_plugin.Ml2Plugin._device_to_port_id(
self.context, device))
def test__device_to_port_id_mac_address(self):
with self.port() as p:
mac = p['port']['mac_address']
port_id = p['port']['id']
self.assertEqual(port_id,
ml2_plugin.Ml2Plugin._device_to_port_id(mac))
ml2_plugin.Ml2Plugin._device_to_port_id(
self.context, mac))
def test__device_to_port_id_not_uuid_not_mac(self):
dev = '1234567'
self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(dev))
self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(
self.context, dev))
def test__device_to_port_id_UUID(self):
port_id = uuidutils.generate_uuid()
self.assertEqual(port_id,
ml2_plugin.Ml2Plugin._device_to_port_id(port_id))
self.assertEqual(port_id, ml2_plugin.Ml2Plugin._device_to_port_id(
self.context, port_id))
class TestMl2DvrPortsV2(TestMl2PortsV2):

View File

@ -79,14 +79,14 @@ class RpcCallbacksTestCase(base.BaseTestCase):
self.plugin.get_bound_port_context.return_value = None
self.assertEqual(
{'device': 'fake_device'},
self.callbacks.get_device_details('fake_context',
self.callbacks.get_device_details(mock.Mock(),
device='fake_device'))
def test_get_device_details_port_context_without_bounded_segment(self):
self.plugin.get_bound_port_context().bottom_bound_segment = None
self.assertEqual(
{'device': 'fake_device'},
self.callbacks.get_device_details('fake_context',
self.callbacks.get_device_details(mock.Mock(),
device='fake_device'))
def test_get_device_details_port_status_equal_new_status(self):
@ -103,7 +103,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
port['admin_state_up'] = admin_state_up
port['status'] = status
self.plugin.update_port_status.reset_mock()
self.callbacks.get_device_details('fake_context')
self.callbacks.get_device_details(mock.Mock())
self.assertEqual(status == new_status,
not self.plugin.update_port_status.called)
@ -113,7 +113,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
self.plugin.get_bound_port_context().current = port
self.plugin.get_bound_port_context().network.current = (
{"id": "fake_network"})
self.callbacks.get_device_details('fake_context', host='fake_host',
self.callbacks.get_device_details(mock.Mock(), host='fake_host',
cached_networks=cached_networks)
self.assertTrue('fake_port' in cached_networks)
@ -123,7 +123,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
port_context.current = port
port_context.host = 'fake'
self.plugin.update_port_status.reset_mock()
self.callbacks.get_device_details('fake_context',
self.callbacks.get_device_details(mock.Mock(),
host='fake_host')
self.assertFalse(self.plugin.update_port_status.called)
@ -132,7 +132,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
port_context = self.plugin.get_bound_port_context()
port_context.current = port
self.plugin.update_port_status.reset_mock()
self.callbacks.get_device_details('fake_context')
self.callbacks.get_device_details(mock.Mock())
self.assertTrue(self.plugin.update_port_status.called)
def test_get_devices_details_list(self):
@ -159,8 +159,8 @@ class RpcCallbacksTestCase(base.BaseTestCase):
def _test_update_device_not_bound_to_host(self, func):
self.plugin.port_bound_to_host.return_value = False
self.plugin._device_to_port_id.return_value = 'fake_port_id'
res = func('fake_context', device='fake_device', host='fake_host')
self.plugin.port_bound_to_host.assert_called_once_with('fake_context',
res = func(mock.Mock(), device='fake_device', host='fake_host')
self.plugin.port_bound_to_host.assert_called_once_with(mock.ANY,
'fake_port_id',
'fake_host')
return res
@ -180,18 +180,18 @@ class RpcCallbacksTestCase(base.BaseTestCase):
self.plugin._device_to_port_id.return_value = 'fake_port_id'
self.assertEqual(
{'device': 'fake_device', 'exists': False},
self.callbacks.update_device_down('fake_context',
self.callbacks.update_device_down(mock.Mock(),
device='fake_device',
host='fake_host'))
self.plugin.update_port_status.assert_called_once_with(
'fake_context', 'fake_port_id', constants.PORT_STATUS_DOWN,
mock.ANY, 'fake_port_id', constants.PORT_STATUS_DOWN,
'fake_host')
def test_update_device_down_call_update_port_status_failed(self):
self.plugin.update_port_status.side_effect = exc.StaleDataError
self.assertEqual({'device': 'fake_device', 'exists': False},
self.callbacks.update_device_down(
'fake_context', device='fake_device'))
mock.Mock(), device='fake_device'))
class RpcApiTestCase(base.BaseTestCase):

View File

@ -19,6 +19,7 @@ import math
import mock
from neutron.common import constants as const
from neutron import context
from neutron.extensions import securitygroup as ext_sg
from neutron import manager
from neutron.tests import tools
@ -51,6 +52,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
test_sg_rpc.SGNotificationTestMixin):
def setUp(self):
super(TestMl2SecurityGroups, self).setUp()
self.ctx = context.get_admin_context()
plugin = manager.NeutronManager.get_plugin()
plugin.start_rpc_listeners()
@ -75,7 +77,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
]
plugin = manager.NeutronManager.get_plugin()
# should match full ID and starting chars
ports = plugin.get_ports_from_devices(
ports = plugin.get_ports_from_devices(self.ctx,
[orig_ports[0]['id'], orig_ports[1]['id'][0:8],
orig_ports[2]['id']])
self.assertEqual(len(orig_ports), len(ports))
@ -92,7 +94,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
def test_security_group_get_ports_from_devices_with_bad_id(self):
plugin = manager.NeutronManager.get_plugin()
ports = plugin.get_ports_from_devices(['bad_device_id'])
ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id'])
self.assertFalse(ports)
def test_security_group_no_db_calls_with_no_ports(self):
@ -100,7 +102,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
with mock.patch(
'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
) as get_mock:
self.assertFalse(plugin.get_ports_from_devices([]))
self.assertFalse(plugin.get_ports_from_devices(self.ctx, []))
self.assertFalse(get_mock.called)
def test_large_port_count_broken_into_parts(self):
@ -114,10 +116,10 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
return_value={}),
) as (max_mock, get_mock):
plugin.get_ports_from_devices(
plugin.get_ports_from_devices(self.ctx,
['%s%s' % (const.TAP_DEVICE_PREFIX, i)
for i in range(ports_to_query)])
all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
all_call_args = map(lambda x: x[1][1], get_mock.mock_calls)
last_call_args = all_call_args.pop()
# all but last should be getting MAX_PORTS_PER_QUERY ports
self.assertTrue(
@ -139,14 +141,14 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
# have one matching 'IN' critiera for all of the IDs
with contextlib.nested(
mock.patch('neutron.plugins.ml2.db.or_'),
mock.patch('neutron.plugins.ml2.db.db_api.get_session')
) as (or_mock, sess_mock):
qmock = sess_mock.return_value.query
mock.patch('sqlalchemy.orm.Session.query')
) as (or_mock, qmock):
fmock = qmock.return_value.outerjoin.return_value.filter
# return no ports to exit the method early since we are mocking
# the query
fmock.return_value = []
plugin.get_ports_from_devices([test_base._uuid(),
plugin.get_ports_from_devices(self.ctx,
[test_base._uuid(),
test_base._uuid()])
# the or_ function should only have one argument
or_mock.assert_called_once_with(mock.ANY)

View File

@ -89,7 +89,8 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
req.get_response(self.api))
port_id = res['port']['id']
plugin = manager.NeutronManager.get_plugin()
port_dict = plugin.get_port_from_device(port_id)
port_dict = plugin.get_port_from_device(mock.Mock(),
port_id)
self.assertEqual(port_id, port_dict['id'])
self.assertEqual([security_group_id],
port_dict[ext_sg.SECURITYGROUPS])
@ -101,5 +102,5 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
def test_security_group_get_port_from_device_with_no_port(self):
plugin = manager.NeutronManager.get_plugin()
port_dict = plugin.get_port_from_device('bad_device_id')
port_dict = plugin.get_port_from_device(mock.Mock(), 'bad_device_id')
self.assertIsNone(port_dict)