Change passing session to context in segments db functions

We have a number of functions that expect to get session as one
of the arguments. Passing session is not correct and this prevents
using new enginefacade which expects context to be passed and
session will be injected in current context.

NeutronLibImpact

Partially-Implements blueprint: enginefacade-switch

Change-Id: Ie1c54138ceaf9ac6f0674ad2786d4aaea9c80f73
This commit is contained in:
Ann Kamyshnikova 2016-12-02 13:59:08 +03:00
parent 07dfea9a02
commit 12b0e16692
9 changed files with 49 additions and 55 deletions

View File

@ -64,17 +64,17 @@ def add_network_segment(context, network_id, segment, segment_index=0,
'network_id': record.network_id})
def get_network_segments(session, network_id, filter_dynamic=False):
def get_network_segments(context, network_id, filter_dynamic=False):
return get_networks_segments(
session, [network_id], filter_dynamic)[network_id]
context, [network_id], filter_dynamic)[network_id]
def get_networks_segments(session, network_ids, filter_dynamic=False):
def get_networks_segments(context, network_ids, filter_dynamic=False):
if not network_ids:
return {}
with session.begin(subtransactions=True):
query = (session.query(segments_model.NetworkSegment).
with context.session.begin(subtransactions=True):
query = (context.session.query(segments_model.NetworkSegment).
filter(segments_model.NetworkSegment.network_id
.in_(network_ids)).
order_by(segments_model.NetworkSegment.segment_index))
@ -87,10 +87,10 @@ def get_networks_segments(session, network_ids, filter_dynamic=False):
return result
def get_segment_by_id(session, segment_id):
with session.begin(subtransactions=True):
def get_segment_by_id(context, segment_id):
with context.session.begin(subtransactions=True):
try:
record = (session.query(segments_model.NetworkSegment).
record = (context.session.query(segments_model.NetworkSegment).
filter_by(id=segment_id).
one())
return _make_segment_dict(record)
@ -98,11 +98,11 @@ def get_segment_by_id(session, segment_id):
return
def get_dynamic_segment(session, network_id, physical_network=None,
def get_dynamic_segment(context, network_id, physical_network=None,
segmentation_id=None):
"""Return a dynamic segment for the filters provided if one exists."""
with session.begin(subtransactions=True):
query = (session.query(segments_model.NetworkSegment).
with context.session.begin(subtransactions=True):
query = (context.session.query(segments_model.NetworkSegment).
filter_by(network_id=network_id, is_dynamic=True))
if physical_network:
query = query.filter_by(physical_network=physical_network)
@ -123,10 +123,10 @@ def get_dynamic_segment(session, network_id, physical_network=None,
return None
def delete_network_segment(session, segment_id):
def delete_network_segment(context, segment_id):
"""Release a dynamic segment for the params provided if one exists."""
with session.begin(subtransactions=True):
(session.query(segments_model.NetworkSegment).
with context.session.begin(subtransactions=True):
(context.session.query(segments_model.NetworkSegment).
filter_by(id=segment_id).delete())

View File

@ -43,7 +43,7 @@ class NetworkContext(MechanismDriverContext, api.NetworkContext):
self._network = network
self._original_network = original_network
self._segments = segments_db.get_network_segments(
plugin_context.session, network['id'])
plugin_context, network['id'])
@property
def current(self):
@ -192,7 +192,7 @@ class PortContext(MechanismDriverContext, api.PortContext):
self._original_binding_levels[-1].segment_id)
def _expand_segment(self, segment_id):
segment = segments_db.get_segment_by_id(self._plugin_context.session,
segment = segments_db.get_segment_by_id(self._plugin_context,
segment_id)
if not segment:
LOG.warning(_LW("Could not expand segment %s"), segment_id)

View File

@ -306,8 +306,7 @@ class DNSExtensionDriverML2(DNSExtensionDriver):
return True
if network['router:external']:
return True
segments = segments_db.get_network_segments(context.session,
network['id'])
segments = segments_db.get_network_segments(context, network['id'])
if len(segments) > 1:
return False
provider_net = segments[0]

View File

@ -156,7 +156,7 @@ class TypeManager(stevedore.named.NamedExtensionManager):
def extend_networks_dict_provider(self, context, networks):
ids = [network['id'] for network in networks]
net_segments = segments_db.get_networks_segments(context.session, ids)
net_segments = segments_db.get_networks_segments(context, ids)
for network in networks:
segments = net_segments[network['id']]
self._extend_network_dict_provider(network, segments)
@ -278,8 +278,7 @@ class TypeManager(stevedore.named.NamedExtensionManager):
raise exc.NoNetworkAvailable()
def release_network_segments(self, context, network_id):
segments = segments_db.get_network_segments(context.session,
network_id,
segments = segments_db.get_network_segments(context, network_id,
filter_dynamic=None)
for segment in segments:
@ -300,7 +299,7 @@ class TypeManager(stevedore.named.NamedExtensionManager):
def allocate_dynamic_segment(self, context, network_id, segment):
"""Allocate a dynamic segment using a partial or full segment dict."""
dynamic_segment = segments_db.get_dynamic_segment(
context.session, network_id, segment.get(api.PHYSICAL_NETWORK),
context, network_id, segment.get(api.PHYSICAL_NETWORK),
segment.get(api.SEGMENTATION_ID))
if dynamic_segment:
@ -313,14 +312,13 @@ class TypeManager(stevedore.named.NamedExtensionManager):
else:
dynamic_segment = driver.obj.reserve_provider_segment(
context, segment)
segments_db.add_network_segment(context,
network_id, dynamic_segment,
segments_db.add_network_segment(context, network_id, dynamic_segment,
is_dynamic=True)
return dynamic_segment
def release_dynamic_segment(self, context, segment_id):
"""Delete a dynamic segment."""
segment = segments_db.get_segment_by_id(context.session, segment_id)
segment = segments_db.get_segment_by_id(context, segment_id)
if segment:
driver = self.drivers.get(segment.get(api.NETWORK_TYPE))
if driver:
@ -328,7 +326,7 @@ class TypeManager(stevedore.named.NamedExtensionManager):
driver.obj.release_segment(context.session, segment)
else:
driver.obj.release_segment(context, segment)
segments_db.delete_network_segment(context.session, segment_id)
segments_db.delete_network_segment(context, segment_id)
else:
LOG.error(_LE("Failed to release segment '%s' because "
"network type is not supported."), segment)

View File

@ -1845,8 +1845,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
def filter_hosts_with_network_access(
self, context, network_id, candidate_hosts):
segments = segments_db.get_network_segments(context.session,
network_id)
segments = segments_db.get_network_segments(context, network_id)
return self.mechanism_manager.filter_hosts_with_segment_access(
context, segments, candidate_hosts, self.get_agents)

View File

@ -23,7 +23,7 @@ from neutron.tests import base
class TestSegmentsDb(base.BaseTestCase):
def test_get_networks_segments_with_empty_networks(self):
session = mock.MagicMock()
net_segs = segments_db.get_networks_segments(session, [])
self.assertFalse(session.query.called)
context = mock.MagicMock()
net_segs = segments_db.get_networks_segments(context, [])
self.assertFalse(context.session.query.called)
self.assertEqual({}, net_segs)

View File

@ -276,8 +276,7 @@ class TestSegment(SegmentTestCase):
network_id=network['id'],
segmentation_id=200)
network_segments = segments_db.get_network_segments(cxt.session,
network['id'])
network_segments = segments_db.get_network_segments(cxt, network['id'])
self.assertEqual([], network_segments)
def test_create_segments_in_certain_order(self):
@ -290,7 +289,7 @@ class TestSegment(SegmentTestCase):
network_id=network['id'], segmentation_id=201)
segment3 = self.segment(
network_id=network['id'], segmentation_id=202)
network_segments = segments_db.get_network_segments(cxt.session,
network_segments = segments_db.get_network_segments(cxt,
network['id'])
self.assertEqual(segment1['segment']['id'],
network_segments[0]['id'])

View File

@ -74,7 +74,7 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
is_dynamic=is_seg_dynamic)
net_segments = segments_db.get_network_segments(
self.ctx.session, network_id,
self.ctx, network_id,
filter_dynamic=is_seg_dynamic)
net_segments = self._sort_segments(net_segments)
@ -120,7 +120,7 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
net1segs = self._create_segments(segments1, network_id='net1')
net2segs = self._create_segments(segments2, network_id='net2')
segs = segments_db.get_networks_segments(
self.ctx.session, ['net1', 'net2'])
self.ctx, ['net1', 'net2'])
self.assertEqual(net1segs, self._sort_segments(segs['net1']))
self.assertEqual(net2segs, self._sort_segments(segs['net2']))
@ -128,7 +128,7 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
self._create_segments([], network_id='net1')
self._create_segments([], network_id='net2')
segs = segments_db.get_networks_segments(
self.ctx.session, ['net1', 'net2'])
self.ctx, ['net1', 'net2'])
self.assertEqual([], segs['net1'])
self.assertEqual([], segs['net2'])
@ -140,13 +140,13 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
net_segment = self._create_segments([segment])[0]
segment_uuid = net_segment[api.ID]
net_segment = segments_db.get_segment_by_id(self.ctx.session,
net_segment = segments_db.get_segment_by_id(self.ctx,
segment_uuid)
self.assertEqual(segment, net_segment)
def test_get_segment_by_id_result_not_found(self):
segment_uuid = uuidutils.generate_uuid()
net_segment = segments_db.get_segment_by_id(self.ctx.session,
net_segment = segments_db.get_segment_by_id(self.ctx,
segment_uuid)
self.assertIsNone(net_segment)
@ -158,9 +158,9 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
net_segment = self._create_segments([segment])[0]
segment_uuid = net_segment[api.ID]
segments_db.delete_network_segment(self.ctx.session, segment_uuid)
segments_db.delete_network_segment(self.ctx, segment_uuid)
# Get segment and verify its empty
net_segment = segments_db.get_segment_by_id(self.ctx.session,
net_segment = segments_db.get_segment_by_id(self.ctx,
segment_uuid)
self.assertIsNone(net_segment)

View File

@ -1775,7 +1775,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
self.driver.type_manager.allocate_dynamic_segment(
self.context, network_id, segment)
dynamic_segment = segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet1')
self.context, network_id, 'physnet1')
self.assertEqual('vlan', dynamic_segment[driver_api.NETWORK_TYPE])
self.assertEqual('physnet1',
dynamic_segment[driver_api.PHYSICAL_NETWORK])
@ -1786,7 +1786,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
self.driver.type_manager.allocate_dynamic_segment(
self.context, network_id, segment2)
dynamic_segment = segments_db.get_dynamic_segment(
self.context.session, network_id, segmentation_id='1234')
self.context, network_id, segmentation_id='1234')
self.assertEqual('vlan', dynamic_segment[driver_api.NETWORK_TYPE])
self.assertEqual('physnet3',
dynamic_segment[driver_api.PHYSICAL_NETWORK])
@ -1804,14 +1804,14 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
self.driver.type_manager.allocate_dynamic_segment(
self.context, network_id, segment)
dynamic_segment = segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet1')
self.context, network_id, 'physnet1')
self.assertEqual('vlan', dynamic_segment[driver_api.NETWORK_TYPE])
self.assertEqual('physnet1',
dynamic_segment[driver_api.PHYSICAL_NETWORK])
dynamic_segmentation_id = dynamic_segment[driver_api.SEGMENTATION_ID]
self.assertGreater(dynamic_segmentation_id, 0)
dynamic_segment1 = segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet1')
self.context, network_id, 'physnet1')
dynamic_segment1_id = dynamic_segment1[driver_api.SEGMENTATION_ID]
self.assertEqual(dynamic_segmentation_id, dynamic_segment1_id)
segment2 = {driver_api.NETWORK_TYPE: 'vlan',
@ -1819,7 +1819,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
self.driver.type_manager.allocate_dynamic_segment(
self.context, network_id, segment2)
dynamic_segment2 = segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet2')
self.context, network_id, 'physnet2')
dynamic_segmentation2_id = dynamic_segment2[driver_api.SEGMENTATION_ID]
self.assertNotEqual(dynamic_segmentation_id, dynamic_segmentation2_id)
@ -1835,7 +1835,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
self.driver.type_manager.allocate_dynamic_segment(
self.context, network_id, segment)
dynamic_segment = segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet1')
self.context, network_id, 'physnet1')
self.assertEqual('vlan', dynamic_segment[driver_api.NETWORK_TYPE])
self.assertEqual('physnet1',
dynamic_segment[driver_api.PHYSICAL_NETWORK])
@ -1844,7 +1844,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
self.driver.type_manager.release_dynamic_segment(
self.context, dynamic_segment[driver_api.ID])
self.assertIsNone(segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet1'))
self.context, network_id, 'physnet1'))
def test_create_network_provider(self):
data = {'network': {'name': 'net1',
@ -1972,7 +1972,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
self.driver.type_manager.allocate_dynamic_segment(
self.context, network_id, segment)
dynamic_segment = segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet2')
self.context, network_id, 'physnet2')
self.assertEqual('vlan', dynamic_segment[driver_api.NETWORK_TYPE])
self.assertEqual('physnet2',
dynamic_segment[driver_api.PHYSICAL_NETWORK])
@ -1985,9 +1985,9 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
res = req.get_response(self.api)
self.assertEqual(2, rs.call_count)
self.assertEqual([], segments_db.get_network_segments(
self.context.session, network_id))
self.context, network_id))
self.assertIsNone(segments_db.get_dynamic_segment(
self.context.session, network_id, 'physnet2'))
self.context, network_id, 'physnet2'))
def test_release_segment_no_type_driver(self):
data = {'network': {'name': 'net1',
@ -2681,10 +2681,9 @@ class TestML2Segments(Ml2PluginV2TestCase):
with self.network() as network:
network = network['network']
for stale_seg in segments_db.get_network_segments(self.context.session,
for stale_seg in segments_db.get_network_segments(self.context,
network['id']):
segments_db.delete_network_segment(self.context.session,
stale_seg['id'])
segments_db.delete_network_segment(self.context, stale_seg['id'])
for seg in [seg1, seg2, seg3]:
seg['network_id'] = network['id']
@ -2735,7 +2734,7 @@ class TestML2Segments(Ml2PluginV2TestCase):
binding, None)
plugin._bind_port_if_needed(mech_context)
segment = segments_db.get_network_segments(
self.context.session, port['port']['network_id'])[0]
self.context, port['port']['network_id'])[0]
segment['network_id'] = port['port']['network_id']
self.assertRaises(c_exc.CallbackFailure, registry.notify,
resources.SEGMENT, events.BEFORE_DELETE,