Add type annotations

Add type annotations to the different parameters
and return values to modernize the python used.
This will also introduce mypy as another tool
for static code analysis which will currently not
run in CI

Change-Id: Ic09e47673f916328568c413d0e8485d36c283c24
This commit is contained in:
Niklas Schwarz 2024-03-06 09:57:55 +01:00
parent 8704b2c7c0
commit f63f99d822
40 changed files with 1105 additions and 708 deletions

View File

@ -15,26 +15,22 @@ ignore=.git,tests,openstack
disable=
# "F" Fatal errors that prevent further processing
import-error,
import-untyped,
# "I" Informational noise
locally-disabled,
# "E" Error for important programming issues (likely bugs)
access-member-before-definition,
bad-super-call,
maybe-no-member,
no-member,
no-method-argument,
no-self-argument,
no-name-in-module,
not-callable,
no-value-for-parameter,
super-on-old-class,
too-few-format-args,
unsubscriptable-object,
# "W" Warnings for stylistic problems or minor programming issues
abstract-method,
anomalous-backslash-in-string,
anomalous-unicode-escape-in-string,
arguments-differ,
attribute-defined-outside-init,
bad-builtin,
# bad-builtin,
bad-indentation,
broad-except,
dangerous-default-value,
@ -57,7 +53,6 @@ disable=
unnecessary-lambda,
unnecessary-pass,
unpacking-non-sequence,
unreachable,
unused-argument,
unused-import,
unused-variable,
@ -67,17 +62,13 @@ disable=
bad-continuation,
invalid-name,
missing-docstring,
old-style-class,
superfluous-parens,
# "R" Refactor recommendations
abstract-class-little-used,
abstract-class-not-used,
consider-using-set-comprehension,
cyclic-import,
duplicate-code,
inconsistent-return-statements,
interface-not-implemented,
no-else-raise,
no-else-return,
no-self-use,
too-few-public-methods,
too-many-ancestors,
@ -89,7 +80,6 @@ disable=
too-many-public-methods,
too-many-return-statements,
too-many-statements,
useless-object-inheritance
[BASIC]
# Variable names can be 1 to 31 characters long, with lowercase and underscores

View File

@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as ty
import uuid
from neutron.agent.linux import external_process
from neutron.common.ovn import utils as ovn_utils
from neutron.conf.plugins.ml2.drivers.ovn import ovn_conf as config
from oslo_config import cfg
from oslo_log import log as logging
from oslo_service import service
from ovsdbapp.backend.ovs_idl import event as row_event
@ -32,44 +34,6 @@ LOG = logging.getLogger(__name__)
OVN_VPNAGENT_UUID_NAMESPACE = uuid.UUID('e1ce3b12-b1e0-4c81-ba27-07c0fec9c12b')
class ChassisCreateEventBase(row_event.RowEvent):
"""Row create event - Chassis name == our_chassis.
On connection, we get a dump of all chassis so if we catch a creation
of our own chassis it has to be a reconnection. In this case, we need
to do a full sync to make sure that we capture all changes while the
connection to OVSDB was down.
"""
table = None
def __init__(self, vpn_agent):
self.agent = vpn_agent
self.first_time = True
events = (self.ROW_CREATE,)
super().__init__(
events, self.table, (('name', '=', self.agent.chassis),))
self.event_name = self.__class__.__name__
def run(self, event, row, old):
if self.first_time:
self.first_time = False
else:
# NOTE(lucasagomes): Re-register the ovn vpn agent
# with the local chassis in case its entry was re-created
# (happens when restarting the ovn-controller)
self.agent.register_vpn_agent()
LOG.info("Connection to OVSDB established, doing a full sync")
self.agent.sync()
class ChassisCreateEvent(ChassisCreateEventBase):
table = 'Chassis'
class ChassisPrivateCreateEvent(ChassisCreateEventBase):
table = 'Chassis_Private'
class SbGlobalUpdateEvent(row_event.RowEvent):
"""Row update event on SB_Global table."""
@ -90,7 +54,7 @@ class SbGlobalUpdateEvent(row_event.RowEvent):
class OvnVpnAgent(service.Service):
def __init__(self, conf):
def __init__(self, conf: cfg.ConfigOpts):
super().__init__()
self.conf = conf
vlog.use_python_logger(max_level=config.get_ovn_ovsdb_log_level())
@ -102,13 +66,13 @@ class OvnVpnAgent(service.Service):
self.device_drivers = self.service.load_device_drivers(self.conf.host)
def _load_config(self):
self.chassis = self._get_own_chassis_name()
self.chassis: ty.Optional[str] = self._get_own_chassis_name()
try:
self.chassis_id = uuid.UUID(self.chassis)
except ValueError:
# OVS system-id could be a non UUID formatted string.
self.chassis_id = uuid.uuid5(OVN_VPNAGENT_UUID_NAMESPACE,
self.chassis)
self.chassis if self.chassis else '')
LOG.debug("Loaded chassis name %s (UUID: %s).",
self.chassis, self.chassis_id)
@ -156,12 +120,51 @@ class OvnVpnAgent(service.Service):
self.sb_idl.db_add(table, self.chassis, 'external_ids',
ext_ids).execute(check_error=True)
def _get_own_chassis_name(self):
def _get_own_chassis_name(self) -> ty.Optional[str]:
"""Return the external_ids:system-id value of the Open_vSwitch table.
As long as ovn-controller is running on this node, the key is
guaranteed to exist and will include the chassis name.
"""
ext_ids = self.ovs_idl.db_get(
ext_ids: ty.Optional[ty.Dict[str, str]] = self.ovs_idl.db_get(
'Open_vSwitch', '.', 'external_ids').execute()
return ext_ids['system-id']
return ext_ids['system-id'] if ext_ids else None
class ChassisCreateEventBase(row_event.RowEvent):
"""Row create event - Chassis name == our_chassis.
On connection, we get a dump of all chassis so if we catch a creation
of our own chassis it has to be a reconnection. In this case, we need
to do a full sync to make sure that we capture all changes while the
connection to OVSDB was down.
"""
table: ty.Optional[str] = None
def __init__(self, vpn_agent: OvnVpnAgent):
self.agent = vpn_agent
self.first_time: bool = True
events: ty.Tuple[str] = (self.ROW_CREATE,)
super().__init__(events, self.table,
(('name', '=', self.agent.chassis),))
self.event_name = self.__class__.__name__
def run(self, event, row, old):
if self.first_time:
self.first_time = False
else:
# NOTE(lucasagomes): Re-register the ovn vpn agent
# with the local chassis in case its entry was re-created
# (happens when restarting the ovn-controller)
self.agent.register_vpn_agent()
LOG.info("Connection to OVSDB established, doing a full sync")
self.agent.sync()
class ChassisCreateEvent(ChassisCreateEventBase):
table = "Chassis"
class ChassisPrivateCreateEvent(ChassisCreateEventBase):
table = "Chassis_Private"

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as ty
from neutron.conf.plugins.ml2.drivers.ovn import ovn_conf as config
from neutron.plugins.ml2.drivers.ovn.mech_driver.ovsdb import impl_idl_ovn
@ -28,7 +29,7 @@ LOG = logging.getLogger(__name__)
class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl):
SCHEMA = 'OVN_Southbound'
SCHEMA: str = 'OVN_Southbound'
def __init__(self, chassis=None, events=None, tables=None):
connection_string = config.get_ovn_sb_connection()
@ -39,8 +40,8 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl):
for table in tables:
helper.register_table(table)
try:
super().__init__(
None, connection_string, helper, leader_only=False)
super().__init__(None, connection_string,
helper, leader_only=False)
except TypeError:
# TODO(bpetermann) We can remove this when we require ovs>=2.12.0
super().__init__(None, connection_string, helper)
@ -54,7 +55,7 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl):
@tenacity.retry(
wait=tenacity.wait_exponential(max=180),
reraise=True)
def _get_ovsdb_helper(self, connection_string):
def _get_ovsdb_helper(self, connection_string: str):
return idlutils.get_schema_helper(connection_string, self.SCHEMA)
def start(self):
@ -62,14 +63,18 @@ class VPNAgentOvnSbIdl(ovsdb_monitor.OvnIdl):
self, timeout=config.get_ovn_ovsdb_timeout())
return impl_idl_ovn.OvsdbSbOvnIdl(conn)
def post_connect(self):
pass
class VPNAgentOvsIdl(object):
class VPNAgentOvsIdl:
def start(self):
connection_string = config.cfg.CONF.ovs.ovsdb_connection
connection_string: str = config.cfg.CONF.ovs.ovsdb_connection
helper = idlutils.get_schema_helper(connection_string,
'Open_vSwitch')
tables = ('Open_vSwitch', 'Bridge', 'Port', 'Interface')
tables: ty.Tuple[str, str, str, str] = ('Open_vSwitch', 'Bridge',
'Port', 'Interface')
for table in tables:
helper.register_table(table)
ovs_idl = idl.Idl(

View File

@ -12,7 +12,10 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron.api.rpc.agentnotifiers import utils as ag_utils
from neutron_lib import context
from neutron_lib import rpc as n_rpc
import oslo_messaging
@ -23,25 +26,31 @@ from neutron_vpnaas.services.vpn.common import topics
AGENT_NOTIFY_MAX_ATTEMPTS = 2
class VPNAgentNotifyAPI(object):
class VPNAgentNotifyAPI:
"""API for plugin to notify VPN agent."""
def __init__(self, topic=topics.IPSEC_AGENT_TOPIC):
target = oslo_messaging.Target(topic=topic, version='1.0')
self.client = n_rpc.get_client(target)
def agent_updated(self, context, admin_state_up, host):
def agent_updated(
self, context: context.ContextBase,
admin_state_up: bool, host: str):
cctxt = self.client.prepare(server=host)
cctxt.cast(context, 'agent_updated',
payload={'admin_state_up': admin_state_up})
def vpnservice_removed_from_agent(self, context, router_id, host):
def vpnservice_removed_from_agent(
self, context: context.ContextBase,
router_id: str, host: str):
"""Notify agent about removed VPN service(s) of a router."""
cctxt = self.client.prepare(server=host)
cctxt.cast(context, 'vpnservice_removed_from_agent',
router_id=router_id)
def vpnservice_added_to_agent(self, context, router_ids, host):
def vpnservice_added_to_agent(
self, context: context.ContextBase,
router_ids: ty.List[str], host: str):
"""Notify agent about added VPN service(s) of router(s)."""
# need to use call here as we want to be sure agent received
# notification and router will not be "lost". However using call()

View File

@ -27,8 +27,8 @@ from neutron_vpnaas.db.migration import alembic_migrations
MYSQL_ENGINE = None
config = context.config
neutron_config = config.neutron_config
logging_config.fileConfig(config.config_file_name)
neutron_config = config.neutron_config # type: ignore
logging_config.fileConfig(config.config_file_name) # type: ignore
target_metadata = model_base.BASEV2.metadata
@ -46,7 +46,7 @@ def set_mysql_engine():
def run_migrations_offline():
set_mysql_engine()
kwargs = dict()
kwargs = {}
if neutron_config.database.connection:
kwargs['url'] = neutron_config.database.connection
else:

View File

@ -13,9 +13,12 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import random
from neutron.extensions import agent as nagent
from neutron.extensions import l3
from neutron.extensions import router_availability_zone as router_az
from neutron import worker as neutron_worker
from neutron_lib import context as ncontext
@ -31,8 +34,10 @@ import sqlalchemy as sa
from sqlalchemy import func
from neutron_vpnaas._i18n import _
from neutron_vpnaas.api.rpc.agentnotifiers import vpn_rpc_agent_api as nfy_api
from neutron_vpnaas.db.vpn import vpn_models
from neutron_vpnaas.extensions import vpn_agentschedulers
from neutron_vpnaas.scheduler.vpn_agent_scheduler import VPNScheduler
from neutron_vpnaas.services.vpn.common.constants import AGENT_TYPE_VPN
@ -71,15 +76,15 @@ class VPNAgentSchedulerDbMixin(
using the VPN agent.
"""
vpn_scheduler = None
agent_notifiers = {}
vpn_scheduler: ty.Optional[VPNScheduler] = None
agent_notifiers: ty.Dict[str, nfy_api.VPNAgentNotifyAPI] = {}
@property
def l3_plugin(self):
def l3_plugin(self) -> l3.RouterPluginBase:
return directory.get_plugin(plugin_const.L3)
@property
def core_plugin(self):
def core_plugin(self) -> nagent.AgentPluginBase:
return directory.get_plugin()
def add_periodic_vpn_agent_status_check(self):
@ -96,7 +101,7 @@ class VPNAgentSchedulerDbMixin(
check_worker = neutron_worker.PeriodicWorker(
self.reschedule_vpnservices_from_down_agents,
interval, initial_delay)
self.add_worker(check_worker)
self.add_worker(check_worker) # type: ignore
def reschedule_vpnservices_from_down_agents(self):
"""Reschedule VPN services from down VPN agents.
@ -111,9 +116,10 @@ class VPNAgentSchedulerDbMixin(
for binding in down_bindings:
if binding.vpn_agent_id in agents_back_online:
continue
agent = self.core_plugin.get_agent(context,
agent: ty.Optional[nagent.Agent] = self.core_plugin.get_agent(
context,
binding.vpn_agent_id)
if agent['alive']:
if agent and agent['alive']:
agents_back_online.add(binding.vpn_agent_id)
continue
@ -137,7 +143,8 @@ class VPNAgentSchedulerDbMixin(
"rescheduling.")
@db_api.CONTEXT_READER
def get_down_router_bindings(self, context):
def get_down_router_bindings(self,
context: ncontext.Context) -> ty.List[RouterVPNAgentBinding]:
vpn_agents = self.get_vpn_agents(context, active=False)
if not vpn_agents:
return []
@ -148,7 +155,9 @@ class VPNAgentSchedulerDbMixin(
RouterVPNAgentBinding.vpn_agent_id.in_(vpn_agent_ids))
return query.all()
def validate_agent_router_combination(self, context, agent, router):
def validate_agent_router_combination(self, context: ncontext.ContextBase,
agent: nagent.Agent,
router: ty.Dict[str, ty.Any]):
"""Validate if the router can be correctly assigned to the agent.
:raises: InvalidVPNAgent if attempting to assign router to an
@ -158,7 +167,9 @@ class VPNAgentSchedulerDbMixin(
raise vpn_agentschedulers.InvalidVPNAgent(id=agent['id'])
@db_api.CONTEXT_READER
def check_agent_router_scheduling_needed(self, context, agent, router):
def check_agent_router_scheduling_needed(self, context: ncontext.Context,
agent: ty.Dict[str, ty.Any],
router: ty.Dict[str, ty.Any]):
"""Check if the scheduling of router's VPN services is needed.
:raises: RouterHostedByVPNAgent if router is already assigned
@ -180,7 +191,8 @@ class VPNAgentSchedulerDbMixin(
router_id=router_id,
agent_id=bindings[0].vpn_agent_id)
def create_router_to_agent_binding(self, context, router_id, agent_id):
def create_router_to_agent_binding(self, context: ncontext.Context,
router_id: str, agent_id: str):
"""Create router to VPN agent binding."""
try:
with db_api.CONTEXT_WRITER.using(context):
@ -203,11 +215,12 @@ class VPNAgentSchedulerDbMixin(
{'router_id': router_id, 'agent_id': agent_id})
return True
def add_router_to_vpn_agent(self, context, agent_id, router_id):
def add_router_to_vpn_agent(self, context: ncontext.Context,
agent_id: str, router_id: str):
"""Add a VPN agent to host VPN services of a router."""
with db_api.CONTEXT_WRITER.using(context):
router = self.l3_plugin.get_router(context, router_id)
agent = self.core_plugin.get_agent(context, agent_id)
agent: nagent.Agent = self.core_plugin.get_agent(context, agent_id)
self.validate_agent_router_combination(context, agent, router)
if not self.check_agent_router_scheduling_needed(
context, agent, router):
@ -232,7 +245,8 @@ class VPNAgentSchedulerDbMixin(
self.vpn_router_agent_binding_changed(
context, router_id, agent['host'])
def remove_router_from_vpn_agent(self, context, agent_id, router_id):
def remove_router_from_vpn_agent(self, context: ncontext.Context,
agent_id: str, router_id: str):
"""Remove the router from VPN agent.
After removal, the VPN service(s) of the router will be non-hosted
@ -248,7 +262,8 @@ class VPNAgentSchedulerDbMixin(
vpn_notifier.vpnservice_removed_from_agent(
context, router_id, agent['host'])
def _unbind_router(self, context, router_id, agent_id):
def _unbind_router(self, context: ncontext.Context,
router_id: str, agent_id: str):
with db_api.CONTEXT_WRITER.using(context):
query = context.session.query(RouterVPNAgentBinding)
query = query.filter(
@ -256,7 +271,8 @@ class VPNAgentSchedulerDbMixin(
RouterVPNAgentBinding.vpn_agent_id == agent_id)
return query.delete()
def reschedule_router(self, context, router_id, cur_agent):
def reschedule_router(self, context: ncontext.Context, router_id: str,
cur_agent: nagent.Agent):
"""Reschedule router to a new VPN agent
Remove the router from the agent currently hosting it and
@ -282,8 +298,10 @@ class VPNAgentSchedulerDbMixin(
self.vpn_router_agent_binding_changed(
context, router_id, new_agent['host'])
def _notify_agents_router_rescheduled(self, context, router_id,
old_agent, new_agent):
def _notify_agents_router_rescheduled(self, context: ncontext.Context,
router_id: str,
old_agent: ty.Dict[str, ty.Any],
new_agent: ty.Dict[str, ty.Any]):
vpn_notifier = self.agent_notifiers.get(AGENT_TYPE_VPN)
if not vpn_notifier:
return
@ -303,7 +321,8 @@ class VPNAgentSchedulerDbMixin(
router_id=router_id)
@db_api.CONTEXT_READER
def list_routers_on_vpn_agent(self, context, agent_id):
def list_routers_on_vpn_agent(self, context: ncontext.Context,
agent_id: str):
query = context.session.query(RouterVPNAgentBinding.router_id)
query = query.filter(RouterVPNAgentBinding.vpn_agent_id == agent_id)
@ -312,13 +331,15 @@ class VPNAgentSchedulerDbMixin(
return {'routers':
self.l3_plugin.get_routers(context,
filters={'id': router_ids})}
else:
# Exception will be thrown if the requested agent does not exist.
self.core_plugin.get_agent(context, agent_id)
return {'routers': []}
@db_api.CONTEXT_READER
def get_vpn_agents_hosting_routers(self, context, router_ids, active=None):
def get_vpn_agents_hosting_routers(self, context: ncontext.Context,
router_ids: ty.Optional[ty.List[str]],
active: ty.Optional[bool] = None):
if not router_ids:
return []
query = context.session.query(RouterVPNAgentBinding)
@ -332,29 +353,39 @@ class VPNAgentSchedulerDbMixin(
if agent['alive'] == active]
return vpn_agents
def list_vpn_agents_hosting_router(self, context, router_id):
def list_vpn_agents_hosting_router(self, context: ncontext.Context,
router_id: str):
vpn_agents = self.get_vpn_agents_hosting_routers(context, [router_id])
return {'agents': vpn_agents}
def get_vpn_agents(self, context, active=None, host=None):
def get_vpn_agents(self, context: ncontext.Context,
active: ty.Optional[bool] = None,
host: ty.Optional[str] = None) -> \
ty.Optional[ty.List[nagent.Agent]]:
filters = {'agent_type': [AGENT_TYPE_VPN]}
if host is not None:
filters['host'] = [host]
vpn_agents = self.core_plugin.get_agents(context, filters=filters)
vpn_agents: ty.Optional[ty.List[nagent.Agent]] = \
self.core_plugin.get_agents(context, filters=filters)
if active is None:
return vpn_agents
else:
if not vpn_agents:
return None
return [vpn_agent
for vpn_agent in vpn_agents
if vpn_agent['alive'] == active]
def get_vpn_agent_on_host(self, context, host, active=None):
def get_vpn_agent_on_host(self, context: ncontext.Context,
host: str, active: ty.Optional[bool] = None):
agents = self.get_vpn_agents(context, active=active, host=host)
if agents:
return agents[0]
@db_api.CONTEXT_READER
def get_unscheduled_vpn_routers(self, context, router_ids=None):
def get_unscheduled_vpn_routers(self, context: ncontext.Context,
router_ids: ty.Optional[ty.List[str]] = None):
"""Get IDs of routers which have unscheduled VPN services."""
query = context.session.query(vpn_models.VPNService.router_id)
query = query.outerjoin(
@ -366,12 +397,14 @@ class VPNAgentSchedulerDbMixin(
vpn_models.VPNService.router_id.in_(router_ids))
return [router_id for router_id, in query.all()]
def auto_schedule_routers(self, context, vpn_agent):
def auto_schedule_routers(self, context: ncontext.Context, vpn_agent):
if self.vpn_scheduler:
return self.vpn_scheduler.auto_schedule_routers(
self, context, vpn_agent)
def schedule_router(self, context, router, candidates=None):
def schedule_router(self, context: ncontext.Context,
router, candidates: ty.Optional[ty.List] = None) -> \
ty.Optional[nagent.Agent]:
"""Schedule VPN services of a router to a VPN agent.
Returns the chosen agent; None if another server scheduled the
@ -381,9 +414,11 @@ class VPNAgentSchedulerDbMixin(
if self.vpn_scheduler:
return self.vpn_scheduler.schedule(
self, context, router, candidates=candidates)
return None
@db_api.CONTEXT_READER
def get_vpn_agent_with_min_routers(self, context, agent_ids):
def get_vpn_agent_with_min_routers(self, context: ncontext.Context,
agent_ids: ty.Optional[ty.List[str]]):
"""Return VPN agent with the least number of routers."""
if not agent_ids:
return None
@ -397,15 +432,18 @@ class VPNAgentSchedulerDbMixin(
unused_agent_ids = set(agent_ids) - set(used_agent_ids)
if unused_agent_ids:
return unused_agent_ids.pop()
else:
return used_agent_ids[0]
def get_hosts_to_notify(self, context, router_id):
def get_hosts_to_notify(self, context: ncontext.Context, router_id):
"""Returns all hosts to send notification about router update"""
agents = self.get_vpn_agents_hosting_routers(context, [router_id],
active=True)
return [a['host'] for a in agents]
def vpn_router_agent_binding_changed(self, context: ncontext.Context,
router_id: str, host: str):
pass
class AZVPNAgentSchedulerDbMixin(VPNAgentSchedulerDbMixin,
router_az.RouterAvailabilityZonePluginBase):

View File

@ -13,12 +13,14 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron.db import models_v2
from neutron_lib.callbacks import events
from neutron_lib.callbacks import registry
from neutron_lib.callbacks import resources
from neutron_lib import constants as lib_constants
from neutron_lib import context
from neutron_lib.db import api as db_api
from neutron_lib.db import model_query
from neutron_lib.db import utils as db_utils
@ -58,12 +60,13 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
"""
return vpn_validator.VpnReferenceValidator()
def update_status(self, context, model, v_id, status):
def update_status(self, context: context.Context, model, v_id: str,
status: str):
with db_api.CONTEXT_WRITER.using(context):
v_db = self._get_resource(context, model, v_id)
v_db.update({'status': status})
def _get_resource(self, context, model, v_id):
def _get_resource(self, context: context.Context, model, v_id):
try:
r = model_query.get_by_id(context, model, v_id)
except exc.NoResultFound:
@ -91,7 +94,9 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
if utils.in_pending_status(status):
raise vpn_exception.VPNStateInvalidToUpdate(id=_id, state=status)
def _make_ipsec_site_connection_dict(self, ipsec_site_conn, fields=None):
def _make_ipsec_site_connection_dict(self,
ipsec_site_conn: ty.Dict[str, ty.Any],
fields=None):
res = {'id': ipsec_site_conn['id'],
'tenant_id': ipsec_site_conn['tenant_id'],
@ -123,14 +128,16 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
return db_utils.resource_fields(res, fields)
def get_endpoint_info(self, context, ipsec_sitecon):
def get_endpoint_info(self, context: context.Context, ipsec_sitecon):
"""Obtain all endpoint info, and store in connection for validation."""
ipsec_sitecon['local_epg_subnets'] = self.get_endpoint_group(
context, ipsec_sitecon['local_ep_group_id'])
ipsec_sitecon['peer_epg_cidrs'] = self.get_endpoint_group(
context, ipsec_sitecon['peer_ep_group_id'])
def validate_connection_info(self, context, validator, ipsec_sitecon,
def validate_connection_info(
self, context: context.Context,
validator: vpn_validator.VpnReferenceValidator, ipsec_sitecon,
vpnservice):
"""Collect info and validate connection.
@ -150,7 +157,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
validator.validate_ipsec_site_connection(context, ipsec_sitecon,
ip_version, vpnservice)
def create_ipsec_site_connection(self, context, ipsec_site_connection):
def create_ipsec_site_connection(self, context: context.Context,
ipsec_site_connection: ty.Dict[str, ty.Dict]):
ipsec_sitecon = ipsec_site_connection['ipsec_site_connection']
validator = self._get_validator()
validator.assign_sensible_ipsec_sitecon_defaults(ipsec_sitecon)
@ -203,7 +211,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
return self._make_ipsec_site_connection_dict(ipsec_site_conn_db)
def update_ipsec_site_connection(
self, context,
self, context: context.Context,
ipsec_site_conn_id, ipsec_site_connection):
ipsec_sitecon = ipsec_site_connection['ipsec_site_connection']
changed_peer_cidrs = False
@ -252,36 +260,41 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
result['peer_cidrs'] = new_peer_cidrs
return result
def delete_ipsec_site_connection(self, context, ipsec_site_conn_id):
def delete_ipsec_site_connection(self, context: context.Context,
ipsec_site_conn_id: str):
with db_api.CONTEXT_WRITER.using(context):
ipsec_site_conn_db = self._get_resource(
context, vpn_models.IPsecSiteConnection, ipsec_site_conn_id)
context.session.delete(ipsec_site_conn_db)
def _get_ipsec_site_connection(
self, context, ipsec_site_conn_id):
def _get_ipsec_site_connection(self, context: context.Context,
ipsec_site_conn_id: str) -> vpn_models.IPsecSiteConnection:
return self._get_resource(
context, vpn_models.IPsecSiteConnection, ipsec_site_conn_id)
def get_ipsec_site_connection(self, context,
ipsec_site_conn_id, fields=None):
def get_ipsec_site_connection(self, context: context.Context,
ipsec_site_conn_id: str, fields=None):
ipsec_site_conn_db = self._get_ipsec_site_connection(
context, ipsec_site_conn_id)
return self._make_ipsec_site_connection_dict(
ipsec_site_conn_db, fields)
def get_ipsec_site_connections(self, context, filters=None, fields=None):
def get_ipsec_site_connections(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None,
fields=None) -> ty.List[vpn_models.IPsecSiteConnection]:
return model_query.get_collection(
context, vpn_models.IPsecSiteConnection,
self._make_ipsec_site_connection_dict,
filters=filters, fields=fields)
def update_ipsec_site_conn_status(self, context, conn_id, new_status):
def update_ipsec_site_conn_status(self, context: context.Context,
conn_id: str, new_status: str):
with db_api.CONTEXT_WRITER.using(context):
self._update_connection_status(context, conn_id, new_status, True)
def _update_connection_status(self, context, conn_id, new_status,
updated_pending):
def _update_connection_status(self, context: context.Context,
conn_id: str, new_status: str,
updated_pending: bool):
"""Update the connection status, if changed.
If the connection is not in a pending state, unconditionally update
@ -295,7 +308,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
if not utils.in_pending_status(conn_db.status) or updated_pending:
conn_db.status = new_status
def _make_ikepolicy_dict(self, ikepolicy, fields=None):
def _make_ikepolicy_dict(self, ikepolicy: ty.Dict[str, ty.Any],
fields=None) -> ty.Dict[str, ty.Any]:
res = {'id': ikepolicy['id'],
'tenant_id': ikepolicy['tenant_id'],
'name': ikepolicy['name'],
@ -313,10 +327,11 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
return db_utils.resource_fields(res, fields)
def create_ikepolicy(self, context, ikepolicy):
def create_ikepolicy(self, context: context.Context,
ikepolicy: ty.Dict[str, ty.Dict[str, ty.Any]]):
ike = ikepolicy['ikepolicy']
validator = self._get_validator()
lifetime_info = ike['lifetime']
lifetime_info: ty.Dict[str, ty.Any] = ike['lifetime']
lifetime_units = lifetime_info.get('units', 'seconds')
lifetime_value = lifetime_info.get('value', 3600)
@ -339,7 +354,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.add(ike_db)
return self._make_ikepolicy_dict(ike_db)
def update_ikepolicy(self, context, ikepolicy_id, ikepolicy):
def update_ikepolicy(self, context: context.Context, ikepolicy_id: str,
ikepolicy: ty.Dict[str, ty.Dict]):
ike = ikepolicy['ikepolicy']
validator = self._get_validator()
with db_api.CONTEXT_WRITER.using(context):
@ -359,7 +375,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
ike_db.update(ike)
return self._make_ikepolicy_dict(ike_db)
def delete_ikepolicy(self, context, ikepolicy_id):
def delete_ikepolicy(self, context: context.Context, ikepolicy_id: str):
with db_api.CONTEXT_WRITER.using(context):
if context.session.query(vpn_models.IPsecSiteConnection).filter_by(
ikepolicy_id=ikepolicy_id).first():
@ -369,19 +385,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.delete(ike_db)
@db_api.CONTEXT_READER
def get_ikepolicy(self, context, ikepolicy_id, fields=None):
def get_ikepolicy(self, context: context.Context, ikepolicy_id: str,
fields=None):
ike_db = self._get_resource(
context, vpn_models.IKEPolicy, ikepolicy_id)
return self._make_ikepolicy_dict(ike_db, fields)
@db_api.CONTEXT_READER
def get_ikepolicies(self, context, filters=None, fields=None):
def get_ikepolicies(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None, fields=None):
return model_query.get_collection(context, vpn_models.IKEPolicy,
self._make_ikepolicy_dict,
filters=filters, fields=fields)
def _make_ipsecpolicy_dict(self, ipsecpolicy, fields=None):
def _make_ipsecpolicy_dict(self, ipsecpolicy: ty.Dict[str, ty.Any],
fields=None) -> ty.Dict[str, ty.Any]:
res = {'id': ipsecpolicy['id'],
'tenant_id': ipsecpolicy['tenant_id'],
'name': ipsecpolicy['name'],
@ -399,10 +417,11 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
return db_utils.resource_fields(res, fields)
def create_ipsecpolicy(self, context, ipsecpolicy):
def create_ipsecpolicy(self, context: context.Context,
ipsecpolicy: ty.Dict[str, ty.Dict[str, ty.Any]]):
ipsecp = ipsecpolicy['ipsecpolicy']
validator = self._get_validator()
lifetime_info = ipsecp['lifetime']
lifetime_info: ty.Dict[str, ty.Any] = ipsecp['lifetime']
lifetime_units = lifetime_info.get('units', 'seconds')
lifetime_value = lifetime_info.get('value', 3600)
@ -423,7 +442,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.add(ipsecp_db)
return self._make_ipsecpolicy_dict(ipsecp_db)
def update_ipsecpolicy(self, context, ipsecpolicy_id, ipsecpolicy):
def update_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str,
ipsecpolicy: ty.Dict[str, ty.Dict[str, ty.Any]]):
ipsecp = ipsecpolicy['ipsecpolicy']
validator = self._get_validator()
with db_api.CONTEXT_WRITER.using(context):
@ -444,7 +464,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
ipsecp_db.update(ipsecp)
return self._make_ipsecpolicy_dict(ipsecp_db)
def delete_ipsecpolicy(self, context, ipsecpolicy_id):
def delete_ipsecpolicy(self, context: context.Context,
ipsecpolicy_id: str):
with db_api.CONTEXT_WRITER.using(context):
if context.session.query(vpn_models.IPsecSiteConnection).filter_by(
ipsecpolicy_id=ipsecpolicy_id).first():
@ -455,18 +476,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.delete(ipsec_db)
@db_api.CONTEXT_READER
def get_ipsecpolicy(self, context, ipsecpolicy_id, fields=None):
def get_ipsecpolicy(self, context: context.Context,
ipsecpolicy_id: str, fields=None):
ipsec_db = self._get_resource(
context, vpn_models.IPsecPolicy, ipsecpolicy_id)
return self._make_ipsecpolicy_dict(ipsec_db, fields)
@db_api.CONTEXT_READER
def get_ipsecpolicies(self, context, filters=None, fields=None):
def get_ipsecpolicies(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None, fields=None):
return model_query.get_collection(context, vpn_models.IPsecPolicy,
self._make_ipsecpolicy_dict,
filters=filters, fields=fields)
def _make_vpnservice_dict(self, vpnservice, fields=None):
def _make_vpnservice_dict(self, vpnservice: ty.Dict[str, ty.Any],
fields=None) -> ty.Dict[str, ty.Any]:
res = {'id': vpnservice['id'],
'name': vpnservice['name'],
'description': vpnservice['description'],
@ -480,7 +504,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
'status': vpnservice['status']}
return db_utils.resource_fields(res, fields)
def create_vpnservice(self, context, vpnservice):
def create_vpnservice(self, context: context.Context,
vpnservice: ty.Dict[str, ty.Dict[str, ty.Any]]):
vpns = vpnservice['vpnservice']
flavor_id = vpns.get('flavor_id', None)
validator = self._get_validator()
@ -499,8 +524,10 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.add(vpnservice_db)
return self._make_vpnservice_dict(vpnservice_db)
def set_external_tunnel_ips(self, context, vpnservice_id, v4_ip=None,
v6_ip=None):
def set_external_tunnel_ips(self, context: context.Context,
vpnservice_id: str,
v4_ip: ty.Optional[str] = None,
v6_ip: ty.Optional[str] = None):
"""Update the external tunnel IP(s) for service."""
vpns = {'external_v4_ip': v4_ip, 'external_v6_ip': v6_ip}
with db_api.CONTEXT_WRITER.using(context):
@ -509,20 +536,22 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
vpns_db.update(vpns)
return self._make_vpnservice_dict(vpns_db)
def set_vpnservice_status(self, context, vpnservice_id, status,
updated_pending_status=False):
def set_vpnservice_status(self, context: context.Context,
vpnservice_id: str, status: str,
updated_pending_status: bool = False):
vpns = {'status': status}
with db_api.CONTEXT_WRITER.using(context):
vpns_db = self._get_resource(context, vpn_models.VPNService,
vpnservice_id)
if (utils.in_pending_status(vpns_db.status) and
not updated_pending_status):
raise vpnaas.VPNStateInvalidToUpdate(
raise vpnaas.VPNStateInvalidToUpdate( # type: ignore
id=vpnservice_id, state=vpns_db.status)
vpns_db.update(vpns)
return self._make_vpnservice_dict(vpns_db)
def update_vpnservice(self, context, vpnservice_id, vpnservice):
def update_vpnservice(self, context: context.Context, vpnservice_id: str,
vpnservice: ty.Dict[str, ty.Optional[ty.Dict]]):
vpns = vpnservice['vpnservice']
with db_api.CONTEXT_WRITER.using(context):
vpns_db = self._get_resource(context, vpn_models.VPNService,
@ -532,7 +561,7 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
vpns_db.update(vpns)
return self._make_vpnservice_dict(vpns_db)
def delete_vpnservice(self, context, vpnservice_id):
def delete_vpnservice(self, context: context.Context, vpnservice_id: str):
with db_api.CONTEXT_WRITER.using(context):
if context.session.query(vpn_models.IPsecSiteConnection).filter_by(
vpnservice_id=vpnservice_id
@ -544,23 +573,26 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.delete(vpns_db)
@db_api.CONTEXT_READER
def _get_vpnservice(self, context, vpnservice_id):
def _get_vpnservice(self, context: context.Context, vpnservice_id: str):
return self._get_resource(context, vpn_models.VPNService,
vpnservice_id)
@db_api.CONTEXT_READER
def get_vpnservice(self, context, vpnservice_id, fields=None):
def get_vpnservice(self, context: context.Context, vpnservice_id: str,
fields=None):
vpns_db = self._get_resource(context, vpn_models.VPNService,
vpnservice_id)
return self._make_vpnservice_dict(vpns_db, fields)
@db_api.CONTEXT_READER
def get_vpnservices(self, context, filters=None, fields=None):
def get_vpnservices(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None,
fields=None) -> ty.List:
return model_query.get_collection(context, vpn_models.VPNService,
self._make_vpnservice_dict,
filters=filters, fields=fields)
def check_router_in_use(self, context, router_id):
def check_router_in_use(self, context: context.Context, router_id: str):
vpnservices = self.get_vpnservices(
context, filters={'router_id': [router_id]})
if vpnservices:
@ -572,7 +604,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
"(%(services)s)" % {'plural': plural,
'services': services})
def check_subnet_in_use(self, context, subnet_id, router_id):
def check_subnet_in_use(self, context: context.Context, subnet_id: str,
router_id: str):
with db_api.CONTEXT_READER.using(context):
vpnservices = context.session.query(
vpn_models.VPNService).filter_by(
@ -605,7 +638,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
subnet_id=subnet_id,
ipsec_site_connection_id=connection['id'])
def check_subnet_in_use_by_endpoint_group(self, context, subnet_id):
def check_subnet_in_use_by_endpoint_group(self, context: context.Context,
subnet_id: str):
with db_api.CONTEXT_READER.using(context):
query = context.session.query(vpn_models.VPNEndpointGroup)
query = query.filter(vpn_models.VPNEndpointGroup.endpoint_type ==
@ -620,7 +654,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
raise vpn_exception.SubnetInUseByEndpointGroup(
subnet_id=subnet_id, group_id=group['id'])
def _make_endpoint_group_dict(self, endpoint_group, fields=None):
def _make_endpoint_group_dict(self, endpoint_group: ty.Dict[str, ty.Any],
fields: ty.Optional[ty.Dict] = None) -> ty.Dict:
res = {'id': endpoint_group['id'],
'tenant_id': endpoint_group['tenant_id'],
'name': endpoint_group['name'],
@ -630,7 +665,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
for ep in endpoint_group['endpoints']]}
return db_utils.resource_fields(res, fields)
def create_endpoint_group(self, context, endpoint_group):
def create_endpoint_group(self, context: context.Context,
endpoint_group: ty.Dict[str, ty.Dict[str, ty.Any]]):
group = endpoint_group['endpoint_group']
validator = self._get_validator()
with db_api.CONTEXT_WRITER.using(context):
@ -650,8 +686,9 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.add(endpoint_db)
return self._make_endpoint_group_dict(endpoint_group_db)
def update_endpoint_group(self, context, endpoint_group_id,
endpoint_group):
def update_endpoint_group(self, context: context.Context,
endpoint_group_id: str,
endpoint_group: ty.Dict[str, ty.Dict[str, ty.Any]]):
group_changes = endpoint_group['endpoint_group']
# Note: Endpoints cannot be changed, so will not do validation
with db_api.CONTEXT_WRITER.using(context):
@ -661,7 +698,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
endpoint_group_db.update(group_changes)
return self._make_endpoint_group_dict(endpoint_group_db)
def delete_endpoint_group(self, context, endpoint_group_id):
def delete_endpoint_group(self, context: context.Context,
endpoint_group_id: str):
with db_api.CONTEXT_WRITER.using(context):
self.check_endpoint_group_not_in_use(context, endpoint_group_id)
endpoint_group_db = self._get_resource(
@ -669,18 +707,21 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
context.session.delete(endpoint_group_db)
@db_api.CONTEXT_READER
def get_endpoint_group(self, context, endpoint_group_id, fields=None):
def get_endpoint_group(self, context: context.Context,
endpoint_group_id: str, fields=None):
endpoint_group_db = self._get_resource(
context, vpn_models.VPNEndpointGroup, endpoint_group_id)
return self._make_endpoint_group_dict(endpoint_group_db, fields)
@db_api.CONTEXT_READER
def get_endpoint_groups(self, context, filters=None, fields=None):
def get_endpoint_groups(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None, fields=None):
return model_query.get_collection(context, vpn_models.VPNEndpointGroup,
self._make_endpoint_group_dict,
filters=filters, fields=fields)
def check_endpoint_group_not_in_use(self, context, group_id):
def check_endpoint_group_not_in_use(self, context: context.Context,
group_id: str):
query = context.session.query(vpn_models.IPsecSiteConnection)
query = query.filter(
sa.or_(
@ -690,13 +731,15 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
if query.first():
raise vpn_exception.EndpointGroupInUse(group_id=group_id)
def get_vpnservice_router_id(self, context, vpnservice_id):
def get_vpnservice_router_id(self, context: context.Context,
vpnservice_id: str):
with db_api.CONTEXT_READER.using(context):
vpnservice = self._get_vpnservice(context, vpnservice_id)
return vpnservice['router_id']
@db_api.CONTEXT_READER
def get_peer_cidrs_for_router(self, context, router_id):
def get_peer_cidrs_for_router(self, context: context.Context,
router_id: str):
filters = {'router_id': [router_id]}
vpnservices = model_query.get_collection_query(
context, vpn_models.VPNService, filters=filters).all()
@ -712,8 +755,8 @@ class VPNPluginDb(vpnaas.VPNPluginBase,
return cidrs
class VPNPluginRpcDbMixin(object):
def _build_local_subnet_cidr_map(self, context):
class VPNPluginRpcDbMixin(VPNPluginDb):
def _build_local_subnet_cidr_map(self, context: context.Context):
"""Build a dict of all local endpoint subnets, with list of CIDRs."""
query = context.session.query(models_v2.Subnet.id,
models_v2.Subnet.cidr)
@ -728,7 +771,8 @@ class VPNPluginRpcDbMixin(object):
vpn_models.VPNEndpointGroup.id)
return {sn.id: sn.cidr for sn in query.all()}
def update_status_by_agent(self, context, service_status_info_list):
def update_status_by_agent(self, context: context.Context,
service_status_info_list):
"""Updating vpnservice and vpnconnection status.
:param context: context variable
@ -768,7 +812,8 @@ class VPNPluginRpcDbMixin(object):
def vpn_router_gateway_callback(resource, event, trigger, payload=None):
# the event payload objects
vpn_plugin = directory.get_plugin(p_constants.VPN)
vpn_plugin: ty.Optional[VPNPluginDb] = \
directory.get_plugin(p_constants.VPN)
if vpn_plugin:
context = payload.context
router_id = payload.resource_id
@ -782,7 +827,8 @@ def vpn_router_gateway_callback(resource, event, trigger, payload=None):
def migration_callback(resource, event, trigger, payload):
context = payload.context
router = payload.latest_state
vpn_plugin = directory.get_plugin(p_constants.VPN)
vpn_plugin: ty.Optional[VPNPluginDb] = \
directory.get_plugin(p_constants.VPN)
if vpn_plugin:
vpn_plugin.check_router_in_use(context, router['id'])
return True
@ -792,7 +838,8 @@ def subnet_callback(resource, event, trigger, payload=None):
"""Respond to subnet based notifications - see if subnet in use."""
context = payload.context
subnet_id = payload.resource_id
vpn_plugin = directory.get_plugin(p_constants.VPN)
vpn_plugin: ty.Optional[VPNPluginDb] = \
directory.get_plugin(p_constants.VPN)
if vpn_plugin:
vpn_plugin.check_subnet_in_use_by_endpoint_group(context, subnet_id)

View File

@ -13,6 +13,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron.db.models import l3 as l3_models
from neutron.db import models_v2
@ -20,6 +21,7 @@ from neutron_lib.callbacks import events
from neutron_lib.callbacks import registry
from neutron_lib.callbacks import resources
from neutron_lib import constants as lib_constants
from neutron_lib import context
from neutron_lib.db import api as db_api
from neutron_lib.db import model_base
from neutron_lib.db import model_query
@ -33,8 +35,15 @@ from sqlalchemy import orm
from sqlalchemy.orm import exc
from neutron_vpnaas._i18n import _
from neutron_vpnaas.db.vpn import vpn_db
from neutron_vpnaas.services.vpn.common import constants as v_constants
#pylint: disable=ungrouped-imports
# Additional import for typechecking. Importing these without typechecking
# would resolve in a cyclic dependency
if ty.TYPE_CHECKING:
from neutron.db import db_base_plugin_v2 as db_plugin
#pylint: enable=ungrouped-imports
LOG = logging.getLogger(__name__)
@ -80,11 +89,11 @@ class VPNExtGW(model_base.BASEV2, model_base.HasId, model_base.HasProject):
@registry.has_registry_receivers
class VPNExtGWPlugin_db(object):
class VPNExtGWPlugin_db(vpn_db.VPNPluginDb):
"""DB class to support vpn external ports configuration."""
@property
def _core_plugin(self):
def _core_plugin(self) -> 'db_plugin.NeutronDbPluginV2':
return directory.get_plugin()
@property
@ -95,13 +104,13 @@ class VPNExtGWPlugin_db(object):
@registry.receives(resources.PORT, [events.BEFORE_DELETE])
def _prevent_vpn_port_delete_callback(resource, event,
trigger, payload=None):
vpn_plugin = directory.get_plugin(plugin_const.VPN)
vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN)
if vpn_plugin:
vpn_plugin.prevent_vpn_port_deletion(payload.context,
payload.resource_id)
@db_api.CONTEXT_READER
def _id_used(self, context, id_column, resource_id):
def _id_used(self, context: context.Context, id_column, resource_id):
return context.session.query(VPNExtGW).filter(
sa.and_(
id_column == resource_id,
@ -109,7 +118,8 @@ class VPNExtGWPlugin_db(object):
)
).count() > 0
def prevent_vpn_port_deletion(self, context, port_id):
def prevent_vpn_port_deletion(self, context: context.Context,
port_id: str):
"""Checks to make sure a port is allowed to be deleted.
Raises an exception if this is not the case. This should be called by
@ -124,7 +134,7 @@ class VPNExtGWPlugin_db(object):
# non-existent ports don't need to be protected from deletion
return
port_id_column = {
port_id_column: ty.Optional[str] = {
v_constants.DEVICE_OWNER_VPN_ROUTER_GW: VPNExtGW.gw_port_id,
v_constants.DEVICE_OWNER_TRANSIT_NETWORK:
VPNExtGW.transit_port_id,
@ -142,12 +152,13 @@ class VPNExtGWPlugin_db(object):
@registry.receives(resources.SUBNET, [events.BEFORE_DELETE])
def _prevent_vpn_subnet_delete_callback(resource, event,
trigger, payload=None):
vpn_plugin = directory.get_plugin(plugin_const.VPN)
vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN)
if vpn_plugin:
vpn_plugin.prevent_vpn_subnet_deletion(payload.context,
payload.resource_id)
def prevent_vpn_subnet_deletion(self, context, subnet_id):
def prevent_vpn_subnet_deletion(self, context: context.Context,
subnet_id: str):
if self._id_used(context, VPNExtGW.transit_subnet_id, subnet_id):
reason = _('Subnet is used by VPN service')
raise n_exc.SubnetInUse(subnet_id=subnet_id, reason=reason)
@ -156,16 +167,18 @@ class VPNExtGWPlugin_db(object):
@registry.receives(resources.NETWORK, [events.BEFORE_DELETE])
def _prevent_vpn_network_delete_callback(resource, event,
trigger, payload=None):
vpn_plugin = directory.get_plugin(plugin_const.VPN)
vpn_plugin: VPNExtGWPlugin_db = directory.get_plugin(plugin_const.VPN)
if vpn_plugin:
vpn_plugin.prevent_vpn_network_deletion(payload.context,
payload.resource_id)
def prevent_vpn_network_deletion(self, context, network_id):
def prevent_vpn_network_deletion(self, context: context.Context,
network_id: str):
if self._id_used(context, VPNExtGW.transit_network_id, network_id):
raise VPNNetworkInUse(network_id=network_id)
def _make_vpn_ext_gw_dict(self, gateway_db):
def _make_vpn_ext_gw_dict(self, gateway_db: ty.Optional[VPNExtGW]) -> \
ty.Optional[ty.Dict[str, ty.Any]]:
if not gateway_db:
return None
gateway = {
@ -187,7 +200,8 @@ class VPNExtGWPlugin_db(object):
gateway[key] = value
return gateway
def _get_vpn_gw_by_router_id(self, context, router_id):
def _get_vpn_gw_by_router_id(self, context: context.Context,
router_id: str) -> ty.Optional[VPNExtGW]:
try:
gateway_db = context.session.query(VPNExtGW).filter(
VPNExtGW.router_id == router_id).one()
@ -196,17 +210,20 @@ class VPNExtGWPlugin_db(object):
return gateway_db
@db_api.CONTEXT_READER
def get_vpn_gw_by_router_id(self, context, router_id):
def get_vpn_gw_by_router_id(
self, context: context.Context, router_id: str):
return self._get_vpn_gw_by_router_id(context, router_id)
@db_api.CONTEXT_READER
def get_vpn_gw_dict_by_router_id(self, context, router_id, refresh=False):
def get_vpn_gw_dict_by_router_id(self, context: context.Context,
router_id: str, refresh: bool = False):
gateway_db = self._get_vpn_gw_by_router_id(context, router_id)
if gateway_db and refresh:
context.session.refresh(gateway_db)
return self._make_vpn_ext_gw_dict(gateway_db)
def create_gateway(self, context, gateway):
def create_gateway(self, context: context.Context,
gateway: ty.Dict[str, ty.Dict[str, ty.Any]]):
info = gateway['gateway']
with db_api.CONTEXT_WRITER.using(context):
@ -223,14 +240,15 @@ class VPNExtGWPlugin_db(object):
return self._make_vpn_ext_gw_dict(gateway_db)
def update_gateway(self, context, gateway_id, gateway):
def update_gateway(self, context: context.Context, gateway_id: str,
gateway: ty.Dict[str, ty.Dict[str, ty.Any]]):
info = gateway['gateway']
with db_api.CONTEXT_WRITER.using(context):
gateway_db = model_query.get_by_id(context, VPNExtGW, gateway_id)
gateway_db.update(info)
return self._make_vpn_ext_gw_dict(gateway_db)
def delete_gateway(self, context, gateway_id):
def delete_gateway(self, context: context.Context, gateway_id: str):
with db_api.CONTEXT_WRITER.using(context):
query = context.session.query(VPNExtGW)
return query.filter(VPNExtGW.id == gateway_id).delete()

View File

@ -12,12 +12,18 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import socket
import netaddr
from neutron.db import l3_db
from neutron.db import models_v2
from neutron import neutron_plugin_base_v2
from neutron.services.l3_router import l3_router_plugin
from neutron_lib.api import validators
from neutron_lib import context
from neutron_lib import exceptions as nexception
from neutron_lib.exceptions import vpn as vpn_exception
from neutron_lib.plugins import constants as plugin_const
@ -27,7 +33,7 @@ from neutron_vpnaas._i18n import _
from neutron_vpnaas.services.vpn.common import constants
class VpnReferenceValidator(object):
class VpnReferenceValidator:
"""
Baseline validation routines for VPN resources.
@ -38,28 +44,30 @@ class VpnReferenceValidator(object):
IP_MIN_MTU = {4: 68, 6: 1280}
@property
def l3_plugin(self):
def l3_plugin(self) -> l3_router_plugin.L3RouterPlugin:
try:
return self._l3_plugin
except AttributeError:
self._l3_plugin = directory.get_plugin(plugin_const.L3)
self._l3_plugin: l3_router_plugin.L3RouterPlugin = \
directory.get_plugin(plugin_const.L3)
return self._l3_plugin
@property
def core_plugin(self):
def core_plugin(self) -> neutron_plugin_base_v2.NeutronPluginBaseV2:
try:
return self._core_plugin
except AttributeError:
self._core_plugin = directory.get_plugin()
self._core_plugin: neutron_plugin_base_v2.NeutronPluginBaseV2 = \
directory.get_plugin()
return self._core_plugin
def _check_dpd(self, ipsec_sitecon):
def _check_dpd(self, ipsec_sitecon: ty.Dict[str, ty.Union[int, ty.Any]]):
"""Ensure that DPD timeout is greater than DPD interval."""
if ipsec_sitecon['dpd_timeout'] <= ipsec_sitecon['dpd_interval']:
raise vpn_exception.IPsecSiteConnectionDpdIntervalValueError(
attr='dpd_timeout')
def _check_mtu(self, context, mtu, ip_version):
def _check_mtu(self, context: context.Context, mtu, ip_version):
if mtu < VpnReferenceValidator.IP_MIN_MTU[ip_version]:
raise vpn_exception.IPsecSiteConnectionMtuError(
mtu=mtu, version=ip_version)
@ -95,7 +103,7 @@ class VpnReferenceValidator(object):
ip_version = netaddr.IPAddress(ipsec_sitecon['peer_address']).version
self._validate_peer_address(ip_version, router)
def _get_local_subnets(self, context, endpoint_group):
def _get_local_subnets(self, context: context.Context, endpoint_group):
if endpoint_group['type'] != constants.SUBNET_ENDPOINT:
raise vpn_exception.WrongEndpointGroupType(
group_type=endpoint_group['type'], which=endpoint_group['id'],
@ -118,7 +126,7 @@ class VpnReferenceValidator(object):
"""
if len(local_subnets) == 1:
return local_subnets[0]['ip_version']
ip_versions = set([subnet['ip_version'] for subnet in local_subnets])
ip_versions = {subnet['ip_version'] for subnet in local_subnets}
if len(ip_versions) > 1:
raise vpn_exception.MixedIPVersionsForIPSecEndpoints(
group=group_id)
@ -131,7 +139,7 @@ class VpnReferenceValidator(object):
"""
if len(peer_cidrs) == 1:
return netaddr.IPNetwork(peer_cidrs[0]).version
ip_versions = set([netaddr.IPNetwork(pc).version for pc in peer_cidrs])
ip_versions = {netaddr.IPNetwork(pc).version for pc in peer_cidrs}
if len(ip_versions) > 1:
raise vpn_exception.MixedIPVersionsForIPSecEndpoints(
group=group_id)
@ -149,12 +157,13 @@ class VpnReferenceValidator(object):
"""Ensure all CIDRs have the same IP version."""
if len(peer_cidrs) == 1:
return netaddr.IPNetwork(peer_cidrs[0]).version
ip_versions = set([netaddr.IPNetwork(pc).version for pc in peer_cidrs])
ip_versions = {netaddr.IPNetwork(pc).version for pc in peer_cidrs}
if len(ip_versions) > 1:
raise vpn_exception.MixedIPVersionsForPeerCidrs()
return ip_versions.pop()
def _check_local_subnets_on_router(self, context, router, local_subnets):
def _check_local_subnets_on_router(self, context: context.Context,
router, local_subnets):
for subnet in local_subnets:
self._check_subnet_id(context, router, subnet['id'])
@ -163,7 +172,9 @@ class VpnReferenceValidator(object):
if local_ip_version != peer_ip_version:
raise vpn_exception.MixedIPVersionsForIPSecConnection()
def validate_ipsec_conn_optional_args(self, ipsec_sitecon, subnet):
def validate_ipsec_conn_optional_args(self,
ipsec_sitecon: ty.Dict[str, ty.Any],
subnet):
"""Ensure that proper combinations of optional args are used.
When VPN service has a subnet, then we must have peer_cidrs, and
@ -176,10 +187,10 @@ class VpnReferenceValidator(object):
local_epg_id = ipsec_sitecon.get('local_ep_group_id')
peer_epg_id = ipsec_sitecon.get('peer_ep_group_id')
peer_cidrs = ipsec_sitecon.get('peer_cidrs')
epgs: ty.List[str] = []
if subnet:
if not peer_cidrs:
raise vpn_exception.MissingPeerCidrs()
epgs = []
if local_epg_id:
epgs.append('local')
if peer_epg_id:
@ -192,7 +203,6 @@ class VpnReferenceValidator(object):
else:
if peer_cidrs:
raise vpn_exception.PeerCidrsInvalid()
epgs = []
if not local_epg_id:
epgs.append('local')
if not peer_epg_id:
@ -203,8 +213,9 @@ class VpnReferenceValidator(object):
raise vpn_exception.MissingRequiredEndpointGroup(
which=which, suffix=suffix)
def assign_sensible_ipsec_sitecon_defaults(self, ipsec_sitecon,
prev_conn=None):
def assign_sensible_ipsec_sitecon_defaults(self,
ipsec_sitecon: ty.Dict[str, ty.Any],
prev_conn: ty.Optional[ty.Dict] = None):
"""Provide defaults for optional items, if missing.
With endpoint groups capabilities, the peer_cidr (legacy mode)
@ -229,7 +240,7 @@ class VpnReferenceValidator(object):
prev_conn = {'dpd_action': 'hold',
'dpd_interval': 30,
'dpd_timeout': 120}
dpd = ipsec_sitecon.get('dpd', {})
dpd: ty.Dict[str, ty.Any] = ipsec_sitecon.get('dpd', {})
ipsec_sitecon['dpd_action'] = dpd.get('action',
prev_conn['dpd_action'])
ipsec_sitecon['dpd_interval'] = dpd.get('interval',
@ -237,8 +248,10 @@ class VpnReferenceValidator(object):
ipsec_sitecon['dpd_timeout'] = dpd.get('timeout',
prev_conn['dpd_timeout'])
def validate_ipsec_site_connection(self, context, ipsec_sitecon,
local_ip_version, vpnservice=None):
def validate_ipsec_site_connection(
self, context: context.Context, ipsec_sitecon: ty.Dict[str, ty.Any],
local_ip_version,
vpnservice: ty.Optional[ty.Dict[str, ty.Any]] = None):
"""Reference implementation of validation for IPSec connection.
This makes sure that IP versions are the same. For endpoint groups,
@ -254,7 +267,9 @@ class VpnReferenceValidator(object):
local_subnets = self._get_local_subnets(
context, ipsec_sitecon['local_epg_subnets'])
self._check_local_subnets_on_router(
context, vpnservice['router_id'], local_subnets)
context,
vpnservice['router_id'] if vpnservice else '',
local_subnets)
local_ip_version = self._check_local_endpoint_ip_versions(
ipsec_sitecon['local_ep_group_id'], local_subnets)
peer_cidrs = self._get_peer_cidrs(ipsec_sitecon['peer_epg_cidrs'])
@ -272,12 +287,13 @@ class VpnReferenceValidator(object):
if mtu:
self._check_mtu(context, mtu, local_ip_version)
def _check_router(self, context, router_id):
def _check_router(self, context: context.Context, router_id: str):
router = self.l3_plugin.get_router(context, router_id)
if not router.get(l3_db.EXTERNAL_GW_INFO):
raise vpn_exception.RouterIsNotExternal(router_id=router_id)
def _check_subnet_id(self, context, router_id, subnet_id):
def _check_subnet_id(self, context: context.Context,
router_id: str, subnet_id: str):
ports = self.core_plugin.get_ports(
context,
filters={
@ -288,13 +304,14 @@ class VpnReferenceValidator(object):
subnet_id=subnet_id,
router_id=router_id)
def validate_vpnservice(self, context, vpnservice):
def validate_vpnservice(self, context: context.Context,
vpnservice: ty.Dict[str, ty.Any]):
self._check_router(context, vpnservice['router_id'])
if vpnservice['subnet_id'] is not None:
self._check_subnet_id(context, vpnservice['router_id'],
vpnservice['subnet_id'])
def validate_ipsec_policy(self, context, ipsec_policy):
def validate_ipsec_policy(self, context: context.Context, ipsec_policy):
"""Reference implementation of validation for IPSec Policy.
Service driver can override and implement specific logic
@ -311,7 +328,8 @@ class VpnReferenceValidator(object):
group_type=constants.CIDR_ENDPOINT, endpoint=cidr,
why=_("Invalid CIDR"))
def _validate_subnets(self, context, subnet_ids):
def _validate_subnets(self, context: context.Context,
subnet_ids: ty.List[str]):
"""Ensure UUIDs OK and subnets exist."""
for subnet_id in subnet_ids:
msg = validators.validate_uuid(subnet_id)
@ -325,7 +343,8 @@ class VpnReferenceValidator(object):
raise vpn_exception.NonExistingSubnetInEndpointGroup(
subnet=subnet_id)
def validate_endpoint_group(self, context, endpoint_group):
def validate_endpoint_group(self, context: context.Context,
endpoint_group: ty.Dict[str, ty.Any]):
"""Reference validator for endpoint group.
Ensures that there is at least one endpoint, all the endpoints in the
@ -342,7 +361,7 @@ class VpnReferenceValidator(object):
elif group_type == constants.SUBNET_ENDPOINT:
self._validate_subnets(context, endpoints)
def validate_ike_policy(self, context, ike_policy):
def validate_ike_policy(self, context: context.Context, ike_policy):
"""Reference implementation of validation for IKE Policy.
Service driver can override and implement specific logic

View File

@ -13,6 +13,7 @@
# under the License.
import abc
import typing as ty
from neutron.api import extensions
from neutron.api.v2 import resource
@ -20,6 +21,7 @@ from neutron import policy
from neutron import wsgi
from neutron_lib.api import extensions as lib_extensions
from neutron_lib.api import faults as base
from neutron_lib import context
from neutron_lib import exceptions
from neutron_lib.plugins import constants as plugin_const
from neutron_lib.plugins import directory
@ -37,9 +39,36 @@ VPN_AGENT = 'vpn-agent'
VPN_AGENTS = VPN_AGENT + 's'
class VPNAgentSchedulerPluginBase(metaclass=abc.ABCMeta):
"""REST API to operate the VPN agent scheduler.
All methods must be in an admin context.
"""
@abc.abstractmethod
def add_router_to_vpn_agent(self, context: context.ContextBase,
id: str, router_id: str):
pass
@abc.abstractmethod
def remove_router_from_vpn_agent(self, context: context.ContextBase,
id: str, router_id: str):
pass
@abc.abstractmethod
def list_routers_on_vpn_agent(self, context: context.ContextBase, id: str):
pass
@abc.abstractmethod
def list_vpn_agents_hosting_router(self, context: context.ContextBase,
router_id: str):
pass
class VPNRouterSchedulerController(wsgi.Controller):
def get_plugin(self):
plugin = directory.get_plugin(plugin_const.VPN)
def get_plugin(self) -> VPNAgentSchedulerPluginBase:
plugin: VPNAgentSchedulerPluginBase = \
directory.get_plugin(plugin_const.VPN)
if not plugin:
LOG.error('No plugin for VPN registered to handle VPN '
'router scheduling')
@ -47,18 +76,18 @@ class VPNRouterSchedulerController(wsgi.Controller):
raise webob.exc.HTTPNotFound(msg)
return plugin
def index(self, request, **kwargs):
def index(self, request: wsgi.Request, **kwargs):
plugin = self.get_plugin()
policy.enforce(request.context,
"get_%s" % VPN_ROUTERS,
f"get_{VPN_ROUTERS}",
{})
return plugin.list_routers_on_vpn_agent(
request.context, kwargs['agent_id'])
def create(self, request, body, **kwargs):
def create(self, request: wsgi.Request, body, **kwargs):
plugin = self.get_plugin()
policy.enforce(request.context,
"create_%s" % VPN_ROUTER,
f"create_{VPN_ROUTER}",
{})
agent_id = kwargs['agent_id']
router_id = body['router_id']
@ -67,10 +96,10 @@ class VPNRouterSchedulerController(wsgi.Controller):
notify(request.context, 'vpn_agent.router.add', router_id, agent_id)
return result
def delete(self, request, id, **kwargs):
def delete(self, request: wsgi.Request, id, **kwargs):
plugin = self.get_plugin()
policy.enforce(request.context,
"delete_%s" % VPN_ROUTER,
f"delete_{VPN_ROUTER}",
{})
agent_id = kwargs['agent_id']
result = plugin.remove_router_from_vpn_agent(request.context, agent_id,
@ -80,18 +109,19 @@ class VPNRouterSchedulerController(wsgi.Controller):
class VPNAgentsHostingRouterController(wsgi.Controller):
def get_plugin(self):
plugin = directory.get_plugin(plugin_const.VPN)
def get_plugin(self) -> VPNAgentSchedulerPluginBase:
plugin: VPNAgentSchedulerPluginBase = \
directory.get_plugin(plugin_const.VPN)
if not plugin:
LOG.error('VPN plugin not registered to handle agent scheduling')
msg = 'The resource could not be found.'
raise webob.exc.HTTPNotFound(msg)
return plugin
def index(self, request, **kwargs):
def index(self, request: wsgi.Request, **kwargs):
plugin = self.get_plugin()
policy.enforce(request.context,
"get_%s" % VPN_AGENTS,
f"get_{VPN_AGENTS}",
{})
return plugin.list_vpn_agents_hosting_router(
request.context, kwargs['router_id'])
@ -102,35 +132,35 @@ class Vpn_agentschedulers(lib_extensions.ExtensionDescriptor):
"""
@classmethod
def get_name(cls):
def get_name(cls) -> str:
return "VPN Agent Scheduler"
@classmethod
def get_alias(cls):
def get_alias(cls) -> str:
return "vpn-agent-scheduler"
@classmethod
def get_description(cls):
def get_description(cls) -> str:
return "Schedule VPN services of routers among VPN agents"
@classmethod
def get_updated(cls):
def get_updated(cls) -> str:
return "2016-08-15T10:00:00-00:00"
@classmethod
def get_resources(cls):
def get_resources(cls) -> ty.List[extensions.ResourceExtension]:
"""Returns Ext Resources."""
exts = []
parent = dict(member_name="agent",
collection_name="agents")
exts: ty.List[extensions.ResourceExtension] = []
parent = {'member_name': "agent",
'collection_name': "agents"}
controller = resource.Resource(VPNRouterSchedulerController(),
base.FAULT_MAP)
exts.append(extensions.ResourceExtension(
VPN_ROUTERS, controller, parent))
parent = dict(member_name="router",
collection_name="routers")
parent = {'member_name': "router",
'collection_name': "routers"}
controller = resource.Resource(VPNAgentsHostingRouterController(),
base.FAULT_MAP)
@ -138,7 +168,7 @@ class Vpn_agentschedulers(lib_extensions.ExtensionDescriptor):
VPN_AGENTS, controller, parent))
return exts
def get_extended_resources(self, version):
def get_extended_resources(self, version) -> ty.Dict:
return {}
@ -161,30 +191,7 @@ class RouterReschedulingFailed(exceptions.Conflict):
"No eligible VPN agent found.")
class VPNAgentSchedulerPluginBase(object, metaclass=abc.ABCMeta):
"""REST API to operate the VPN agent scheduler.
All methods must be in an admin context.
"""
@abc.abstractmethod
def add_router_to_vpn_agent(self, context, id, router_id):
pass
@abc.abstractmethod
def remove_router_from_vpn_agent(self, context, id, router_id):
pass
@abc.abstractmethod
def list_routers_on_vpn_agent(self, context, id):
pass
@abc.abstractmethod
def list_vpn_agents_hosting_router(self, context, router_id):
pass
def notify(context, action, router_id, agent_id):
def notify(context: context.ContextBase, action, router_id, agent_id):
info = {'id': agent_id, 'router_id': router_id}
notifier = n_rpc.get_notifier('router')
notifier.info(context, action, {'agent': info})

View File

@ -14,10 +14,14 @@
import abc
import typing as ty
from neutron_lib.api.definitions import vpn_endpoint_groups
from neutron_lib.api import extensions
from neutron_lib import context
from neutron_lib.plugins import constants as nconstants
from neutron.api import extensions as nextensions
from neutron.api.v2 import resource_helper
@ -25,7 +29,7 @@ class Vpn_endpoint_groups(extensions.APIExtensionDescriptor):
api_definition = vpn_endpoint_groups
@classmethod
def get_resources(cls):
def get_resources(cls) -> ty.List[nextensions.ResourceExtension]:
plural_mappings = resource_helper.build_plural_mappings(
{}, vpn_endpoint_groups.RESOURCE_ATTRIBUTE_MAP)
return resource_helper.build_resource_info(
@ -36,25 +40,28 @@ class Vpn_endpoint_groups(extensions.APIExtensionDescriptor):
translate_name=True)
class VPNEndpointGroupsPluginBase(object, metaclass=abc.ABCMeta):
class VPNEndpointGroupsPluginBase(metaclass=abc.ABCMeta):
@abc.abstractmethod
def create_endpoint_group(self, context, endpoint_group):
def create_endpoint_group(self, context: context.Context, endpoint_group):
pass
@abc.abstractmethod
def update_endpoint_group(self, context, endpoint_group_id,
endpoint_group):
def update_endpoint_group(self, context: context.Context,
endpoint_group_id: str, endpoint_group):
pass
@abc.abstractmethod
def delete_endpoint_group(self, context, endpoint_group_id):
def delete_endpoint_group(self, context: context.Context,
endpoint_group_id: str):
pass
@abc.abstractmethod
def get_endpoint_group(self, context, endpoint_group_id, fields=None):
def get_endpoint_group(self, context: context.Context,
endpoint_group_id: str, fields=None):
pass
@abc.abstractmethod
def get_endpoint_groups(self, context, filters=None, fields=None):
def get_endpoint_groups(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None, fields=None):
pass

View File

@ -14,13 +14,16 @@
# under the License.
import abc
import typing as ty
from neutron_lib.api.definitions import vpn
from neutron_lib.api import extensions
from neutron_lib import context
from neutron_lib import exceptions as nexception
from neutron_lib.plugins import constants as nconstants
from neutron_lib.services import base as service_base
from neutron.api import extensions as nextensions
from neutron.api.v2 import resource_helper
from neutron_vpnaas._i18n import _
@ -51,7 +54,7 @@ class Vpnaas(extensions.APIExtensionDescriptor):
api_definition = vpn
@classmethod
def get_resources(cls):
def get_resources(cls) -> ty.List[nextensions.ResourceExtension]:
special_mappings = {'ikepolicies': 'ikepolicy',
'ipsecpolicies': 'ipsecpolicy'}
plural_mappings = resource_helper.build_plural_mappings(
@ -71,90 +74,105 @@ class Vpnaas(extensions.APIExtensionDescriptor):
class VPNPluginBase(service_base.ServicePluginBase, metaclass=abc.ABCMeta):
def get_plugin_type(self):
def get_plugin_type(self) -> str:
return nconstants.VPN
def get_plugin_description(self):
def get_plugin_description(self) -> str:
return 'VPN service plugin'
@abc.abstractmethod
def get_vpnservices(self, context, filters=None, fields=None):
def get_vpnservices(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None, fields=None):
pass
@abc.abstractmethod
def get_vpnservice(self, context, vpnservice_id, fields=None):
def get_vpnservice(self, context: context.Context, vpnservice_id: str,
fields=None):
pass
@abc.abstractmethod
def create_vpnservice(self, context, vpnservice):
def create_vpnservice(self, context: context.Context, vpnservice):
pass
@abc.abstractmethod
def update_vpnservice(self, context, vpnservice_id, vpnservice):
def update_vpnservice(self, context: context.Context, vpnservice_id: str,
vpnservice):
pass
@abc.abstractmethod
def delete_vpnservice(self, context, vpnservice_id):
def delete_vpnservice(self, context: context.Context, vpnservice_id: str):
pass
@abc.abstractmethod
def get_ipsec_site_connections(self, context, filters=None, fields=None):
def get_ipsec_site_connections(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None,
fields=None):
pass
@abc.abstractmethod
def get_ipsec_site_connection(self, context,
ipsecsite_conn_id, fields=None):
def get_ipsec_site_connection(self, context: context.Context,
ipsecsite_conn_id: str, fields=None):
pass
@abc.abstractmethod
def create_ipsec_site_connection(self, context, ipsec_site_connection):
def create_ipsec_site_connection(self, context: context.Context,
ipsec_site_connection):
pass
@abc.abstractmethod
def update_ipsec_site_connection(self, context,
ipsecsite_conn_id, ipsec_site_connection):
def update_ipsec_site_connection(self, context: context.Context,
ipsecsite_conn_id: str,
ipsec_site_connection):
pass
@abc.abstractmethod
def delete_ipsec_site_connection(self, context, ipsecsite_conn_id):
def delete_ipsec_site_connection(self, context: context.Context,
ipsecsite_conn_id: str):
pass
@abc.abstractmethod
def get_ikepolicy(self, context, ikepolicy_id, fields=None):
def get_ikepolicy(self, context: context.Context, ikepolicy_id: str,
fields=None):
pass
@abc.abstractmethod
def get_ikepolicies(self, context, filters=None, fields=None):
def get_ikepolicies(self, context: context.Context,
filters: ty.Optional[ty.Dict], fields=None):
pass
@abc.abstractmethod
def create_ikepolicy(self, context, ikepolicy):
def create_ikepolicy(self, context: context.Context, ikepolicy):
pass
@abc.abstractmethod
def update_ikepolicy(self, context, ikepolicy_id, ikepolicy):
def update_ikepolicy(self, context: context.Context, ikepolicy_id: str,
ikepolicy):
pass
@abc.abstractmethod
def delete_ikepolicy(self, context, ikepolicy_id):
def delete_ikepolicy(self, context: context.Context, ikepolicy_id: str):
pass
@abc.abstractmethod
def get_ipsecpolicies(self, context, filters=None, fields=None):
def get_ipsecpolicies(self, context: context.Context,
filters: ty.Optional[ty.Dict] = None, fields=None):
pass
@abc.abstractmethod
def get_ipsecpolicy(self, context, ipsecpolicy_id, fields=None):
def get_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str,
fields=None):
pass
@abc.abstractmethod
def create_ipsecpolicy(self, context, ipsecpolicy):
def create_ipsecpolicy(self, context: context.Context, ipsecpolicy):
pass
@abc.abstractmethod
def update_ipsecpolicy(self, context, ipsecpolicy_id, ipsecpolicy):
def update_ipsecpolicy(self, context: context.Context, ipsecpolicy_id: str,
ipsecpolicy):
pass
@abc.abstractmethod
def delete_ipsecpolicy(self, context, ipsecpolicy_id):
def delete_ipsecpolicy(self, context: context.Context,
ipsecpolicy_id: str):
pass

View File

@ -9,6 +9,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import neutron.conf.plugins.ml2.drivers.ovn.ovn_conf
import neutron.services.provider_configuration
@ -19,7 +20,7 @@ import neutron_vpnaas.services.vpn.device_drivers.strongswan_ipsec
import neutron_vpnaas.services.vpn.ovn_agent
def list_agent_opts():
def list_agent_opts() -> ty.List[ty.Tuple[str, ty.List]]:
return [
('vpnagent',
neutron_vpnaas.services.vpn.agent.vpn_agent_opts),
@ -33,7 +34,7 @@ def list_agent_opts():
]
def list_ovn_agent_opts():
def list_ovn_agent_opts() -> ty.List[ty.Tuple[str, ty.List]]:
return [
('vpnagent',
neutron_vpnaas.services.vpn.ovn_agent.VPN_AGENT_OPTS),
@ -51,7 +52,7 @@ def list_ovn_agent_opts():
]
def list_opts():
def list_opts() -> ty.List[ty.Tuple[str, ty.List]]:
return [
('service_providers',
neutron.services.provider_configuration.serviceprovider_opts)

View File

@ -14,33 +14,42 @@
import abc
import random
import typing as ty
from neutron.extensions import agent as nagent
from neutron.extensions import availability_zone as az_ext
from neutron.extensions import l3
from neutron_lib import context
from neutron_lib.plugins import constants as plugin_constants
from neutron_lib.plugins import directory
from oslo_config import cfg
from oslo_log import log as logging
if ty.TYPE_CHECKING:
from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as scheduler_db
from neutron_vpnaas.extensions import vpn_agentschedulers
LOG = logging.getLogger(__name__)
class VPNScheduler(object, metaclass=abc.ABCMeta):
class VPNScheduler(metaclass=abc.ABCMeta):
@property
def l3_plugin(self):
def l3_plugin(self) -> l3.RouterPluginBase:
return directory.get_plugin(plugin_constants.L3)
@abc.abstractmethod
def schedule(self, plugin, context, router_id,
candidates=None, hints=None):
def schedule(self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context, router_id, candidates=None, hints=None) -> \
ty.Optional[nagent.Agent]:
"""Schedule the router to an active VPN agent.
Schedule the router only if it is not already scheduled.
"""
pass
def _get_unscheduled_routers(self, context, plugin, router_ids=None):
def _get_unscheduled_routers(self, context: context.Context,
plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
router_ids=None) -> ty.List:
"""Get the list of routers with VPN services to be scheduled.
If router IDs are omitted, look for all unscheduled routers.
@ -57,7 +66,10 @@ class VPNScheduler(object, metaclass=abc.ABCMeta):
context, filters={'id': unscheduled_router_ids})
return []
def _get_routers_can_schedule(self, context, plugin, routers, vpn_agent):
def _get_routers_can_schedule(
self, context: context.Context,
plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase,
routers, vpn_agent,):
"""Get the subset of routers whose VPN services can be scheduled on
the VPN agent.
"""
@ -65,7 +77,9 @@ class VPNScheduler(object, metaclass=abc.ABCMeta):
# all routers can be scheduled to it
return routers
def auto_schedule_routers(self, plugin, context, vpn_agent):
def auto_schedule_routers(
self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context, vpn_agent) -> ty.List[str]:
"""Schedule non-hosted routers to a VPN agent.
:returns: True if routers have been successfully assigned to the agent
@ -83,20 +97,31 @@ class VPNScheduler(object, metaclass=abc.ABCMeta):
self._bind_routers(context, plugin, target_routers, vpn_agent)
return [router['id'] for router in target_routers]
def _get_candidates(self, plugin, context, sync_router):
def _get_candidates(
self,
plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context,
sync_router) -> ty.Optional[ty.List[nagent.Agent]]:
"""Return VPN agents where a router could be scheduled."""
active_vpn_agents = plugin.get_vpn_agents(context, active=True)
if not active_vpn_agents:
LOG.warning('No active VPN agents')
return active_vpn_agents
def _bind_routers(self, context, plugin, routers, vpn_agent):
def _bind_routers(
self, context: context.Context,
plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
routers, vpn_agent):
for router in routers:
plugin.create_router_to_agent_binding(
context, router['id'], vpn_agent['id'])
def _schedule_router(self, plugin, context, router_id,
candidates=None):
def _schedule_router(
self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context,
router_id,
candidates: ty.Optional[ty.List[nagent.Agent]] = None) -> \
ty.Optional[nagent.Agent]:
current_vpn_agents = plugin.get_vpn_agents_hosting_routers(
context, [router_id])
if current_vpn_agents:
@ -118,9 +143,13 @@ class VPNScheduler(object, metaclass=abc.ABCMeta):
if plugin.create_router_to_agent_binding(context, router_id,
chosen_agent['id']):
return chosen_agent
return None
@abc.abstractmethod
def _choose_vpn_agent(self, plugin, context, candidates):
def _choose_vpn_agent(
self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context,
candidates: ty.List[nagent.Agent]) -> nagent.Agent:
"""Choose an agent from candidates based on a specific policy."""
pass
@ -128,24 +157,31 @@ class VPNScheduler(object, metaclass=abc.ABCMeta):
class ChanceScheduler(VPNScheduler):
"""Randomly allocate an VPN agent for a router."""
def schedule(self, plugin, context, router_id,
candidates=None):
def schedule(
self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context, router_id, candidates=None, hints=None):
return self._schedule_router(
plugin, context, router_id, candidates=candidates)
def _choose_vpn_agent(self, plugin, context, candidates):
def _choose_vpn_agent(
self, plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase,
context: context.Context,
candidates: ty.List[nagent.Agent]) -> nagent.Agent:
return random.choice(candidates)
class LeastRoutersScheduler(VPNScheduler):
"""Allocate to an VPN agent with the least number of routers bound."""
def schedule(self, plugin, context, router_id,
candidates=None):
def schedule(self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context, router_id, candidates=None, hints=None):
return self._schedule_router(
plugin, context, router_id, candidates=candidates)
def _choose_vpn_agent(self, plugin, context, candidates):
def _choose_vpn_agent(
self, plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context,
candidates: ty.List[nagent.Agent]) -> nagent.Agent:
candidates_dict = {c['id']: c for c in candidates}
chosen_agent_id = plugin.get_vpn_agent_with_min_routers(
context, candidates_dict.keys())
@ -158,7 +194,10 @@ class AZLeastRoutersScheduler(LeastRoutersScheduler):
return (router.get(az_ext.AZ_HINTS) or
cfg.CONF.default_availability_zones)
def _get_routers_can_schedule(self, context, plugin, routers, vpn_agent):
def _get_routers_can_schedule(
self, context: context.Context,
plugin: vpn_agentschedulers.VPNAgentSchedulerPluginBase,
routers, vpn_agent):
"""Overwrite VPNScheduler's method to filter by availability zone."""
target_routers = []
for r in routers:
@ -172,10 +211,15 @@ class AZLeastRoutersScheduler(LeastRoutersScheduler):
return super()._get_routers_can_schedule(
context, plugin, target_routers, vpn_agent)
def _get_candidates(self, plugin, context, sync_router):
def _get_candidates(
self,
plugin: 'scheduler_db.VPNAgentSchedulerDbMixin',
context: context.Context,
sync_router) -> ty.Optional[ty.List[nagent.Agent]]:
"""Overwrite VPNScheduler's method to filter by availability zone."""
all_candidates = super()._get_candidates(plugin, context, sync_router)
if all_candidates:
candidates = []
az_hints = self._get_az_hints(sync_router)
for agent in all_candidates:
@ -183,3 +227,4 @@ class AZLeastRoutersScheduler(LeastRoutersScheduler):
candidates.append(agent)
return candidates
return None

View File

@ -14,12 +14,16 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron.agent.l3 import l3_agent_extension_api
from neutron_lib.agent import l3_extension
from neutron_lib import context
from oslo_config import cfg
from oslo_log import log as logging
from neutron_vpnaas._i18n import _
from neutron_vpnaas.services.vpn import device_drivers
from neutron_vpnaas.services.vpn import vpn_service
LOG = logging.getLogger(__name__)
@ -43,10 +47,11 @@ cfg.CONF.register_opts(vpn_agent_opts, 'vpnagent')
class VPNAgent(l3_extension.L3AgentExtension):
"""VPNaaS Agent support to be used by Neutron L3 agent."""
def initialize(self, connection, driver_type):
def initialize(self, connection, driver_type: device_drivers.DeviceDriver):
LOG.debug("Loading VPNaaS")
def consume_api(self, agent_api):
def consume_api(self,
agent_api: l3_agent_extension_api.L3AgentExtensionAPI):
LOG.debug("Loading consume_api for VPNaaS")
self.agent_api = agent_api
@ -58,7 +63,7 @@ class VPNAgent(l3_extension.L3AgentExtension):
self.service = vpn_service.VPNService(self)
self.device_drivers = self.service.load_device_drivers(self.host)
def add_router(self, context, data):
def add_router(self, context: context.Context, data):
"""Handles router add event"""
ri = self.agent_api.get_router_info(data['id'])
if ri is not None:
@ -69,17 +74,18 @@ class VPNAgent(l3_extension.L3AgentExtension):
LOG.debug("Router %s was concurrently deleted while "
"creating VPN for it", data['id'])
def update_router(self, context, data):
def update_router(self, context: context.Context, data):
"""Handles router update event"""
for device_driver in self.device_drivers:
device_driver.sync(context, [data])
def delete_router(self, context, data):
def delete_router(self, context: context.Context, data):
"""Handles router delete event"""
for device_driver in self.device_drivers:
device_driver.destroy_router(data['id'])
def ha_state_change(self, context, data):
def ha_state_change(self, context: context.Context,
data: ty.Dict[str, str]):
"""Enable the vpn process when router transitioned to master.
And disable vpn process for backup router.
@ -98,7 +104,7 @@ class VPNAgent(l3_extension.L3AgentExtension):
else:
process.disable()
def update_network(self, context, data):
def update_network(self, context: context.Context, data):
pass
@ -109,5 +115,4 @@ class L3WithVPNaaS(VPNAgent):
self.conf = conf
else:
self.conf = cfg.CONF
super(L3WithVPNaaS, self).__init__(
host=self.conf.host, conf=self.conf)
super().__init__(host=self.conf.host, conf=self.conf)

View File

@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import configparser as ConfigParser
import errno
import os
@ -31,7 +33,7 @@ from neutron_vpnaas._i18n import _
LOG = logging.getLogger(__name__)
def setup_conf():
def setup_conf() -> cfg.ConfigOpts:
cli_opts = [
cfg.DictOpt('mount_paths',
required=True,
@ -50,9 +52,9 @@ def setup_conf():
return conf
def execute(cmd):
def execute(cmd) -> ty.Optional[int]:
if not cmd:
return
return None
cmd = list(map(str, cmd))
LOG.debug("Running command: %s", cmd)
env = os.environ.copy()
@ -106,12 +108,12 @@ def filter_command(command, rootwrap_config):
'name': exc.match.name})
sys.exit(errno.EINVAL)
except wrapper.NoFilterMatched:
LOG.error('Unauthorized command: %(cmd)s (no filter matched)',
{'cmd': command})
LOG.error("Unauthorized command: %(cmd)s (no filter matched)",
{"cmd": command})
sys.exit(errno.EPERM)
def execute_with_mount():
def execute_with_mount() -> ty.Optional[int]:
config.register_common_config_options()
conf = setup_conf()
conf()

View File

@ -14,14 +14,16 @@
# under the License.
import abc
from neutron_lib import context
class DeviceDriver(object, metaclass=abc.ABCMeta):
class DeviceDriver(metaclass=abc.ABCMeta):
def __init__(self, agent, host):
pass
@abc.abstractmethod
def sync(self, context, processes):
def sync(self, context: context.ContextBase, processes):
pass
@abc.abstractmethod

View File

@ -12,6 +12,9 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import abc
import base64
import copy
@ -126,7 +129,8 @@ IPSEC_CONNS = 'ipsec_site_connections'
PSK_BASE64_PREFIX = '0s'
def _get_template(template_file):
def _get_template(
template_file: ty.Union[str, jinja2.Template]) -> jinja2.Template:
global JINJA_ENV
if not JINJA_ENV:
templateLoader = jinja2.FileSystemLoader(searchpath="/")
@ -134,7 +138,7 @@ def _get_template(template_file):
return JINJA_ENV.get_template(template_file)
class BaseSwanProcess(object, metaclass=abc.ABCMeta):
class BaseSwanProcess(metaclass=abc.ABCMeta):
"""Swan Family Process Manager
This class manages start/restart/stop ipsec process.
@ -189,12 +193,12 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
STATUS_IPSEC_SA_ESTABLISHED_RE2 = (
r'\d{3} #\d+: "([a-f0-9\-\/x]+).*established.*newest IPSEC')
def __init__(self, conf, process_id, vpnservice, namespace):
def __init__(self, conf, process_id: str, vpnservice, namespace: str):
self.conf = conf
self.id = process_id
self.updated_pending_status = False
self.updated_pending_status: bool = False
self.namespace = namespace
self.connection_status = {}
self.connection_status: ty.Dict[str, ty.Dict[str, ty.Any]] = {}
self.config_dir = os.path.join(
self.conf.ipsec.config_base_dir, self.id)
self.etc_dir = os.path.join(self.config_dir, 'etc')
@ -212,32 +216,31 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
def translate_dialect(self):
if not self.vpnservice:
return
for ipsec_site_conn in self.vpnservice['ipsec_site_connections']:
self._dialect(ipsec_site_conn, 'initiator')
self._dialect(ipsec_site_conn['ikepolicy'], 'ike_version')
for key in ['encryption_algorithm',
'auth_algorithm',
'pfs']:
self._dialect(ipsec_site_conn['ikepolicy'], key)
self._dialect(ipsec_site_conn['ipsecpolicy'], key)
if (('local_id' not in ipsec_site_conn.keys()) or
(not ipsec_site_conn['local_id'])):
ipsec_site_conn['local_id'] = ipsec_site_conn['external_ip']
for ipsec_site_conn in self.vpnservice["ipsec_site_connections"]:
self._dialect(ipsec_site_conn, "initiator")
self._dialect(ipsec_site_conn["ikepolicy"], "ike_version")
for key in ["encryption_algorithm", "auth_algorithm", "pfs"]:
self._dialect(ipsec_site_conn["ikepolicy"], key)
self._dialect(ipsec_site_conn["ipsecpolicy"], key)
if ("local_id" not in ipsec_site_conn.keys()) or (
not ipsec_site_conn["local_id"]
):
ipsec_site_conn["local_id"] = ipsec_site_conn["external_ip"]
def base64_encode_psk(self):
if not self.vpnservice:
return
for ipsec_site_conn in self.vpnservice['ipsec_site_connections']:
psk = ipsec_site_conn['psk']
for ipsec_site_conn in self.vpnservice["ipsec_site_connections"]:
psk = ipsec_site_conn["psk"]
encoded_psk = base64.b64encode(encodeutils.safe_encode(psk))
# NOTE(huntxu): base64.b64encode returns an instance of 'bytes'
# in Python 3, convert it to a str. For Python 2, after calling
# safe_decode, psk is converted into a unicode not containing any
# non-ASCII characters so it doesn't matter.
psk = encodeutils.safe_decode(encoded_psk, incoming='utf_8')
ipsec_site_conn['psk'] = PSK_BASE64_PREFIX + psk
psk = encodeutils.safe_decode(encoded_psk, incoming="utf_8")
ipsec_site_conn["psk"] = PSK_BASE64_PREFIX + psk
def get_ns_wrapper(self):
def get_ns_wrapper(self) -> str:
"""
Check if we're inside a virtualenv. If we are, then we should
respect this and launch wrapper from venv as well.
@ -249,7 +252,7 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
ns_wrapper = self.NS_WRAPPER
return ns_wrapper
def update_vpnservice(self, vpnservice):
def update_vpnservice(self, vpnservice: ty.Dict[str, ty.Any]):
self.vpnservice = vpnservice
self.translate_dialect()
self.base64_encode_psk()
@ -275,7 +278,7 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
agent_utils.execute(
cmd=["rm", "-rf", self.config_dir], run_as_root=True)
def _get_config_filename(self, kind):
def _get_config_filename(self, kind) -> str:
config_dir = self.etc_dir
return os.path.join(config_dir, kind)
@ -286,15 +289,15 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
dir_path = os.path.join(self.config_dir, subdir)
fileutils.ensure_tree(dir_path, 0o755)
def _gen_config_content(self, template_file, vpnservice):
def _gen_config_content(self, template_file, vpnservice) -> str:
template = _get_template(template_file)
return template.render(
{'vpnservice': vpnservice,
'state_path': self.conf.state_path})
def _get_rootwrap_config(self):
def _get_rootwrap_config(self) -> ty.Optional[str]:
if 'neutron-rootwrap' in cfg.CONF.AGENT.root_helper:
rh_tokens = cfg.CONF.AGENT.root_helper.split(' ')
rh_tokens: ty.List[str] = cfg.CONF.AGENT.root_helper.split(' ')
if len(rh_tokens) == 3 and os.path.exists(rh_tokens[2]):
return rh_tokens[2]
return None
@ -304,13 +307,13 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
pass
@property
def status(self):
def status(self) -> str:
if self.active:
return constants.ACTIVE
return constants.DOWN
@property
def active(self):
def active(self) -> bool:
"""Check if the process is active or not."""
if not self.namespace:
return False
@ -329,8 +332,9 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
# Disable the process if a vpnservice is disabled or it has no
# enabled IPSec site connections.
vpnservice_has_active_ipsec_site_conns = any(
[ipsec_site_conn['admin_state_up']
for ipsec_site_conn in self.vpnservice['ipsec_site_connections']])
ipsec_site_conn['admin_state_up']
for ipsec_site_conn in
self.vpnservice['ipsec_site_connections'])
if (not self.vpnservice['admin_state_up'] or
not vpnservice_has_active_ipsec_site_conns):
self.disable()
@ -386,7 +390,8 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
def stop(self):
"""Stop process."""
def _check_status_line(self, line):
def _check_status_line(self,
line: str) -> ty.Tuple[ty.Optional[str], ty.Optional[str]]:
"""Parse a line and search for status information.
If a connection has an established Security Association,
@ -404,14 +409,15 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
if m:
connection_id = m.group(1)
return connection_id, constants.ACTIVE
else:
m = self.STATUS_PATTERN.search(line)
if m:
connection_id = m.group(1)
return connection_id, constants.DOWN
return None, None
def _extract_and_record_connection_status(self, status_output):
def _extract_and_record_connection_status(self,
status_output: ty.Optional[str]):
if not status_output:
self.connection_status = {}
return
@ -423,8 +429,8 @@ class BaseSwanProcess(object, metaclass=abc.ABCMeta):
if conn_id:
self._record_connection_status(conn_id, conn_status)
def _record_connection_status(self, connection_id, status,
force_status_update=False):
def _record_connection_status(self, connection_id: str, status,
force_status_update: bool = False):
conn_info = self.connection_status.get(connection_id)
if not conn_info:
self.connection_status[connection_id] = {
@ -445,18 +451,20 @@ class OpenSwanProcess(BaseSwanProcess):
(2) ipsec addconn: Adds new ipsec addconn
(3) ipsec whack: control interface for IPSEC keying daemon
"""
def __init__(self, conf, process_id, vpnservice, namespace):
super(OpenSwanProcess, self).__init__(conf, process_id,
vpnservice, namespace)
self.secrets_file = os.path.join(
self.etc_dir, 'ipsec.secrets')
self.config_file = os.path.join(
self.etc_dir, 'ipsec.conf')
self.pid_path = os.path.join(
self.config_dir, 'var', 'run', 'pluto')
self.pid_file = '%s.pid' % self.pid_path
def _execute(self, cmd, check_exit_code=True, extra_ok_codes=None):
def __init__(self, conf, process_id: str, vpnservice, namespace: str):
super().__init__(conf, process_id, vpnservice, namespace)
self.secrets_file: str = os.path.join(
self.etc_dir, 'ipsec.secrets')
self.config_file: str = os.path.join(
self.etc_dir, 'ipsec.conf')
self.pid_path: str = os.path.join(
self.config_dir, 'var', 'run', 'pluto')
self.pid_file: str = f'{self.pid_path}.pid'
def _execute(self, cmd, check_exit_code: bool = True,
extra_ok_codes: ty.Optional[ty.List[int]] = None
) -> ty.Optional[str]:
"""Execute command on namespace."""
ip_wrapper = ip_lib.IPWrapper(namespace=self.namespace)
return ip_wrapper.netns.execute(cmd, check_exit_code=check_exit_code,
@ -490,7 +498,7 @@ class OpenSwanProcess(BaseSwanProcess):
shutil.copyfile(config_file_name, config_file_name + '.old')
os.chmod(config_file_name + '.old', 0o600)
def _process_running(self):
def _process_running(self) -> bool:
"""Checks if process is still running."""
# If no PID file, we assume the process is not running.
@ -502,9 +510,10 @@ class OpenSwanProcess(BaseSwanProcess):
# on throwing to tell us something. If the pid file exists,
# delve into the process information and check if it matches
# our expected command line.
with open(self.pid_file, 'r') as f:
with open(self.pid_file, 'r', encoding="C") as f:
pid = f.readline().strip()
with open('/proc/%s/cmdline' % pid) as cmd_line_file:
with open(f'/proc/{pid}/cmdline',
encoding="C") as cmd_line_file:
cmd_line = cmd_line_file.readline()
if self.pid_path in cmd_line and 'pluto' in cmd_line:
# Okay the process is probably a pluto process
@ -529,7 +538,7 @@ class OpenSwanProcess(BaseSwanProcess):
def _cleanup_control_files(self):
try:
ctl_file = '%s.ctl' % self.pid_path
ctl_file = f'{self.pid_path}.ctl'
LOG.debug('Removing %(pidfile)s and %(ctlfile)s',
{'pidfile': self.pid_file,
'ctlfile': ctl_file})
@ -545,14 +554,14 @@ class OpenSwanProcess(BaseSwanProcess):
'files for router %(router)s. %(msg)s',
{'router': self.id, 'msg': e})
def get_status(self):
def get_status(self) -> ty.Optional[str]:
return self._execute([self.binary,
'whack',
'--ctlbase',
self.pid_path,
'--status'], extra_ok_codes=[1, 3])
def _config_changed(self):
def _config_changed(self) -> bool:
secrets_file = os.path.join(
self.etc_dir, 'ipsec.secrets')
config_file = os.path.join(
@ -590,19 +599,25 @@ class OpenSwanProcess(BaseSwanProcess):
LOG.warning('Server appears to still be running, restart '
'of router %s may fail', self.id)
self.start()
return
def _resolve_fqdn(self, fqdn):
def _resolve_fqdn(self, fqdn) -> ty.Optional[str]:
# The first addrinfo member from the list returned by
# socket.getaddrinfo is used for the address resolution.
# The code doesn't filter for ipv4 or ipv6 address.
try:
addrinfo = socket.getaddrinfo(fqdn, None)[0]
addrinfo: ty.Tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
ty.Union[ty.Tuple[str, int], ty.Tuple[str, int, int, int]],
] = socket.getaddrinfo(fqdn, None)[0]
return addrinfo[-1][0]
except socket.gaierror:
LOG.exception("Peer address %s cannot be resolved", fqdn)
return None
def _get_nexthop(self, address, connection_id):
def _get_nexthop(self, address: str, connection_id: str) -> str:
# check if address is an ip address or fqdn
invalid_ip_address = validators.validate_ip_address(address)
if invalid_ip_address:
@ -615,28 +630,31 @@ class OpenSwanProcess(BaseSwanProcess):
else:
ip_addr = address
routes = self._execute(['ip', 'route', 'get', ip_addr])
if routes.find('via') >= 0:
if routes and routes.find('via') >= 0:
return routes.split(' ')[2]
return address
def _virtual_privates(self, vpnservice):
def _virtual_privates(self,
vpnservice: ty.Dict[str, ty.List[ty.Dict[str, ty.Any]]]) -> str:
"""Returns line of virtual_privates.
virtual_private contains the networks
that are allowed as subnet for the remote client.
"""
virtual_privates = []
nets = []
virtual_privates: ty.List[str] = []
nets: ty.List = []
for ipsec_site_conn in vpnservice['ipsec_site_connections']:
nets += ipsec_site_conn['local_cidrs']
nets += ipsec_site_conn['peer_cidrs']
for net in nets:
version = netaddr.IPNetwork(net).version
virtual_privates.append('%%v%s:%s' % (version, net))
virtual_privates.append(f'%v{version}:{net}')
virtual_privates.sort()
return ','.join(virtual_privates)
def _gen_config_content(self, template_file, vpnservice):
def _gen_config_content(self,
template_file: ty.Union[str, jinja2.Template],
vpnservice) -> str:
template = _get_template(template_file)
virtual_privates = self._virtual_privates(vpnservice)
return template.render(
@ -660,7 +678,7 @@ class OpenSwanProcess(BaseSwanProcess):
def add_ipsec_connection(self, nexthop, conn_id):
self._execute([self.binary,
'addconn',
'--ctlbase', '%s.ctl' % self.pid_path,
'--ctlbase', f'{self.pid_path}.ctl',
'--defaultroutenexthop', nexthop,
'--config', self.config_file, conn_id
])
@ -745,9 +763,9 @@ class OpenSwanProcess(BaseSwanProcess):
self.initiate_connection(ipsec_site_conn['id'])
self._copy_configs()
def get_established_connections(self):
connections = []
status_output = self.get_status()
def get_established_connections(self) -> ty.List[str]:
connections: ty.List[str] = []
status_output: ty.Optional[str] = self.get_status()
if not status_output:
return connections
@ -781,7 +799,7 @@ class OpenSwanProcess(BaseSwanProcess):
self.connection_status = {}
class IPsecVpnDriverApi(object):
class IPsecVpnDriverApi:
"""IPSecVpnDriver RPC api."""
@log_helpers.log_method_call
@ -790,7 +808,7 @@ class IPsecVpnDriverApi(object):
self.client = n_rpc.get_client(target)
@log_helpers.log_method_call
def get_vpn_services_on_host(self, context, host):
def get_vpn_services_on_host(self, context: context.ContextBase, host):
"""Get list of vpnservices.
The vpnservices including related ipsec_site_connection,
@ -800,7 +818,7 @@ class IPsecVpnDriverApi(object):
return cctxt.call(context, 'get_vpn_services_on_host', host=host)
@log_helpers.log_method_call
def update_status(self, context, status):
def update_status(self, context: context.ContextBase, status):
"""Update local status.
This method call updates status attribute of
@ -823,19 +841,20 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
# 1.0 Initial version
target = oslo_messaging.Target(version='1.0')
def __init__(self, vpn_service, host):
def __init__(self, vpn_service, host: str):
# TODO(pc_m) Replace vpn_service with config arg, once all driver
# implementations no longer need vpn_service.
self.conf = vpn_service.conf
self.host = host
self.conn = n_rpc.Connection()
self.context = context.get_admin_context_without_session()
self.topic = topics.IPSEC_AGENT_TOPIC
node_topic = '%s.%s' % (self.topic, self.host)
self.context: context.ContextBase = \
context.get_admin_context_without_session()
self.topic: str = topics.IPSEC_AGENT_TOPIC
node_topic: str = f'{self.topic}.{self.host}'
self.processes = {}
self.routers = {}
self.process_status_cache = {}
self.processes: ty.Dict[str, BaseSwanProcess] = {}
self.routers: ty.Dict[str, ty.Any] = {}
self.process_status_cache: ty.Dict[str, ty.Dict[str, ty.Any]] = {}
self.endpoints = [self]
self.conn.create_consumer(node_topic, self.endpoints, fanout=False)
@ -846,7 +865,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
self.process_status_cache_check.start(
interval=self.conf.ipsec.ipsec_status_check_interval)
def get_namespace(self, router_id):
def get_namespace(self, router_id: str) -> ty.Optional[str]:
"""Get namespace of router.
:router_id: router_id
@ -856,7 +875,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
"""
router = self.routers.get(router_id)
if not router:
return
return None
# For DVR, use SNAT namespace
# TODO(pcm): Use router object method to tell if DVR, when available
if router.router['distributed']:
@ -941,14 +960,14 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
for peer_cidr in ipsec_site_connection['peer_cidrs']:
func(router_id,
'POSTROUTING',
'-s %s -d %s -m policy '
f'-s {local_cidr} -d {peer_cidr} -m policy '
'--dir out --pol ipsec '
'-j ACCEPT ' % (local_cidr, peer_cidr),
'-j ACCEPT ',
top=True)
self.iptables_apply(router_id)
@log_helpers.log_method_call
def vpnservice_updated(self, context, **kwargs):
def vpnservice_updated(self, context: context.ContextBase, **kwargs):
"""Vpnservice updated rpc handler
VPN Service Driver will call this method
@ -959,16 +978,18 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
self.sync(context, [router] if router else [])
@abc.abstractmethod
def create_process(self, process_id, vpnservice, namespace):
def create_process(self, process_id: str, vpnservice,
namespace) -> BaseSwanProcess:
pass
def ensure_process(self, process_id, vpnservice=None):
def ensure_process(self, process_id: str,
vpnservice=None) -> BaseSwanProcess:
"""Ensuring process.
If the process doesn't exist, it will create process
and store it in self.process
"""
process = self.processes.get(process_id)
process = self.processes.get(process_id, None)
if not process or not process.namespace:
namespace = self.get_namespace(process_id)
process = self.create_process(
@ -1022,7 +1043,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
if process_id in self.routers:
del self.routers[process_id]
def get_process_status_cache(self, process):
def get_process_status_cache(self,
process: BaseSwanProcess) -> ty.Dict[str, ty.Any]:
if not self.process_status_cache.get(process.id):
self.process_status_cache[process.id] = {
'status': None,
@ -1031,7 +1053,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
'ipsec_site_connections': {}}
return self.process_status_cache[process.id]
def is_status_updated(self, process, previous_status):
def is_status_updated(self, process: BaseSwanProcess,
previous_status: ty.Dict[str, ty.Any]) -> bool:
if process.updated_pending_status:
return True
if process.status != previous_status['status']:
@ -1039,13 +1062,15 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
if (process.connection_status !=
previous_status['ipsec_site_connections']):
return True
return False
def unset_updated_pending_status(self, process):
process.updated_pending_status = False
for connection_status in process.connection_status.values():
connection_status['updated_pending_status'] = False
def copy_process_status(self, process):
def copy_process_status(self,
process: BaseSwanProcess) -> ty.Dict[str, ty.Any]:
return {
'id': process.vpnservice['id'],
'status': process.status,
@ -1053,7 +1078,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
'ipsec_site_connections': copy.deepcopy(process.connection_status)
}
def update_downed_connections(self, process_id, new_status):
def update_downed_connections(self, process_id: str, new_status):
"""Update info to be reported, if connections just went down.
If there is no longer any information for a connection, because it
@ -1069,13 +1094,15 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
'updated_pending_status': True
}
def should_be_reported(self, context, process):
def should_be_reported(self, context: context.ContextBase,
process: BaseSwanProcess) -> bool:
if (context.is_admin or
process.vpnservice["tenant_id"] == context.tenant_id):
return True
return False
@log_helpers.log_method_call
def report_status(self, context):
def report_status(self, context: context.ContextBase):
status_changed_vpn_services = []
for process_id, process in list(self.processes.items()):
# NOTE(mnaser): It's not necessary to check status for processes
@ -1104,7 +1131,7 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
@log_helpers.log_method_call
@lockutils.synchronized('vpn-agent', 'neutron-')
def sync(self, context, routers):
def sync(self, context: context.ContextBase, routers):
"""Sync status with server side.
:param context: context object for RPC call
@ -1122,8 +1149,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
"""
vpnservices = self.agent_rpc.get_vpn_services_on_host(
context, self.host)
router_ids = [vpnservice['router_id'] for vpnservice in vpnservices]
sync_router_ids = [router['id'] for router in routers]
router_ids = [vpnservice["router_id"] for vpnservice in vpnservices]
sync_router_ids = [router["id"] for router in routers]
self._sync_vpn_processes(vpnservices, sync_router_ids)
self._delete_vpn_processes(sync_router_ids, router_ids)
@ -1138,8 +1165,10 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
for vpnservice in vpnservices:
if vpnservice['router_id'] not in self.processes or (
vpnservice['router_id'] in sync_router_ids):
process = self.ensure_process(vpnservice['router_id'],
vpnservice=vpnservice)
process: ty.Optional[BaseSwanProcess] = self.ensure_process(
vpnservice['router_id'], vpnservice=vpnservice)
if not process:
return
self._update_nat(vpnservice, self.add_nat_rule)
router = self.routers.get(vpnservice['router_id'])
if not router:
@ -1168,7 +1197,8 @@ class IPsecDriver(device_drivers.DeviceDriver, metaclass=abc.ABCMeta):
class OpenSwanDriver(IPsecDriver):
def create_process(self, process_id, vpnservice, namespace):
def create_process(self, process_id: str, vpnservice,
namespace: str) -> BaseSwanProcess:
return OpenSwanProcess(
self.conf,
process_id,

View File

@ -12,8 +12,9 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import os
import os.path
from neutron.agent.linux import ip_lib
@ -26,39 +27,38 @@ class LibreSwanProcess(ipsec.OpenSwanProcess):
Libreswan needs nssdb initialised before running pluto daemon.
"""
# pylint: disable=useless-super-delegation
def __init__(self, conf, process_id, vpnservice, namespace):
def __init__(self, conf, process_id: str, vpnservice, namespace: str):
self._rootwrap_cfg = self._get_rootwrap_config()
super(LibreSwanProcess, self).__init__(conf, process_id,
vpnservice, namespace)
super().__init__(conf, process_id, vpnservice, namespace)
def _ipsec_execute(self, cmd, check_exit_code=True, extra_ok_codes=None):
def _ipsec_execute(self, cmd: ty.List[str], check_exit_code: bool = True,
extra_ok_codes: ty.Optional[ty.List[int]] = None):
"""Execute ipsec command on namespace.
This execute is wrapped by namespace wrapper.
The namespace wrapper will bind /etc and /var/run
"""
ip_wrapper = ip_lib.IPWrapper(namespace=self.namespace)
mount_paths = {'/etc': '%s/etc' % self.config_dir,
'/var/run': '%s/var/run' % self.config_dir}
mount_paths = {'/etc': f'{self.config_dir}/etc',
'/var/run': f'{self.config_dir}/var/run'}
mount_paths_str = ','.join(
"%s:%s" % (source, target)
for source, target in mount_paths.items())
f"{source}:{target}" for source, target in mount_paths.items())
ns_wrapper = self.get_ns_wrapper()
return ip_wrapper.netns.execute(
[ns_wrapper,
'--mount_paths=%s' % mount_paths_str,
('--rootwrap_config=%s' % self._rootwrap_cfg
if self._rootwrap_cfg else ''),
'--cmd=%s,%s' % (self.binary, ','.join(cmd))],
f'--mount_paths={mount_paths_str}',
'--rootwrap_config={0}'.format(
self._rootwrap_cfg if self._rootwrap_cfg else ""),
f'--cmd={self.binary},{",".join(cmd)}'],
check_exit_code=check_exit_code,
extra_ok_codes=extra_ok_codes)
def _ensure_needed_files(self):
# addconn reads from /etc/hosts and /etc/resolv.conf. As /etc would be
# bind-mounted, create these two empty files in the target directory.
with open('%s/etc/hosts' % self.config_dir, 'a'):
with open(f'{self.config_dir}/etc/hosts', 'a', encoding="utf8"):
pass
with open('%s/etc/resolv.conf' % self.config_dir, 'a'):
with open(f'{self.config_dir}/etc/resolv.conf', 'a', encoding="utf8"):
pass
def ensure_configs(self):
@ -75,17 +75,17 @@ class LibreSwanProcess(ipsec.OpenSwanProcess):
if os.path.exists(secrets_file):
self._execute(['rm', '-f', secrets_file])
super(LibreSwanProcess, self).ensure_configs()
super().ensure_configs()
# LibreSwan uses the capabilities library to restrict access to
# ipsec.secrets to users that have explicit access. Since pluto is
# running as root and the file has 0600 perms, we must set the
# owner of the file to root.
self._execute(['chown', '--from=%s' % os.getuid(), 'root:root',
self._execute(['chown', f'--from={os.getuid()}', 'root:root',
secrets_file])
# Libreswan needs to write logs to this directory.
self._execute(['chown', '--from=%s' % os.getuid(), 'root:root',
self._execute(['chown', f'--from={os.getuid()}', 'root:root',
self.log_dir])
self._ensure_needed_files()
@ -131,11 +131,12 @@ class LibreSwanProcess(ipsec.OpenSwanProcess):
['whack', '--name', conn_name, '--asynchronous', '--initiate'])
def terminate_connection(self, conn_name):
self._ipsec_execute(['whack', '--name', conn_name, '--terminate'])
self._ipsec_execute(["whack", "--name", conn_name, "--terminate"])
class LibreSwanDriver(ipsec.IPsecDriver):
def create_process(self, process_id, vpnservice, namespace):
def create_process(self, process_id: str, vpnservice,
namespace: str) -> ipsec.BaseSwanProcess:
return LibreSwanProcess(
self.conf,
process_id,

View File

@ -14,9 +14,15 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import abc
import netaddr
from neutron.agent.common import utils as agent_common_utils
from neutron.agent.linux import interface
from neutron.agent.linux import ip_lib
from neutron_lib import constants as lib_constants
from neutron_lib import context as nctx
@ -30,7 +36,7 @@ from neutron_vpnaas.services.vpn.device_drivers import strongswan_ipsec
PORT_PREFIX_INTERNAL = 'vr'
PORT_PREFIX_EXTERNAL = 'vg'
PORT_PREFIXES = {
PORT_PREFIXES: ty.Dict[str, str] = {
'internal': PORT_PREFIX_INTERNAL,
'external': PORT_PREFIX_EXTERNAL,
}
@ -38,7 +44,7 @@ PORT_PREFIXES = {
LOG = logging.getLogger(__name__)
class DeviceManager(object):
class DeviceManager:
"""Device Manager for ports in qvpn-xx namespace.
It is a veth pair, one side in qvpn and the other
side is attached to ovs.
@ -51,16 +57,17 @@ class DeviceManager(object):
self.host = host
self.plugin = plugin
self.context = context
self.driver = agent_common_utils.load_interface_driver(conf)
self.driver: interface.LinuxInterfaceDriver = \
agent_common_utils.load_interface_driver(conf)
def get_interface_name(self, port, ptype):
def get_interface_name(self, port: ty.Dict[str, str], ptype: str) -> str:
suffix = port['id']
return (PORT_PREFIXES[ptype] + suffix)[:self.driver.DEV_NAME_LEN]
def get_namespace_name(self, process_id):
def get_namespace_name(self, process_id: str):
return self.OVN_NS_PREFIX + process_id
def get_existing_process_ids(self):
def get_existing_process_ids(self) -> ty.List:
"""Return the process IDs derived from the existing VPN namespaces."""
return [ns[len(self.OVN_NS_PREFIX):]
for ns in ip_lib.list_network_namespaces()
@ -87,7 +94,8 @@ class DeviceManager(object):
device.route.delete_route(cidr, via=via, metric=100,
proto='static')
def list_routes(self, namespace, via=None):
def list_routes(self, namespace,
via=None) -> ty.List[ty.Dict[str, ty.Any]]:
device = ip_lib.IPDevice(None, namespace=namespace)
return device.route.list_routes(
lib_constants.IP_VERSION_4, proto='static', via=via)
@ -100,24 +108,26 @@ class DeviceManager(object):
for r in routes:
device.route.delete_route(r['cidr'], via=r['via'])
def _del_port(self, process_id, ptype):
def _del_port(self, process_id: str, ptype: str):
namespace = self.get_namespace_name(process_id)
prefix = PORT_PREFIXES[ptype]
device = ip_lib.IPDevice(None, namespace=namespace)
ports = device.addr.list()
ports: ty.List[ty.Dict[str, ty.Union[str, ty.Any]]] = \
device.addr.list()
for p in ports:
if not p['name'].startswith(prefix):
continue
interface_name = p['name']
self.driver.unplug(interface_name, namespace=namespace)
def del_internal_port(self, process_id):
def del_internal_port(self, process_id: str):
self._del_port(process_id, 'internal')
def del_external_port(self, process_id):
def del_external_port(self, process_id: str):
self._del_port(process_id, 'external')
def setup_external(self, process_id, network_details):
def setup_external(self, process_id: str,
network_details) -> ty.Optional[str]:
network = network_details["external_network"]
vpn_port = network_details['gw_port']
ns_name = self.get_namespace_name(process_id)
@ -143,7 +153,7 @@ class DeviceManager(object):
subnet_id = fixed_ip['subnet_id']
subnet = self.plugin.get_subnet_info(subnet_id)
net = netaddr.IPNetwork(subnet['cidr'])
ip_cidr = '%s/%s' % (fixed_ip['ip_address'], net.prefixlen)
ip_cidr = f'{fixed_ip["ip_address"]}/{net.prefixlen}'
ip_cidrs.append(ip_cidr)
subnets.append(subnet)
self.driver.init_l3(interface_name, ip_cidrs,
@ -152,7 +162,7 @@ class DeviceManager(object):
self.set_default_route(ns_name, subnet, interface_name)
return interface_name
def setup_internal(self, process_id, network_details):
def setup_internal(self, process_id, network_details) -> ty.Optional[str]:
vpn_port = network_details["transit_port"]
ns_name = self.get_namespace_name(process_id)
interface_name = self.get_interface_name(vpn_port, 'internal')
@ -172,19 +182,19 @@ class DeviceManager(object):
ip_cidrs = []
for fixed_ip in vpn_port['fixed_ips']:
ip_cidr = '%s/%s' % (fixed_ip['ip_address'], 28)
ip_cidr = f'{fixed_ip["ip_address"]}/28'
ip_cidrs.append(ip_cidr)
self.driver.init_l3(interface_name, ip_cidrs,
namespace=ns_name)
return interface_name
class NamespaceManager(object):
class NamespaceManager:
def __init__(self, use_ipv6=False):
self.ip_wrapper_root = ip_lib.IPWrapper()
self.use_ipv6 = use_ipv6
def exists(self, name):
def exists(self, name) -> bool:
return ip_lib.network_namespace_exists(name)
def create(self, name):
@ -231,9 +241,9 @@ class IPsecOvnDriverApi(ipsec.IPsecVpnDriverApi):
subnet_id=subnet_id)
class OvnIPsecDriver(ipsec.IPsecDriver):
class OvnIPsecDriver(ipsec.IPsecDriver, metaclass=abc.ABCMeta):
def __init__(self, vpn_service, host):
def __init__(self, vpn_service, host: str):
self.nsmgr = NamespaceManager()
super().__init__(vpn_service, host)
self.agent_rpc = IPsecOvnDriverApi(topics.IPSEC_DRIVER_TOPIC)
@ -242,7 +252,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver):
get_router_based_iptables_manager = None
def get_namespace(self, router_id):
def get_namespace(self, router_id) -> str:
"""Get namespace for VPN services of router.
:router_id: router_id
@ -250,7 +260,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver):
"""
return self.devmgr.get_namespace_name(router_id)
def _cleanup_namespace(self, router_id):
def _cleanup_namespace(self, router_id: str):
ns_name = self.devmgr.get_namespace_name(router_id)
if not self.nsmgr.exists(ns_name):
return
@ -259,7 +269,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver):
self.devmgr.del_external_port(router_id)
self.nsmgr.delete(ns_name)
def _ensure_namespace(self, router_id, network_details):
def _ensure_namespace(self, router_id: str, network_details) -> str:
ns_name = self.get_namespace(router_id)
if not self.nsmgr.exists(ns_name):
self.nsmgr.create(ns_name)
@ -272,7 +282,12 @@ class OvnIPsecDriver(ipsec.IPsecDriver):
return ns_name
def destroy_process(self, process_id):
@abc.abstractmethod
def create_process(self, process_id: str,
vpnservice, namespace) -> ipsec.BaseSwanProcess:
pass
def destroy_process(self, process_id: str):
LOG.info('process %s is destroyed', process_id)
namespace = self.devmgr.get_namespace_name(process_id)
@ -316,7 +331,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver):
new_local_cidrs - old_local_cidrs,
gateway_ip)
def _sync_vpn_processes(self, vpnservices, sync_router_ids):
def _sync_vpn_processes(self, vpnservices, sync_router_ids: ty.List[str]):
# Ensure the ipsec process is enabled only for
# - the vpn services which are not yet in self.processes
# - vpn services whose router id is in 'sync_router_ids'
@ -331,7 +346,7 @@ class OvnIPsecDriver(ipsec.IPsecDriver):
process = self.ensure_process(router_id, vpnservice=vpnservice)
process.update()
def _cleanup_stale_vpn_processes(self, vpn_router_ids):
def _cleanup_stale_vpn_processes(self, vpn_router_ids: ty.List[str]):
super()._cleanup_stale_vpn_processes(vpn_router_ids)
# Look for additional namespaces on this node that we don't know
# and that should be deleted
@ -340,17 +355,20 @@ class OvnIPsecDriver(ipsec.IPsecDriver):
self.destroy_process(router_id)
@lockutils.synchronized('vpn-agent', 'neutron-')
def vpnservice_removed_from_agent(self, context, router_id):
def vpnservice_removed_from_agent(self, context: nctx.Context,
router_id: str):
# must run under the same lock as sync()
self.destroy_process(router_id)
def vpnservice_added_to_agent(self, context, router_ids):
def vpnservice_added_to_agent(self, context: nctx.Context,
router_ids: ty.List[str]):
routers = [{'id': router_id} for router_id in router_ids]
self.sync(context, routers)
class OvnStrongSwanDriver(OvnIPsecDriver):
def create_process(self, process_id, vpnservice, namespace):
def create_process(self, process_id: str, vpnservice,
namespace: str) -> ipsec.BaseSwanProcess:
return OvnStrongSwanProcess(
self.conf,
process_id,
@ -359,7 +377,8 @@ class OvnStrongSwanDriver(OvnIPsecDriver):
class OvnOpenSwanDriver(OvnIPsecDriver):
def create_process(self, process_id, vpnservice, namespace):
def create_process(self, process_id: str, vpnservice,
namespace: str) -> ipsec.BaseSwanProcess:
return OvnOpenSwanProcess(
self.conf,
process_id,
@ -368,7 +387,8 @@ class OvnOpenSwanDriver(OvnIPsecDriver):
class OvnLibreSwanDriver(OvnIPsecDriver):
def create_process(self, process_id, vpnservice, namespace):
def create_process(self, process_id: str, vpnservice,
namespace: str) -> ipsec.BaseSwanProcess:
return OvnLibreSwanProcess(
self.conf,
process_id,

View File

@ -12,8 +12,8 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import typing as ty
from oslo_config import cfg
from oslo_log import log as logging
@ -75,15 +75,14 @@ class StrongSwanProcess(ipsec.BaseSwanProcess):
STATUS_RE = r'([a-f0-9\-]+).* (ROUTED|CONNECTING|INSTALLED)'
STATUS_NOT_RUNNING_RE = 'Command:.*ipsec.*status.*Exit code: [1|3] '
def __init__(self, conf, process_id, vpnservice, namespace):
def __init__(self, conf, process_id: str, vpnservice, namespace: str):
self.DIALECT_MAP['v1'] = 'ikev1'
self.DIALECT_MAP['v2'] = 'ikev2'
self.DIALECT_MAP['sha256'] = 'sha256'
self._strongswan_piddir = self._get_strongswan_piddir()
self._rootwrap_cfg = self._get_rootwrap_config()
LOG.debug("strongswan piddir is '%s'", (self._strongswan_piddir))
super(StrongSwanProcess, self).__init__(conf, process_id,
vpnservice, namespace)
super().__init__(conf, process_id, vpnservice, namespace)
def _get_strongswan_piddir(self):
return utils.execute(
@ -103,7 +102,8 @@ class StrongSwanProcess(ipsec.BaseSwanProcess):
return connection_id, status
return None, None
def _execute(self, cmd, check_exit_code=True, extra_ok_codes=None):
def _execute(self, cmd: ty.List[str], check_exit_code: bool = True,
extra_ok_codes: ty.Optional[ty.List[int]] = None):
"""Execute command on namespace.
This execute is wrapped by namespace wrapper.
@ -113,11 +113,11 @@ class StrongSwanProcess(ipsec.BaseSwanProcess):
ns_wrapper = self.get_ns_wrapper()
return ip_wrapper.netns.execute(
[ns_wrapper,
'--mount_paths=/etc:%s/etc,%s:%s/var/run' % (
'--mount_paths=/etc:{0}/etc,{1}:{2}/var/run'.format(
self.config_dir, self._strongswan_piddir, self.config_dir),
('--rootwrap_config=%s' % self._rootwrap_cfg
if self._rootwrap_cfg else ''),
'--cmd=%s' % ','.join(cmd)],
'--rootwrap_config={0}'.format(
self._rootwrap_cfg if self._rootwrap_cfg else ""),
f'--cmd={",".join(cmd)}'],
check_exit_code=check_exit_code,
extra_ok_codes=extra_ok_codes)

View File

@ -14,6 +14,7 @@
from neutron.plugins.ml2.drivers.ovn.agent import neutron_agent
from neutron.plugins.ml2.drivers.ovn.mech_driver.ovsdb import ovsdb_monitor
from neutron.services.ovn_l3 import plugin
from neutron_lib.plugins import constants as plugin_constants
from neutron_lib.plugins import directory
@ -73,9 +74,10 @@ class ChassisVPNAgentWriteEvent(ovsdb_monitor.ChassisAgentEvent):
clear_down=True)
class OVNVPNAgentMonitor(object):
class OVNVPNAgentMonitor:
def watch_agent_events(self):
l3_plugin = directory.get_plugin(plugin_constants.L3)
l3_plugin: plugin.OVNL3RouterPlugin = \
directory.get_plugin(plugin_constants.L3)
sb_ovn = l3_plugin._sb_ovn
if sb_ovn:
idl = sb_ovn.ovsdb_connection.idl

View File

@ -51,7 +51,7 @@ OVS_OPTS = [
]
def register_opts(conf):
def register_opts(conf: cfg.ConfigOpts):
common_config.register_common_config_options()
agent_config.register_interface_driver_opts_helper(conf)
agent_config.register_interface_opts(conf)

View File

@ -13,24 +13,25 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron_lib.callbacks import events
from neutron_lib.callbacks import registry
from neutron_lib.callbacks import resources
from neutron_lib import context
from oslo_config import cfg
from oslo_utils import importutils
from neutron_vpnaas.api.rpc.agentnotifiers import vpn_rpc_agent_api as nfy_api
from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as agent_db
from neutron_vpnaas.db.vpn.vpn_db import VPNPluginDb
from neutron_vpnaas.db.vpn import vpn_ext_gw_db
from neutron_vpnaas.services.vpn.common import constants
from neutron_vpnaas.services.vpn.ovn import agent_monitor
from neutron_vpnaas.services.vpn.plugin import VPNDriverPlugin
from neutron_vpnaas.services.vpn import plugin as vpn_plugin
from neutron_vpnaas.services.vpn.service_drivers import ovn_ipsec
class VPNOVNPlugin(VPNPluginDb,
vpn_ext_gw_db.VPNExtGWPlugin_db,
class VPNOVNPlugin(vpn_ext_gw_db.VPNExtGWPlugin_db,
agent_db.AZVPNAgentSchedulerDbMixin,
agent_monitor.OVNVPNAgentMonitor):
"""Implementation of the VPN Service Plugin.
@ -50,13 +51,14 @@ class VPNOVNPlugin(VPNPluginDb,
resources.PROCESS,
events.AFTER_INIT)
def check_router_in_use(self, context, router_id):
def check_router_in_use(self, context: context.Context, router_id):
pass
def post_fork_initialize(self, resource, event, trigger, payload=None):
self.watch_agent_events()
def vpn_router_agent_binding_changed(self, context, router_id, host):
def vpn_router_agent_binding_changed(self, context: context.Context,
router_id: str, host: str):
pass
supported_extension_aliases = ["vpnaas",
@ -66,11 +68,15 @@ class VPNOVNPlugin(VPNPluginDb,
path_prefix = "/vpn"
class VPNOVNDriverPlugin(VPNOVNPlugin, VPNDriverPlugin):
def vpn_router_agent_binding_changed(self, context, router_id, host):
class VPNOVNDriverPlugin(VPNOVNPlugin, vpn_plugin.VPNDriverPlugin):
def vpn_router_agent_binding_changed(self, context: context.Context,
router_id: str, host: str):
super().vpn_router_agent_binding_changed(context, router_id, host)
filters = {'router_id': [router_id]}
vpnservices = self.get_vpnservices(context, filters=filters)
for vpnservice in vpnservices:
driver = self._get_driver_for_vpnservice(context, vpnservice)
driver: ty.Optional[ovn_ipsec.BaseOvnIPsecVPNDriver] = \
self._get_driver_for_vpnservice( # type: ignore
context, vpnservice)
if driver:
driver.update_port_bindings(context, router_id, host)

View File

@ -1,4 +1,3 @@
# (c) Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
@ -14,6 +13,9 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron.db import flavors_db
from neutron.db import servicetype_db as st_db
from neutron.services import provider_configuration as pconf
from neutron.services import service_base
@ -26,6 +28,7 @@ from oslo_log import log as logging
from neutron_vpnaas.db.vpn import vpn_db
from neutron_vpnaas.extensions import vpn_flavors
from neutron_vpnaas.services.vpn import service_drivers
LOG = logging.getLogger(__name__)
@ -54,8 +57,11 @@ class VPNPlugin(vpn_db.VPNPluginDb):
class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin):
"""VpnPlugin which supports VPN Service Drivers."""
#TODO(nati) handle ikepolicy and ipsecpolicy update usecase
drivers: ty.Dict[str, service_drivers.VpnDriver]
default_provider: ty.Optional[str]
def __init__(self):
super(VPNDriverPlugin, self).__init__()
super().__init__()
self.service_type_manager = st_db.ServiceTypeManager.get_instance()
add_provider_configuration(self.service_type_manager, constants.VPN)
# Load the service driver from neutron.conf.
@ -72,14 +78,16 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin):
vpn_db.subscribe()
@property
def _flavors_plugin(self):
def _flavors_plugin(self) -> flavors_db.FlavorsDbMixin:
return directory.get_plugin(constants.FLAVORS)
def start_rpc_listeners(self):
servers = []
for driver_name, driver in self.drivers.items():
if hasattr(driver, 'start_rpc_listeners'):
servers.extend(driver.start_rpc_listeners())
start_rpc_listeners: ty.Optional[ty.Callable[..., ty.List]] = \
getattr(driver, 'start_rpc_listeners', None)
if start_rpc_listeners and callable(start_rpc_listeners):
servers.extend(start_rpc_listeners())
return servers
def _check_orphan_vpnservice_associations(self):
@ -124,7 +132,9 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin):
context, constants.VPN,
self.default_provider, vpnservice_id)
def _get_provider_for_flavor(self, context, flavor_id):
def _get_provider_for_flavor(
self, context: ncontext.Context,
flavor_id: ty.Optional[str]) -> ty.Optional[str]:
if flavor_id:
if self._flavors_plugin is None:
raise vpn_flavors.FlavorsPluginNotLoaded()
@ -137,7 +147,7 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin):
raise flav_exc.FlavorDisabled()
providers = self._flavors_plugin.get_flavor_next_provider(
context, fl_db['id'])
provider = providers[0].get('provider')
provider: ty.Optional[str] = providers[0].get('provider', None)
if provider not in self.drivers:
raise vpn_flavors.NoProviderFoundForFlavor(flavor_id=flavor_id)
else:
@ -147,84 +157,98 @@ class VPNDriverPlugin(VPNPlugin, vpn_db.VPNPluginRpcDbMixin):
LOG.debug("Selected provider %s", provider)
return provider
def _get_driver_for_vpnservice(self, context, vpnservice):
def _get_driver_for_vpnservice(self, context: ncontext.Context,
vpnservice) -> \
ty.Optional[service_drivers.VpnDriver]:
stm = self.service_type_manager
provider_names = stm.get_provider_names_by_resource_ids(
context, [vpnservice['id']])
provider_names: ty.Dict[str, str] = \
stm.get_provider_names_by_resource_ids(context, [vpnservice['id']])
provider = provider_names.get(vpnservice['id'])
return self.drivers[provider]
return self.drivers.get(provider) if provider else None
def _get_driver_for_ipsec_site_connection(self, context,
def _get_driver_for_ipsec_site_connection(self, context: ncontext.Context,
ipsec_site_connection):
# Only vpnservice_id is required as the vpnservice should be already
# associated with a provider after its creation.
vpnservice = {'id': ipsec_site_connection['vpnservice_id']}
return self._get_driver_for_vpnservice(context, vpnservice)
def create_ipsec_site_connection(self, context, ipsec_site_connection):
def create_ipsec_site_connection(self,
context: ncontext.Context,
ipsec_site_connection) -> ty.Optional[ty.Dict[ty.Any, ty.Any]]:
driver = self._get_driver_for_ipsec_site_connection(
context, ipsec_site_connection['ipsec_site_connection'])
if driver:
driver.validator.validate_ipsec_site_connection(
context,
ipsec_site_connection['ipsec_site_connection'])
ipsec_site_connection = super(
VPNDriverPlugin, self).create_ipsec_site_connection(
ipsec_site_connection = super().create_ipsec_site_connection(
context, ipsec_site_connection)
driver.create_ipsec_site_connection(context, ipsec_site_connection)
return ipsec_site_connection
return None
def delete_ipsec_site_connection(self, context, ipsec_conn_id):
def delete_ipsec_site_connection(self, context: ncontext.Context,
ipsec_site_conn_id: str):
ipsec_site_connection = self.get_ipsec_site_connection(
context, ipsec_conn_id)
super(VPNDriverPlugin, self).delete_ipsec_site_connection(
context, ipsec_conn_id)
context, ipsec_site_conn_id)
super().delete_ipsec_site_connection(
context, ipsec_site_conn_id)
driver = self._get_driver_for_ipsec_site_connection(
context, ipsec_site_connection)
if driver:
driver.delete_ipsec_site_connection(context, ipsec_site_connection)
def update_ipsec_site_connection(
self, context,
ipsec_conn_id, ipsec_site_connection):
self, context: ncontext.Context,
ipsec_site_conn_id: str,
ipsec_site_connection) -> ty.Optional[ty.Dict[ty.Any, ty.Any]]:
old_ipsec_site_connection = self.get_ipsec_site_connection(
context, ipsec_conn_id)
context, ipsec_site_conn_id)
driver = self._get_driver_for_ipsec_site_connection(
context, old_ipsec_site_connection)
if driver:
driver.validator.validate_ipsec_site_connection(
context,
ipsec_site_connection['ipsec_site_connection'])
ipsec_site_connection = super(
VPNDriverPlugin, self).update_ipsec_site_connection(
ipsec_site_connection = super().update_ipsec_site_connection(
context,
ipsec_conn_id,
ipsec_site_conn_id,
ipsec_site_connection)
driver.update_ipsec_site_connection(
context, old_ipsec_site_connection, ipsec_site_connection)
return ipsec_site_connection
return None
def create_vpnservice(self, context, vpnservice):
def create_vpnservice(self, context: ncontext.Context,
vpnservice: ty.Dict[str, ty.Dict[str, ty.Any]]) -> \
ty.Optional[ty.Dict[str, ty.Any]]:
provider = self._get_provider_for_flavor(
context, vpnservice['vpnservice'].get('flavor_id'))
vpnservice = super(
VPNDriverPlugin, self).create_vpnservice(context, vpnservice)
if provider:
vpnservice = super().create_vpnservice(context, vpnservice)
self.service_type_manager.add_resource_association(
context, constants.VPN, provider, vpnservice['id'])
driver = self.drivers[provider]
driver.create_vpnservice(context, vpnservice)
return vpnservice
return None
def update_vpnservice(self, context, vpnservice_id, vpnservice):
def update_vpnservice(self, context: ncontext.Context, vpnservice_id: str,
vpnservice) -> ty.Dict[str, ty.Any]:
old_vpn_service = self.get_vpnservice(context, vpnservice_id)
new_vpn_service = super(
VPNDriverPlugin, self).update_vpnservice(context, vpnservice_id,
new_vpn_service = super().update_vpnservice(context, vpnservice_id,
vpnservice)
driver = self._get_driver_for_vpnservice(context, old_vpn_service)
if driver:
driver.update_vpnservice(context, old_vpn_service, new_vpn_service)
return new_vpn_service
def delete_vpnservice(self, context, vpnservice_id):
def delete_vpnservice(self, context: ncontext.Context, vpnservice_id: str):
vpnservice = self._get_vpnservice(context, vpnservice_id)
super(VPNDriverPlugin, self).delete_vpnservice(context, vpnservice_id)
super().delete_vpnservice(context, vpnservice_id)
driver = self._get_driver_for_vpnservice(context, vpnservice)
if driver:
self.service_type_manager.del_resource_associations(
context, [vpnservice_id])
driver.delete_vpnservice(context, vpnservice)

View File

@ -14,21 +14,25 @@
# under the License.
import abc
import typing as ty
from neutron.extensions import l3
from neutron_lib import context
from neutron_lib.plugins import constants
from neutron_lib.plugins import directory
from neutron_lib import rpc as n_rpc
from oslo_log import log as logging
import oslo_messaging
from neutron_vpnaas.db.vpn import vpn_db
from neutron_vpnaas.services.vpn.service_drivers import driver_validator
LOG = logging.getLogger(__name__)
class VpnDriver(object, metaclass=abc.ABCMeta):
class VpnDriver(metaclass=abc.ABCMeta):
def __init__(self, service_plugin, validator=None):
def __init__(self, service_plugin: vpn_db.VPNPluginDb, validator=None):
self.service_plugin = service_plugin
if validator is None:
validator = driver_validator.VpnDriverValidator(self)
@ -36,7 +40,7 @@ class VpnDriver(object, metaclass=abc.ABCMeta):
self.name = ''
@property
def l3_plugin(self):
def l3_plugin(self) -> l3.RouterPluginBase:
return directory.get_plugin(constants.L3)
@property
@ -44,43 +48,49 @@ class VpnDriver(object, metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
def create_vpnservice(self, context, vpnservice):
def create_vpnservice(self, context: context.ContextBase, vpnservice):
pass
@abc.abstractmethod
def update_vpnservice(
self, context, old_vpnservice, vpnservice):
self, context: context.ContextBase, old_vpnservice, vpnservice):
pass
@abc.abstractmethod
def delete_vpnservice(self, context, vpnservice):
def delete_vpnservice(self, context: context.ContextBase, vpnservice):
pass
@abc.abstractmethod
def create_ipsec_site_connection(self, context, ipsec_site_connection):
pass
@abc.abstractmethod
def update_ipsec_site_connection(self, context, old_ipsec_site_connection,
def create_ipsec_site_connection(self, context: context.ContextBase,
ipsec_site_connection):
pass
@abc.abstractmethod
def delete_ipsec_site_connection(self, context, ipsec_site_connection):
def update_ipsec_site_connection(self, context: context.ContextBase,
old_ipsec_site_connection,
ipsec_site_connection):
pass
@abc.abstractmethod
def delete_ipsec_site_connection(self, context: context.ContextBase,
ipsec_site_connection):
pass
class BaseIPsecVpnAgentApi(object):
class BaseIPsecVpnAgentApi:
"""Base class for IPSec API to agent."""
def __init__(self, topic, default_version, driver):
def __init__(self, topic: str, default_version: str,
driver: VpnDriver):
self.topic = topic
self.driver = driver
target = oslo_messaging.Target(topic=topic, version=default_version)
self.client = n_rpc.get_client(target)
self.target = oslo_messaging.Target(topic=topic,
version=default_version)
self.client = n_rpc.get_client(self.target)
def _agent_notification(self, context, method, router_id,
version=None, **kwargs):
def _agent_notification(self, context: context.ContextBase, method,
router_id: str,
version: ty.Optional[str] = None, **kwargs):
"""Notify update for the agent.
This method will find where is the router, and
@ -103,7 +113,8 @@ class BaseIPsecVpnAgentApi(object):
cctxt = self.client.prepare(server=l3_agent.host, version=version)
cctxt.cast(context, method, **kwargs)
def vpnservice_updated(self, context, router_id, **kwargs):
def vpnservice_updated(self, context: context.ContextBase, router_id: str,
**kwargs):
"""Send update event of vpnservices."""
kwargs['router'] = {'id': router_id}
self._agent_notification(context, 'vpnservice_updated', router_id,

View File

@ -12,17 +12,23 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import abc
import netaddr
import oslo_messaging
from neutron.db import agents_db
from neutron.db.models import l3agent
from neutron.db.models import servicetype
from neutron.objects import agent as nagent
from neutron_lib import constants as lib_constants
from neutron_lib import context
from neutron_lib.db import api as db_api
from neutron_lib.plugins import directory
import oslo_messaging
from neutron_vpnaas.db.vpn import vpn_db
from neutron_vpnaas.db.vpn import vpn_models
from neutron_vpnaas.services.vpn import service_drivers
@ -31,7 +37,7 @@ IPSEC = 'ipsec'
BASE_IPSEC_VERSION = '1.0'
class IPsecVpnDriverCallBack(object):
class IPsecVpnDriverCallBack:
"""Callback for IPSecVpnDriver rpc."""
# history
@ -40,12 +46,13 @@ class IPsecVpnDriverCallBack(object):
target = oslo_messaging.Target(version=BASE_IPSEC_VERSION)
def __init__(self, driver):
super(IPsecVpnDriverCallBack, self).__init__()
super().__init__()
self.driver = driver
def _get_agent_hosting_vpn_services(self, context, host):
plugin = directory.get_plugin()
agent = plugin._get_agent_by_type_and_host(
def _get_agent_hosting_vpn_services(self, context: context.Context,
host: ty.Optional[str]):
plugin: agents_db.AgentDbMixin = directory.get_plugin()
agent: ty.Optional[nagent.Agent] = plugin._get_agent_by_type_and_host(
context, lib_constants.AGENT_TYPE_L3, host)
agent_conf = plugin.get_configuration_dict(agent)
# Retrieve the agent_mode to check if this is the
@ -53,7 +60,9 @@ class IPsecVpnDriverCallBack(object):
# case of distributed the vpn service should reside
# only on a dvr_snat node.
agent_mode = agent_conf.get('agent_mode', 'legacy')
if not agent.admin_state_up or agent_mode == 'dvr':
if (not agent and
not agent.admin_state_up or # type: ignore
agent_mode == 'dvr'):
return []
query = context.session.query(vpn_models.VPNService)
query = query.join(vpn_models.IPsecSiteConnection)
@ -65,25 +74,27 @@ class IPsecVpnDriverCallBack(object):
servicetype.ProviderResourceAssociation.resource_id ==
vpn_models.VPNService.id)
query = query.filter(
l3agent.RouterL3AgentBinding.l3_agent_id == agent.id)
l3agent.RouterL3AgentBinding.l3_agent_id ==
agent.id) # type: ignore
query = query.filter(
servicetype.ProviderResourceAssociation.provider_name ==
self.driver.name)
return query
@db_api.CONTEXT_READER
def get_vpn_services_on_host(self, context, host=None):
def get_vpn_services_on_host(self, context: context.Context,
host: ty.Optional[str] = None):
"""Returns the vpnservices on the host."""
vpnservices = self._get_agent_hosting_vpn_services(
context, host)
plugin = self.driver.service_plugin
plugin: vpn_db.VPNPluginRpcDbMixin = self.driver.service_plugin
local_cidr_map = plugin._build_local_subnet_cidr_map(context)
return [self.driver.make_vpnservice_dict(vpnservice, local_cidr_map)
for vpnservice in vpnservices]
def update_status(self, context, status):
def update_status(self, context: context.Context, status):
"""Update status of vpnservices."""
plugin = self.driver.service_plugin
plugin: vpn_db.VPNPluginRpcDbMixin = self.driver.service_plugin
plugin.update_status_by_agent(context, status)
@ -94,57 +105,63 @@ class IPsecVpnAgentApi(service_drivers.BaseIPsecVpnAgentApi):
# pylint: disable=useless-super-delegation
def __init__(self, topic, default_version, driver):
super(IPsecVpnAgentApi, self).__init__(
topic, default_version, driver)
super().__init__(topic, default_version, driver)
class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta):
"""Base VPN Service Driver class."""
def __init__(self, service_plugin, validator=None):
super(BaseIPsecVPNDriver, self).__init__(service_plugin, validator)
agent_rpc: service_drivers.BaseIPsecVpnAgentApi
def __init__(self, service_plugin: vpn_db.VPNPluginDb, validator=None):
super().__init__(service_plugin, validator)
self.create_rpc_conn()
@property
def service_type(self):
def service_type(self) -> str:
return IPSEC
@abc.abstractmethod
def create_rpc_conn(self):
def create_rpc_conn(self) -> None:
pass
def create_ipsec_site_connection(self, context, ipsec_site_connection):
def create_ipsec_site_connection(self, context: context.Context,
ipsec_site_connection):
router_id = self.service_plugin.get_vpnservice_router_id(
context, ipsec_site_connection['vpnservice_id'])
self.agent_rpc.vpnservice_updated(context, router_id)
def update_ipsec_site_connection(
self, context, old_ipsec_site_connection, ipsec_site_connection):
def update_ipsec_site_connection(self, context: context.Context,
old_ipsec_site_connection,
ipsec_site_connection):
router_id = self.service_plugin.get_vpnservice_router_id(
context, ipsec_site_connection['vpnservice_id'])
self.agent_rpc.vpnservice_updated(context, router_id)
def delete_ipsec_site_connection(self, context, ipsec_site_connection):
def delete_ipsec_site_connection(self, context: context.Context,
ipsec_site_connection):
router_id = self.service_plugin.get_vpnservice_router_id(
context, ipsec_site_connection['vpnservice_id'])
self.agent_rpc.vpnservice_updated(context, router_id)
def create_ikepolicy(self, context, ikepolicy):
def create_ikepolicy(self, context: context.Context, ikepolicy):
pass
def delete_ikepolicy(self, context, ikepolicy):
def delete_ikepolicy(self, context: context.Context, ikepolicy):
pass
def update_ikepolicy(self, context, old_ikepolicy, ikepolicy):
def update_ikepolicy(self, context: context.Context,
old_ikepolicy, ikepolicy):
pass
def create_ipsecpolicy(self, context, ipsecpolicy):
def create_ipsecpolicy(self, context: context.Context, ipsecpolicy):
pass
def delete_ipsecpolicy(self, context, ipsecpolicy):
def delete_ipsecpolicy(self, context: context.Context, ipsecpolicy):
pass
def update_ipsecpolicy(self, context, old_ipsec_policy, ipsecpolicy):
def update_ipsecpolicy(self, context: context.Context,
old_ipsec_policy, ipsecpolicy):
pass
def _get_gateway_ips(self, router):
@ -164,7 +181,7 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta):
return v4_ip, v6_ip
@db_api.CONTEXT_WRITER
def create_vpnservice(self, context, vpnservice_dict):
def create_vpnservice(self, context: context.Context, vpnservice_dict):
"""Get the gateway IP(s) and save for later use.
For the reference implementation, this side's tunnel IP (external_ip)
@ -181,10 +198,11 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta):
vpnservice_dict['id'],
v4_ip=v4_ip, v6_ip=v6_ip)
def update_vpnservice(self, context, old_vpnservice, vpnservice):
def update_vpnservice(self, context: context.Context,
old_vpnservice, vpnservice):
self.agent_rpc.vpnservice_updated(context, vpnservice['router_id'])
def delete_vpnservice(self, context, vpnservice):
def delete_vpnservice(self, context: context.Context, vpnservice):
self.agent_rpc.vpnservice_updated(context, vpnservice['router_id'])
def get_external_ip_based_on_peer(self, vpnservice, ipsec_site_con):
@ -204,7 +222,7 @@ class BaseIPsecVPNDriver(service_drivers.VpnDriver, metaclass=abc.ABCMeta):
also converting parameter name for vpn agent driver
"""
vpnservice_dict = dict(vpnservice)
vpnservice_dict: ty.Dict[str, ty.Any] = dict(vpnservice)
# Populate tenant_id for RPC compat
vpnservice_dict['tenant_id'] = vpnservice_dict['project_id']
vpnservice_dict['ipsec_site_connections'] = []

View File

@ -12,18 +12,25 @@
# License for the specific language governing permissions and limitations
# under the License.
#
import typing as ty
from neutron.extensions import l3
from neutron_lib import context
if ty.TYPE_CHECKING:
from neutron_vpnaas.services.vpn import service_drivers
class VpnDriverValidator(object):
class VpnDriverValidator:
"""Driver-specific validation routines for VPN resources."""
def __init__(self, driver):
def __init__(self, driver: 'service_drivers.VpnDriver'):
self.driver = driver
@property
def l3_plugin(self):
def l3_plugin(self) -> l3.RouterPluginBase:
return self.driver.l3_plugin
def validate_ipsec_site_connection(self, context, ipsec_sitecon):
def validate_ipsec_site_connection(self, context: context.ContextBase,
ipsec_sitecon):
"""Driver can override this for its additional validations."""
pass

View File

@ -15,6 +15,7 @@
from neutron_lib import rpc as n_rpc
from neutron_vpnaas.db.vpn import vpn_db
from neutron_vpnaas.services.vpn.common import topics
from neutron_vpnaas.services.vpn.service_drivers import base_ipsec
from neutron_vpnaas.services.vpn.service_drivers import ipsec_validator
@ -27,9 +28,8 @@ BASE_IPSEC_VERSION = '1.0'
class IPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
"""VPN Service Driver class for IPsec."""
def __init__(self, service_plugin):
super(IPsecVPNDriver, self).__init__(
service_plugin,
def __init__(self, service_plugin: vpn_db.VPNPluginDb):
super().__init__(service_plugin,
ipsec_validator.IpsecVpnValidator(self))
def create_rpc_conn(self):

View File

@ -12,6 +12,9 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron_lib import context
from neutron_lib import exceptions as nexception
from neutron_vpnaas._i18n import _
@ -29,7 +32,8 @@ class IpsecVpnValidator(driver_validator.VpnDriverValidator):
and Libreswan.
"""
def _check_transform_protocol(self, context, transform_protocol):
def _check_transform_protocol(self, context: context.ContextBase,
transform_protocol: ty.Optional[str]):
"""Restrict selecting ah-esp as IPSec Policy transform protocol.
For those *Swan implementations, the 'ah-esp' transform protocol
@ -41,12 +45,15 @@ class IpsecVpnValidator(driver_validator.VpnDriverValidator):
key='transform_protocol',
value=transform_protocol)
def validate_ipsec_policy(self, context, ipsec_policy):
transform_protocol = ipsec_policy.get('transform_protocol')
def validate_ipsec_policy(self, context: context.ContextBase,
ipsec_policy: ty.Dict[str, ty.Union[ty.Any, str]]):
transform_protocol: ty.Optional[str] = \
ipsec_policy.get('transform_protocol', None)
self._check_transform_protocol(context, transform_protocol)
def validate_ipsec_site_connection(self, context, ipsec_sitecon):
if 'ipsecpolicy_id' in ipsec_sitecon:
def validate_ipsec_site_connection(self, context: context.ContextBase,
ipsec_sitecon):
if "ipsecpolicy_id" in ipsec_sitecon:
ipsec_policy = self.driver.service_plugin.get_ipsecpolicy(
context, ipsec_sitecon['ipsecpolicy_id'])
self.validate_ipsec_policy(context, ipsec_policy)

View File

@ -14,8 +14,14 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
import abc
import netaddr
from neutron.db import extraroute_db
from neutron.plugins.ml2 import plugin
from neutron_lib.api.definitions import portbindings
from neutron_lib.callbacks import events
from neutron_lib.callbacks import registry
@ -33,14 +39,22 @@ from oslo_config import cfg
from oslo_db import exception as o_exc
from oslo_log import log as logging
from neutron_vpnaas.db.vpn import vpn_agentschedulers_db as agent_db
from neutron_vpnaas.db.vpn.vpn_ext_gw_db import RouterIsNotVPNExternal
from neutron_vpnaas.db.vpn import vpn_ext_gw_db as ext_gw
from neutron_vpnaas.db.vpn import vpn_models
from neutron_vpnaas.extensions import vpnaas
from neutron_vpnaas.services.vpn.common import constants as v_constants
from neutron_vpnaas.services.vpn.common import topics
from neutron_vpnaas.services.vpn import ovn_plugin
from neutron_vpnaas.services.vpn.service_drivers import base_ipsec
#pylint: disable=ungrouped-imports
# Additional import for typechecking. Importing these without typechecking
# would resolve in a cyclic dependency
if ty.TYPE_CHECKING:
from neutron.db import db_base_plugin_v2 as db_plugin
#pylint: enable=ungrouped-imports
LOG = logging.getLogger(__name__)
@ -63,17 +77,18 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack):
self.admin_ctx = nctx.get_admin_context()
@property
def core_plugin(self):
def core_plugin(self) -> 'db_plugin.NeutronDbPluginV2':
return self.driver.core_plugin
@property
def service_plugin(self):
def service_plugin(self) -> ext_gw.VPNExtGWPlugin_db:
return self.driver.service_plugin
def _get_vpn_gateway(self, context, router_id):
def _get_vpn_gateway(self, context: nctx.ContextBase, router_id: str):
return self.service_plugin.get_vpn_gw_by_router_id(context, router_id)
def get_vpn_transit_network_details(self, context, router_id):
def get_vpn_transit_network_details(self, context: nctx.ContextBase,
router_id: str):
vpn_gw = self._get_vpn_gateway(context, router_id)
network_id = vpn_gw.gw_port['network_id']
external_network = self.core_plugin.get_network(context, network_id)
@ -86,13 +101,14 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack):
}
return details
def get_subnet_info(self, context, subnet_id=None):
def get_subnet_info(self, context: nctx.ContextBase,
subnet_id: ty.Optional[str] = None):
try:
return self.core_plugin.get_subnet(context, subnet_id)
except n_exc.SubnetNotFound:
return None
def _get_agent_hosting_vpn_services(self, context, host):
def _get_agent_hosting_vpn_services(self, context: nctx.Context, host):
agent = self.service_plugin.get_vpn_agent_on_host(context, host)
if not agent:
return []
@ -114,20 +130,23 @@ class IPsecVpnOvnDriverCallBack(base_ipsec.IPsecVpnDriverCallBack):
@registry.has_registry_receivers
class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
def __init__(self, service_plugin):
self._l3_plugin = None
self._core_plugin = None
class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver,
metaclass=abc.ABCMeta):
def __init__(self, service_plugin: ovn_plugin.VPNOVNPlugin):
self._l3_plugin: \
ty.Optional[extraroute_db.ExtraRoute_dbonly_mixin] = None
self._core_plugin: ty.Optional[plugin.Ml2Plugin] = None
self.service_plugin = service_plugin
super().__init__(service_plugin)
@property
def l3_plugin(self):
def l3_plugin(self) -> extraroute_db.ExtraRoute_dbonly_mixin:
if self._l3_plugin is None:
self._l3_plugin = directory.get_plugin(plugin_constants.L3)
return self._l3_plugin
@property
def core_plugin(self):
def core_plugin(self) -> plugin.Ml2Plugin:
if self._core_plugin is None:
self._core_plugin = directory.get_plugin()
return self._core_plugin
@ -155,20 +174,25 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
raise vpnaas.RouteInUseByVPN(
destinations=", ".join(conflict_cidrs))
def get_vpn_gw_port_name(self, router_id):
@abc.abstractmethod
def create_rpc_conn(self) -> None:
pass
def get_vpn_gw_port_name(self, router_id: str) -> str:
return VPN_GW_PORT_PREFIX + router_id
def get_vpn_namespace_port_name(self, router_id):
def get_vpn_namespace_port_name(self, router_id: str) -> str:
return TRANSIT_PORT_PREFIX + router_id
def get_transit_network_name(self, router_id):
def get_transit_network_name(self, router_id: str) -> str:
return TRANSIT_NETWORK_PREFIX + router_id
def get_transit_subnet_name(self, router_id):
def get_transit_subnet_name(self, router_id: str) -> str:
return TRANSIT_SUBNET_PREFIX + router_id
def make_transit_network(self, router_id, tenant_id, agent_host,
gateway_update):
def make_transit_network(self, router_id: str, tenant_id: str,
agent_host: str,
gateway_update: ty.Dict[str, ty.Any]):
context = nctx.get_admin_context()
network_data = {
'tenant_id': HIDDEN_PROJECT_ID,
@ -213,13 +237,14 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
{"port": port_data})
gateway_update['transit_port_id'] = port['id']
def _del_port(self, context, port_id):
def _del_port(self, context: nctx.ContextBase, port_id: str):
try:
self.core_plugin.delete_port(context, port_id, l3_port_check=False)
except n_exc.PortNotFound:
pass
def _remove_router_interface(self, context, router_id, subnet_id):
def _remove_router_interface(self, context: nctx.ContextBase,
router_id: str, subnet_id: str):
try:
self.l3_plugin.remove_router_interface(
context, router_id, {'subnet_id': subnet_id})
@ -227,27 +252,27 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
n_exc.SubnetNotFound):
pass
def _del_subnet(self, context, subnet_id):
def _del_subnet(self, context: nctx.ContextBase, subnet_id: str):
try:
self.core_plugin.delete_subnet(context, subnet_id)
except n_exc.SubnetNotFound:
pass
def _del_network(self, context, network_id):
def _del_network(self, context: nctx.ContextBase, network_id: str):
try:
self.core_plugin.delete_network(context, network_id)
except n_exc.NetworkNotFound:
pass
def del_transit_network(self, gw):
def del_transit_network(self, gw: ext_gw.VPNExtGW):
context = nctx.get_admin_context()
router_id = gw['router_id']
router_id: str = gw['router_id']
port_id = gw.get('transit_port_id')
port_id: ty.Optional[str] = gw.get('transit_port_id')
if port_id:
self._del_port(context, port_id)
subnet_id = gw.get('transit_subnet_id')
subnet_id: ty.Optional[str] = gw.get('transit_subnet_id')
if subnet_id:
self._remove_router_interface(context, router_id, subnet_id)
self._del_subnet(context, subnet_id)
@ -256,7 +281,8 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
if network_id:
self._del_network(context, network_id)
def make_gw_port(self, router_id, network_id, agent_host, gateway_update):
def make_gw_port(self, router_id: str, network_id: str,
agent_host: str, gateway_update: ty.Dict[str, ty.Any]):
context = nctx.get_admin_context()
port_data = {'tenant_id': HIDDEN_PROJECT_ID,
'network_id': network_id,
@ -273,7 +299,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
LOG.debug('No IPs available for external network %s', network_id)
gateway_update['gw_port_id'] = gw_port['id']
def del_gw_port(self, gateway):
def del_gw_port(self, gateway: ext_gw.VPNExtGW):
context = nctx.get_admin_context()
port_id = gateway.get('gw_port_id')
if port_id:
@ -290,24 +316,26 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
cidrs.append(ep.endpoint)
return cidrs
def _routes_update(self, cidrs, nexthop):
def _routes_update(self, cidrs: ty.Set, nexthop):
routes = [{'destination': cidr, 'nexthop': nexthop}
for cidr in cidrs]
return {'router': {'routes': routes}}
def _update_static_routes(self, context, ipsec_site_connection):
def _update_static_routes(self, context: nctx.ContextBase,
ipsec_site_connection):
vpnservice = self.service_plugin.get_vpnservice(
context, ipsec_site_connection['vpnservice_id'])
router_id = vpnservice['router_id']
gw = self.service_plugin.get_vpn_gw_by_router_id(context, router_id)
gw: ext_gw.VPNExtGW = self.service_plugin.get_vpn_gw_by_router_id(
context, router_id)
nexthop = gw.transit_port['fixed_ips'][0]['ip_address']
router = self.l3_plugin.get_router(context, router_id)
old_routes = router.get('routes', [])
old_cidrs = set([r['destination'] for r in old_routes
if r['nexthop'] == nexthop])
old_cidrs = {r['destination'] for r in old_routes
if r['nexthop'] == nexthop}
new_cidrs = set(
self.service_plugin.get_peer_cidrs_for_router(context, router_id))
@ -330,7 +358,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
nctx.get_admin_context(),
router['id'])
if gateway is None or gateway['external_fixed_ips'] is None:
raise RouterIsNotVPNExternal(router_id=router['id'])
raise ext_gw.RouterIsNotVPNExternal(router_id=router['id'])
v4_ip = v6_ip = None
for fixed_ip in gateway['external_fixed_ips']:
@ -343,12 +371,14 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
v6_ip = addr
return v4_ip, v6_ip
def _update_gateway(self, context, gateway_id, **kwargs):
def _update_gateway(self, context: nctx.Context,
gateway_id: str, **kwargs):
gateway = {'gateway': kwargs}
return self.service_plugin.update_gateway(context, gateway_id, gateway)
@db_api.retry_if_session_inactive()
def _ensure_gateway(self, context, vpnservice):
def _ensure_gateway(self, context: nctx.Context, vpnservice) -> \
ty.Dict[str, ty.Any]:
gw = self.service_plugin.get_vpn_gw_dict_by_router_id(
context, vpnservice['router_id'], refresh=True)
if not gw:
@ -371,23 +401,26 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
return gw
@db_api.CONTEXT_WRITER
def _setup(self, context, vpnservice_dict):
def _setup(self, context: nctx.Context,
vpnservice_dict: ty.Dict[str, ty.Any]):
router_id = vpnservice_dict['router_id']
agent = self.service_plugin.schedule_router(context, router_id)
if not agent:
raise vpnaas.NoVPNAgentAvailable
agent_host = agent['host']
gateway = self._ensure_gateway(context, vpnservice_dict)
gateway: ty.Optional[ty.Dict[str, ty.Any]] = self._ensure_gateway(
context, vpnservice_dict)
# If the gateway status is ACTIVE the ports have been created already
if gateway['status'] == lib_constants.ACTIVE:
if gateway and gateway['status'] == lib_constants.ACTIVE:
return
vpnservice = self.service_plugin._get_vpnservice(context,
vpnservice_dict['id'])
vpnservice = self.service_plugin._get_vpnservice(
context, vpnservice_dict["id"])
network_id = vpnservice.router.gw_port.network_id
gateway_update = {} # keeps track of already-created IDs
# keeps track of already-created IDs
gateway_update: ty.Dict[str, ty.Any] = {}
try:
self.make_gw_port(router_id, network_id, agent_host,
gateway_update)
@ -396,16 +429,16 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
agent_host,
gateway_update)
except Exception:
self._update_gateway(context, gateway['id'],
self._update_gateway(context, gateway['id'], # type: ignore
status=lib_constants.ERROR,
**gateway_update)
raise
self._update_gateway(context, gateway['id'],
self._update_gateway(context, gateway['id'], # type: ignore
status=lib_constants.ACTIVE,
**gateway_update)
def _cleanup(self, context, router_id):
def _cleanup(self, context: nctx.Context, router_id: str):
gw = self.service_plugin.get_vpn_gw_dict_by_router_id(context,
router_id)
if not gw:
@ -423,7 +456,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
status=lib_constants.ERROR)
raise
def create_vpnservice(self, context, vpnservice_dict):
def create_vpnservice(self, context: nctx.Context, vpnservice_dict):
try:
self._setup(context, vpnservice_dict)
except Exception:
@ -435,7 +468,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
raise
super().create_vpnservice(context, vpnservice_dict)
def delete_vpnservice(self, context, vpnservice):
def delete_vpnservice(self, context: nctx.Context, vpnservice):
router_id = vpnservice['router_id']
super().delete_vpnservice(context, vpnservice)
services = self.service_plugin.get_vpnservices(context)
@ -443,11 +476,13 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
if router_id not in router_ids:
self._cleanup(context, router_id)
def create_ipsec_site_connection(self, context, ipsec_site_connection):
def create_ipsec_site_connection(self, context: nctx.Context,
ipsec_site_connection):
self._update_static_routes(context, ipsec_site_connection)
super().create_ipsec_site_connection(context, ipsec_site_connection)
def delete_ipsec_site_connection(self, context, ipsec_site_connection):
def delete_ipsec_site_connection(self, context: nctx.Context,
ipsec_site_connection):
self._update_static_routes(context, ipsec_site_connection)
super().delete_ipsec_site_connection(context, ipsec_site_connection)
@ -457,11 +492,11 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
super().update_ipsec_site_connection(
context, old_ipsec_site_connection, ipsec_site_connection)
def _update_port_binding(self, context, port_id, host):
def _update_port_binding(self, context: nctx.Context, port_id, host):
port_data = {'binding:host_id': host}
self.core_plugin.update_port(context, port_id, {'port': port_data})
def update_port_bindings(self, context, router_id, host):
def update_port_bindings(self, context: nctx.Context, router_id, host):
gw = self.service_plugin.get_vpn_gw_dict_by_router_id(context,
router_id)
if not gw:
@ -475,7 +510,7 @@ class BaseOvnIPsecVPNDriver(base_ipsec.BaseIPsecVPNDriver):
class IPsecOvnVpnAgentApi(base_ipsec.IPsecVpnAgentApi):
def _agent_notification(self, context, method, router_id,
def _agent_notification(self, context: nctx.Context, method, router_id,
version=None, **kwargs):
"""Notify update for the agent.
@ -508,7 +543,7 @@ class IPsecOvnVPNDriver(BaseOvnIPsecVPNDriver):
self.agent_rpc = IPsecOvnVpnAgentApi(
topics.IPSEC_AGENT_TOPIC, BASE_IPSEC_VERSION, self)
def start_rpc_listeners(self):
def start_rpc_listeners(self) -> ty.List:
self.endpoints = [IPsecVpnOvnDriverCallBack(self)]
self.conn = n_rpc.Connection()
self.conn.create_consumer(

View File

@ -13,25 +13,30 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from neutron.services import provider_configuration as provconfig
from neutron_lib.exceptions import vpn as vpn_exception
from oslo_config import cfg
from oslo_log import log as logging
from oslo_utils import importutils
from neutron_vpnaas.services.vpn.device_drivers import ipsec
LOG = logging.getLogger(__name__)
DEVICE_DRIVERS = 'device_drivers'
class VPNService(object):
class VPNService:
"""VPN Service observer."""
def __init__(self, l3_agent):
self.conf = l3_agent.conf
self.conf: cfg.ConfigOpts = l3_agent.conf
def load_device_drivers(self, host):
def load_device_drivers(self, host) -> ty.List[ipsec.IPsecDriver]:
"""Loads one or more device drivers for VPNaaS."""
drivers = []
drivers: ty.List[ipsec.IPsecDriver] = []
for device_driver in self.conf.vpnagent.vpn_device_driver:
device_driver = provconfig.get_provider_driver_class(
device_driver, DEVICE_DRIVERS)

View File

@ -178,7 +178,7 @@ def get_ovs_bridge(br_name):
Vm = collections.namedtuple('Vm', ['namespace', 'port_ip'])
class SiteInfo(object):
class SiteInfo:
"""Holds info on the router, ports, service, and connection."""

View File

@ -44,7 +44,7 @@ VPN_HOSTA = "host-1"
VPN_HOSTB = "host-2"
class VPNAgentSchedulerTestMixIn(object):
class VPNAgentSchedulerTestMixIn:
def _request_list(self, path, admin_context=True,
expected_code=exc.HTTPOk.code):
req = self._path_req(path, admin_context=admin_context)

View File

@ -69,7 +69,7 @@ class TestVpnCorePlugin(test_l3_plugin.TestL3NatIntPlugin,
self.router_scheduler = l3_agent_scheduler.ChanceScheduler()
class VPNTestMixin(object):
class VPNTestMixin:
resource_prefix_map = dict(
(k.replace('_', '-'),
"/vpn")
@ -1718,7 +1718,7 @@ class TestVpnaas(VPNPluginDbTestCase):
# tests.
# TODO(pcm): Put helpers in another module for sharing
class NeutronResourcesMixin(object):
class NeutronResourcesMixin:
def create_network(self, overrides=None):
"""Create database entry for network."""

View File

@ -57,7 +57,7 @@ class FakeSqlQueryObject(dict):
super(FakeSqlQueryObject, self).__init__(**entries)
class FakeGatewayDB(object):
class FakeGatewayDB:
def __init__(self):
self.gateways_by_router = {}
self.gateways_by_id = {}

View File

@ -51,3 +51,7 @@ oslo.policy.policies =
neutron-vpnaas = neutron_vpnaas.policies:list_rules
neutron.policies =
neutron-vpnaas = neutron_vpnaas.policies:list_rules
[mypy]
files = neutron_vpnaas/*.py,neutron_vpnaas/agent/**/*.py,neutron_vpnaas/api/**/*.py,neutron_vpnaas/cmd/**/*.py,neutron_vpnaas/db/**/*.py,neutron_vpnaas/extensions/**/*.py,neutron_vpnaas/policies/**/*.py,neutron_vpnaas/scheduler/**/*.py,neutron_vpnaas/services/**/*.py
ignore_missing_imports = true

View File

@ -11,3 +11,4 @@ stestr>=1.0.0 # Apache-2.0
# see https://review.opendev.org/c/openstack/neutron/+/848706
WebTest>=2.0.27 # MIT
mypy>=1.7.0 # MIT

View File

@ -35,7 +35,7 @@ class ASTWalker(compiler.visitor.ASTVisitor):
compiler.visitor.ASTVisitor.default(self, node, *args)
class Visitor(object):
class Visitor:
def __init__(self, filename, i18n_msg_predicates,
msg_format_checkers, debug):

14
tox.ini
View File

@ -1,14 +1,16 @@
[tox]
minversion = 4.0.0
ignore_basepython_conflict=true
requires = virtualenv>=20.17.1
envlist = py39,py38,pep8,docs
minversion = 3.18.0
[testenv]
usedevelop = True
setenv = VIRTUAL_ENV={envdir}
OS_LOG_CAPTURE={env:OS_LOG_CAPTURE:true}
OS_STDOUT_CAPTURE={env:OS_STDOUT_CAPTURE:true}
OS_STDERR_CAPTURE={env:OS_STDERR_CAPTURE:true}
PYTHONWARNINGS=default::DeprecationWarning
usedevelop = True
deps = -c{env:TOX_CONSTRAINTS_FILE:https://opendev.org/openstack/requirements/raw/branch/master/upper-constraints.txt}
-r{toxinidir}/requirements.txt
-r{toxinidir}/test-requirements.txt
@ -83,6 +85,7 @@ commands =
neutron-db-manage --subproject neutron-vpnaas --database-connection sqlite:// check_migration
{[testenv:genconfig]commands}
{[testenv:genpolicy]commands}
{[testenv:mypy]commands}
allowlist_externals =
bash
@ -153,3 +156,10 @@ commands = bash {toxinidir}/tools/generate_config_file_samples.sh
[testenv:genpolicy]
commands = oslopolicy-sample-generator --config-file=etc/oslo-policy-generator/policy.conf
[testenv:mypy]
description =
Run type checks.
deps = {[testenv]deps}
commands =
mypy --install-types --non-interactive --check-untyped-defs