Merge "Refactor type_tunnel/gre/vxlan to reduce duplicate code"

This commit is contained in:
Jenkins 2015-06-03 23:10:11 +00:00 committed by Gerrit Code Review
commit 8b427ae869
4 changed files with 59 additions and 79 deletions

View File

@ -66,10 +66,11 @@ class GreEndpoints(model_base.BASEV2):
return "<GreTunnelEndpoint(%s)>" % self.ip_address
class GreTypeDriver(type_tunnel.TunnelTypeDriver):
class GreTypeDriver(type_tunnel.EndpointTunnelTypeDriver):
def __init__(self):
super(GreTypeDriver, self).__init__(GreAllocation)
super(GreTypeDriver, self).__init__(
GreAllocation, GreEndpoints)
def get_type(self):
return p_const.TYPE_GRE
@ -127,45 +128,13 @@ class GreTypeDriver(type_tunnel.TunnelTypeDriver):
def get_endpoints(self):
"""Get every gre endpoints from database."""
LOG.debug("get_gre_endpoints() called")
session = db_api.get_session()
gre_endpoints = session.query(GreEndpoints)
gre_endpoints = self._get_endpoints()
return [{'ip_address': gre_endpoint.ip_address,
'host': gre_endpoint.host}
for gre_endpoint in gre_endpoints]
def get_endpoint_by_host(self, host):
LOG.debug("get_endpoint_by_host() called for host %s", host)
session = db_api.get_session()
return (session.query(GreEndpoints).
filter_by(host=host).first())
def get_endpoint_by_ip(self, ip):
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
session = db_api.get_session()
return (session.query(GreEndpoints).
filter_by(ip_address=ip).first())
def add_endpoint(self, ip, host):
LOG.debug("add_gre_endpoint() called for ip %s", ip)
session = db_api.get_session()
try:
gre_endpoint = GreEndpoints(ip_address=ip, host=host)
gre_endpoint.save(session)
except db_exc.DBDuplicateEntry:
gre_endpoint = (session.query(GreEndpoints).
filter_by(ip_address=ip).one())
LOG.warning(_LW("Gre endpoint with ip %s already exists"), ip)
return gre_endpoint
def delete_endpoint(self, ip):
LOG.debug("delete_gre_endpoint() called for ip %s", ip)
session = db_api.get_session()
with session.begin(subtransactions=True):
session.query(GreEndpoints).filter_by(ip_address=ip).delete()
return self._add_endpoint(ip, host)
def get_mtu(self, physical_network=None):
mtu = super(GreTypeDriver, self).get_mtu(physical_network)

View File

@ -15,10 +15,12 @@
import abc
from oslo_config import cfg
from oslo_db import exception as db_exc
from oslo_log import log
from neutron.common import exceptions as exc
from neutron.common import topics
from neutron.db import api as db_api
from neutron.i18n import _LI, _LW
from neutron.plugins.common import utils as plugin_utils
from neutron.plugins.ml2 import driver_api as api
@ -196,6 +198,50 @@ class TunnelTypeDriver(helpers.SegmentTypeDriver):
return min(mtu) if mtu else 0
class EndpointTunnelTypeDriver(TunnelTypeDriver):
def __init__(self, segment_model, endpoint_model):
super(EndpointTunnelTypeDriver, self).__init__(segment_model)
self.endpoint_model = endpoint_model
self.segmentation_key = iter(self.primary_keys).next()
def get_endpoint_by_host(self, host):
LOG.debug("get_endpoint_by_host() called for host %s", host)
session = db_api.get_session()
return (session.query(self.endpoint_model).
filter_by(host=host).first())
def get_endpoint_by_ip(self, ip):
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
session = db_api.get_session()
return (session.query(self.endpoint_model).
filter_by(ip_address=ip).first())
def delete_endpoint(self, ip):
LOG.debug("delete_endpoint() called for ip %s", ip)
session = db_api.get_session()
with session.begin(subtransactions=True):
(session.query(self.endpoint_model).
filter_by(ip_address=ip).delete())
def _get_endpoints(self):
LOG.debug("_get_endpoints() called")
session = db_api.get_session()
return session.query(self.endpoint_model)
def _add_endpoint(self, ip, host, **kwargs):
LOG.debug("_add_endpoint() called for ip %s", ip)
session = db_api.get_session()
try:
endpoint = self.endpoint_model(ip_address=ip, host=host, **kwargs)
endpoint.save(session)
except db_exc.DBDuplicateEntry:
endpoint = (session.query(self.endpoint_model).
filter_by(ip_address=ip).one())
LOG.warning(_LW("Endpoint with ip %s already exists"), ip)
return endpoint
class TunnelRpcCallbackMixin(object):
def setup_tunnel_callback_mixin(self, notifier, type_manager):

View File

@ -14,7 +14,6 @@
# under the License.
from oslo_config import cfg
from oslo_db import exception as db_exc
from oslo_log import log
from six import moves
import sqlalchemy as sa
@ -23,7 +22,7 @@ from sqlalchemy import sql
from neutron.common import exceptions as n_exc
from neutron.db import api as db_api
from neutron.db import model_base
from neutron.i18n import _LE, _LW
from neutron.i18n import _LE
from neutron.plugins.common import constants as p_const
from neutron.plugins.ml2.drivers import type_tunnel
@ -70,10 +69,11 @@ class VxlanEndpoints(model_base.BASEV2):
return "<VxlanTunnelEndpoint(%s)>" % self.ip_address
class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
class VxlanTypeDriver(type_tunnel.EndpointTunnelTypeDriver):
def __init__(self):
super(VxlanTypeDriver, self).__init__(VxlanAllocation)
super(VxlanTypeDriver, self).__init__(
VxlanAllocation, VxlanEndpoints)
def get_type(self):
return p_const.TYPE_VXLAN
@ -132,48 +132,14 @@ class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
def get_endpoints(self):
"""Get every vxlan endpoints from database."""
LOG.debug("get_vxlan_endpoints() called")
session = db_api.get_session()
vxlan_endpoints = session.query(VxlanEndpoints)
vxlan_endpoints = self._get_endpoints()
return [{'ip_address': vxlan_endpoint.ip_address,
'udp_port': vxlan_endpoint.udp_port,
'host': vxlan_endpoint.host}
for vxlan_endpoint in vxlan_endpoints]
def get_endpoint_by_host(self, host):
LOG.debug("get_endpoint_by_host() called for host %s", host)
session = db_api.get_session()
return (session.query(VxlanEndpoints).
filter_by(host=host).first())
def get_endpoint_by_ip(self, ip):
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
session = db_api.get_session()
return (session.query(VxlanEndpoints).
filter_by(ip_address=ip).first())
def add_endpoint(self, ip, host, udp_port=p_const.VXLAN_UDP_PORT):
LOG.debug("add_vxlan_endpoint() called for ip %s", ip)
session = db_api.get_session()
try:
vxlan_endpoint = VxlanEndpoints(ip_address=ip,
udp_port=udp_port,
host=host)
vxlan_endpoint.save(session)
except db_exc.DBDuplicateEntry:
vxlan_endpoint = (session.query(VxlanEndpoints).
filter_by(ip_address=ip).one())
LOG.warning(_LW("Vxlan endpoint with ip %s already exists"), ip)
return vxlan_endpoint
def delete_endpoint(self, ip):
LOG.debug("delete_vxlan_endpoint() called for ip %s", ip)
session = db_api.get_session()
with session.begin(subtransactions=True):
session.query(VxlanEndpoints).filter_by(ip_address=ip).delete()
return self._add_endpoint(ip, host, udp_port=udp_port)
def get_mtu(self, physical_network=None):
mtu = super(VxlanTypeDriver, self).get_mtu()

View File

@ -21,6 +21,7 @@ from testtools import matchers
from neutron.common import exceptions as exc
from neutron.db import api as db
from neutron.plugins.ml2 import driver_api as api
from neutron.plugins.ml2.drivers import type_tunnel
TUNNEL_IP_ONE = "10.10.10.10"
TUNNEL_IP_TWO = "10.10.10.20"
@ -33,7 +34,6 @@ UPDATED_TUNNEL_RANGES = [(TUN_MIN + 5, TUN_MAX + 5)]
class TunnelTypeTestMixin(object):
DRIVER_MODULE = None
DRIVER_CLASS = None
TYPE = None
@ -208,8 +208,7 @@ class TunnelTypeTestMixin(object):
def test_add_endpoint_for_existing_tunnel_ip(self):
self.add_endpoint()
log = getattr(self.DRIVER_MODULE, 'LOG')
with mock.patch.object(log, 'warning') as log_warn:
with mock.patch.object(type_tunnel.LOG, 'warning') as log_warn:
self.add_endpoint()
log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)