diff --git a/networking_l2gw/db/l2gateway/l2gateway_db.py b/networking_l2gw/db/l2gateway/l2gateway_db.py index f1fdf09..3fce034 100644 --- a/networking_l2gw/db/l2gateway/l2gateway_db.py +++ b/networking_l2gw/db/l2gateway/l2gateway_db.py @@ -25,6 +25,8 @@ from networking_l2gw.services.l2gateway import exceptions as l2gw_exc from neutron_lib.callbacks import events from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources +from neutron_lib.db import api as db_api +from neutron_lib.db import model_query from neutron_lib import exceptions from neutron_lib.plugins import directory from oslo_log import log as logging @@ -42,19 +44,28 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, connection_resource = constants.CONNECTION_RESOURCE_NAME config.register_l2gw_opts_helper() + @db_api.retry_if_session_inactive() + @db_api.CONTEXT_READER def _get_l2_gateway(self, context, gw_id): - gw = context.session.query(models.L2Gateway).get(gw_id) - if not gw: + try: + gw = model_query.get_by_id(context, models.L2Gateway, gw_id) + except sa_orm_exc.NoResultFound: raise l2gw_exc.L2GatewayNotFound(gateway_id=gw_id) return gw + @db_api.retry_if_session_inactive() + @db_api.CONTEXT_READER def _get_l2_gateways(self, context): - return context.session.query(models.L2Gateway).all() + return model_query.get_collection(context, models.L2Gateway, + dict_func=None) + @db_api.retry_if_session_inactive() + @db_api.CONTEXT_READER def _get_l2_gw_interfaces(self, context, id): return context.session.query(models.L2GatewayInterface).filter_by( device_id=id).all() + @db_api.CONTEXT_READER def _is_vlan_configured_on_any_interface_for_l2gw(self, context, l2gw_id): @@ -70,6 +81,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, return True return False + @db_api.retry_if_session_inactive() def _get_l2_gateway_devices(self, context, l2gw_id): return context.session.query(models.L2GatewayDevice).filter_by( l2_gateway_id=l2gw_id).all() @@ -143,7 +155,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, gw = l2_gateway[self.gateway_resource] tenant_id = self._get_tenant_id_for_create(context, gw) devices = gw['devices'] - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): gw_db = models.L2Gateway( id=gw.get('id', uuidutils.generate_uuid()), tenant_id=tenant_id, @@ -181,6 +193,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, context.session.query(models.L2GatewayDevice).all() return self._make_l2_gateway_dict(gw_db) + @db_api.CONTEXT_WRITER def update_l2_gateway(self, context, id, l2_gateway): """Update L2Gateway.""" gw = l2_gateway[self.gateway_resource] @@ -228,6 +241,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, 0) context.session.add(interface_db) + @db_api.CONTEXT_READER def get_l2_gateway(self, context, id, fields=None): """get the l2 gateway by id.""" self._admin_check(context, 'GET') @@ -238,10 +252,11 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, """delete the l2 gateway by id.""" gw_db = self._get_l2_gateway(context, id) if gw_db: - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): context.session.delete(gw_db) LOG.debug("l2 gateway '%s' was deleted.", id) + @db_api.CONTEXT_READER def get_l2_gateways(self, context, filters=None, fields=None, sorts=None, limit=None, @@ -269,7 +284,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, def _delete_l2_gateway_interfaces(self, context, int_db_list): """delete the l2 interfaces by id.""" - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): for interfaces in int_db_list: context.session.delete(interfaces) LOG.debug("l2 gateway interfaces was deleted.") @@ -286,7 +301,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, if constants.SEG_ID in gw_connection: segmentation_id = gw_connection.get(constants.SEG_ID) nw_map[constants.SEG_ID] = segmentation_id - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): gw_db = self._get_l2_gateway(context, l2_gw_id) tenant_id = self._get_tenant_id_for_create(context, gw_db) nw_map['tenant_id'] = tenant_id @@ -308,11 +323,13 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, return self._make_l2gw_connections_dict(gw_db) return self._make_l2gw_connections_dict(conn_db[0]) + @db_api.CONTEXT_READER def get_l2_gateway_connections_count(self, context, filters=None): return len(self._get_collection(context, models.L2GatewayConnection, self._make_l2gw_connections_dict, filters=filters)) + @db_api.CONTEXT_READER def get_l2_gateway_connections(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, @@ -328,6 +345,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, marker_obj=marker_obj, page_reverse=page_reverse) + @db_api.CONTEXT_READER def get_l2_gateway_connection(self, context, id, fields=None): """Get l2 gateway connection.""" self._admin_check(context, 'GET') @@ -428,6 +446,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, id="") return con + @db_api.CONTEXT_READER def _get_l2gw_ids_by_interface_switch(self, context, interface_name, switch_name): """Get l2 gateway ids by interface and switch.""" @@ -503,7 +522,7 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, devices = gw['devices'] if not devices: return - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_READER.using(context): # Attemp to retrieve l2gw gw_db = self._get_l2_gateway(context, id) if devices: @@ -554,13 +573,14 @@ class L2GatewayMixin(l2gateway.L2GatewayPluginBase, l2_gw_id) network_id = l2gw_validators.validate_network_mapping_list(nw_map, is_vlan) - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): if self._retrieve_gateway_connections(context, l2_gw_id, nw_map): raise l2gw_exc.L2GatewayConnectionExists(mapping=nw_map, gateway_id=l2_gw_id) + @db_api.CONTEXT_READER def validate_l2_gateway_connection_for_delete(self, context, l2_gateway_conn_id): self._admin_check(context, 'DELETE') diff --git a/networking_l2gw/db/l2gateway/ovsdb/lib.py b/networking_l2gw/db/l2gateway/ovsdb/lib.py index 4f06352..198be68 100644 --- a/networking_l2gw/db/l2gateway/ovsdb/lib.py +++ b/networking_l2gw/db/l2gateway/ovsdb/lib.py @@ -17,6 +17,8 @@ from oslo_utils import timeutils from sqlalchemy import asc from sqlalchemy.orm import exc +from neutron_lib.db import api as db_api + from networking_l2gw.db.l2gateway.ovsdb import models LOG = logging.getLogger(__name__) @@ -490,8 +492,8 @@ def get_pending_ucast_mac_remote(context, ovsdb_identifier, mac, def get_all_pending_remote_macs_in_asc_order(context, ovsdb_identifier): """Get all the pending remote macs in ascending order of timestamp.""" - session = context.session - with session.begin(): + with db_api.CONTEXT_READER.using(context): + session = context.session return session.query( models.PendingUcastMacsRemote ).filter_by(ovsdb_identifier=ovsdb_identifier @@ -501,8 +503,8 @@ def get_all_pending_remote_macs_in_asc_order(context, ovsdb_identifier): def get_all_ucast_mac_remote_by_ls(context, record_dict): """Get ucast macs remote that matches ls_id and ovsdb_identifier.""" - session = context.session - with session.begin(): + with db_api.CONTEXT_READER.using(context): + session = context.session return session.query(models.UcastMacsRemotes).filter_by( ovsdb_identifier=record_dict['ovsdb_identifier'], logical_switch_id=record_dict['logical_switch_id']).all() diff --git a/networking_l2gw/services/l2gateway/plugin.py b/networking_l2gw/services/l2gateway/plugin.py index 11c912e..9d43dc6 100644 --- a/networking_l2gw/services/l2gateway/plugin.py +++ b/networking_l2gw/services/l2gateway/plugin.py @@ -25,6 +25,7 @@ from networking_l2gw.services.l2gateway.common import config from networking_l2gw.services.l2gateway.common import constants from networking_l2gw.services.l2gateway import exceptions as exc +from neutron_lib.db import api as db_api from neutron_lib import exceptions as n_exc from oslo_log import log as logging @@ -94,7 +95,7 @@ class L2GatewayPlugin(l2gateway_db.L2GatewayMixin): """Create the L2Gateway.""" self.validate_l2_gateway_for_create(context, l2_gateway) self.driver.create_l2_gateway(context, l2_gateway) - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): l2_gateway_instance = super(L2GatewayPlugin, self).create_l2_gateway(context, l2_gateway) @@ -155,7 +156,7 @@ class L2GatewayPlugin(l2gateway_db.L2GatewayMixin): context, l2_gateway_connection) self.driver.create_l2_gateway_connection(context, l2_gateway_connection) - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): l2_gateway_conn_instance = super( L2GatewayPlugin, self).create_l2_gateway_connection( context, l2_gateway_connection) @@ -184,7 +185,7 @@ class L2GatewayPlugin(l2gateway_db.L2GatewayMixin): context, l2_gateway_connection_id) self.driver.delete_l2_gateway_connection( context, l2_gateway_connection_id) - with context.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(context): super(L2GatewayPlugin, self).delete_l2_gateway_connection( context, l2_gateway_connection_id) self.driver.delete_l2_gateway_connection_precommit( diff --git a/networking_l2gw/tests/unit/db/ovsdb/test_lib.py b/networking_l2gw/tests/unit/db/ovsdb/test_lib.py index e0dd421..e4ef8e7 100644 --- a/networking_l2gw/tests/unit/db/ovsdb/test_lib.py +++ b/networking_l2gw/tests/unit/db/ovsdb/test_lib.py @@ -19,6 +19,7 @@ from oslo_utils import uuidutils from neutron.tests.unit import testlib_api from neutron_lib import context +from neutron_lib.db import api as db_api from networking_l2gw.db.l2gateway.ovsdb import lib from networking_l2gw.db.l2gateway.ovsdb import models @@ -292,7 +293,7 @@ class OvsdbLibTestCase(testlib_api.SqlTestCase): if mac and logical_switch_uuid: record_dict['mac'] = mac record_dict['logical_switch_id'] = logical_switch_uuid - with self.ctx.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(self.ctx): entry = models.UcastMacsRemotes( uuid=record_dict['uuid'], mac=record_dict['mac'], @@ -305,10 +306,10 @@ class OvsdbLibTestCase(testlib_api.SqlTestCase): def test_get_ucast_mac_remote(self): record_dict = self._get_ucast_mac_remote_dict() - with self.ctx.session.begin(subtransactions=True): - entry = self._create_ucast_mac_remote(record_dict) + entry = self._create_ucast_mac_remote(record_dict) result = lib.get_ucast_mac_remote(self.ctx, record_dict) - self.assertEqual(entry, result) + for ent_key, ent_val in entry.items(): + self.assertEqual(ent_val, result[ent_key]) def test_add_ucast_mac_remote(self): record_dict = self._get_ucast_mac_remote_dict() @@ -450,15 +451,15 @@ class OvsdbLibTestCase(testlib_api.SqlTestCase): def test_get_ucast_mac_remote_by_mac_and_ls(self): record_dict = self._get_ucast_mac_remote_dict() - with self.ctx.session.begin(subtransactions=True): - entry = self._create_ucast_mac_remote(record_dict, - '00:11:22:33:44:55:66', - 'ls123') + entry = self._create_ucast_mac_remote(record_dict, + '00:11:22:33:44:55:66', + 'ls123') record_dict['mac'] = '00:11:22:33:44:55:66' record_dict['logical_switch_uuid'] = 'ls123' result = lib.get_ucast_mac_remote_by_mac_and_ls(self.ctx, record_dict) - self.assertEqual(entry, result) + for ent_key, ent_val in entry.items(): + self.assertEqual(ent_val, result[ent_key]) def test_get_ucast_mac_remote_by_mac_and_ls_when_not_found(self): record_dict = self._get_ucast_mac_remote_dict() @@ -529,13 +530,15 @@ class OvsdbLibTestCase(testlib_api.SqlTestCase): record_dict1 = self._get_pending_mac_dict(timestamp1) timestamp2 = timeutils.utcnow() record_dict2 = self._get_pending_mac_dict(timestamp2) - with self.ctx.session.begin(subtransactions=True): - entry1 = self._create_pending_mac(record_dict1) - entry2 = self._create_pending_mac(record_dict2) + entry = [] + with db_api.CONTEXT_WRITER.using(self.ctx): + entry.append(self._create_pending_mac(record_dict1)) + entry.append(self._create_pending_mac(record_dict2)) result = lib.get_all_pending_remote_macs_in_asc_order( self.ctx, record_dict1['ovsdb_identifier']) - self.assertEqual(result[0], entry1) - self.assertEqual(result[1], entry2) + for index, res in enumerate(result): + for k, v in res.items(): + self.assertEqual(res[k], entry[index][k]) def test_delete_pending_ucast_mac_remote(self): timestamp = timeutils.utcnow() diff --git a/networking_l2gw/tests/unit/db/test_l2gw_db.py b/networking_l2gw/tests/unit/db/test_l2gw_db.py index 15654b8..c8dba9a 100644 --- a/networking_l2gw/tests/unit/db/test_l2gw_db.py +++ b/networking_l2gw/tests/unit/db/test_l2gw_db.py @@ -19,6 +19,7 @@ from neutron.tests.unit import testlib_api from neutron_lib.callbacks import events from neutron_lib.callbacks import resources from neutron_lib import context +from neutron_lib.db import api as db_api from networking_l2gw.db.l2gateway import l2gateway_db from networking_l2gw.services.l2gateway.common import constants @@ -48,7 +49,7 @@ class L2GWTestCase(testlib_api.SqlTestCase): def _create_l2gateway(self, l2gateway): """Create l2gateway helper method.""" - with self.ctx.session.begin(subtransactions=True): + with db_api.CONTEXT_WRITER.using(self.ctx): return self.mixin.create_l2_gateway(self.ctx, l2gateway) def _get_l2_gateway_data(self, name, device_name):