Files
test/framework/rest/ssh_tunnel_rest_client.py
croy f811ace75a Support for REST APIs going through a JumpHost
Change-Id: Ic087946a96cdc8e00df011a6576a3197aacd7fb5
Signed-off-by: croy <christian.roy@windriver.com>
2025-11-14 18:28:40 +00:00

150 lines
5.1 KiB
Python

import socket
from typing import Any, Optional
from urllib.parse import urlparse
import requests
from sshtunnel import SSHTunnelForwarder
from urllib3.exceptions import InsecureRequestWarning
from config.configuration_manager import ConfigurationManager
from framework.rest.rest_response import RestResponse
class SSHTunnelRestClient:
"""
REST client that makes HTTP requests through SSH tunnel when jump host is configured
"""
def __init__(self):
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
self._tunnels = {}
self._port_mappings = {}
self._setup_tunnels()
def _setup_tunnels(self):
"""Setup SSH tunnels for common API ports if jump host is configured"""
config = ConfigurationManager.get_lab_config()
if not config.is_use_jump_server():
return
# Get jump host configurations
jump_host_config = ConfigurationManager.get_lab_config().get_jump_host_configuration()
lab_ip = config.get_floating_ip()
# Get API ports from configuration
rest_api_config = ConfigurationManager.get_rest_api_config()
api_ports = rest_api_config.get_all_ports()
for remote_port in api_ports:
local_port = self._find_free_port()
# Create SSH tunnel for this port
tunnel = SSHTunnelForwarder((jump_host_config.get_host(), jump_host_config.get_ssh_port()), ssh_username=jump_host_config.get_credentials().get_user_name(), ssh_password=jump_host_config.get_credentials().get_password(), remote_bind_address=(lab_ip, remote_port), local_bind_address=("127.0.0.1", local_port))
# Start the tunnel
tunnel.start()
# Store tunnel and port mapping
self._tunnels[remote_port] = tunnel
self._port_mappings[remote_port] = local_port
def _find_free_port(self):
"""Find an available local port"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port
def _modify_url_for_tunnel(self, url):
"""Modify URL to use local tunnel if configured"""
if not self._tunnels:
return url
parsed = urlparse(url)
# Extract the port from the URL
if ":" in parsed.netloc:
host, port_str = parsed.netloc.rsplit(":", 1)
try:
remote_port = int(port_str)
except ValueError:
return url
else:
# Default HTTPS port
remote_port = 443
# Check if we have a tunnel for this port
if remote_port in self._port_mappings:
local_port = self._port_mappings[remote_port]
modified_url = url.replace(f"{parsed.netloc}", f"127.0.0.1:{local_port}")
return modified_url
return url
def _normalize_headers(self, headers):
"""Convert headers to dictionary format"""
if not headers:
return {}
if isinstance(headers, dict):
return headers
if isinstance(headers, list):
headers_dict = {}
for header in headers:
if isinstance(header, dict):
headers_dict.update(header)
return headers_dict
return {}
def get(self, url: str, headers: Optional[Any] = None) -> RestResponse:
"""
Runs a get request with the given url and headers, tunneling through SSH if configured
Args:
url (str): The URL for the request.
headers (Optional[Any]): Headers for the request (dict or list of dicts). Defaults to None.
Returns:
RestResponse: An object representing the response of the GET request.
"""
# Convert headers to dict format
headers_dict = self._normalize_headers(headers)
# Modify URL for tunnel if needed
modified_url = self._modify_url_for_tunnel(url)
response = requests.get(modified_url, headers=headers_dict, verify=False)
return RestResponse(response)
def post(self, url: str, data: Any, headers: Any) -> RestResponse:
"""
Runs a post request with the given url and headers, tunneling through SSH if configured
Args:
url (str): The URL for the request.
data (Any): The data to be sent in the body of the request.
headers (Any): Headers for the request (dict or list of dicts).
Returns:
RestResponse: An object containing the response from the request.
"""
# Convert headers to dict format
headers_dict = self._normalize_headers(headers)
# Modify URL for tunnel if needed
modified_url = self._modify_url_for_tunnel(url)
response = requests.post(modified_url, headers=headers_dict, data=data, verify=False)
return RestResponse(response)
def close(self):
"""Close all SSH tunnels if they exist"""
for tunnel in self._tunnels.values():
tunnel.stop()
self._tunnels.clear()
self._port_mappings.clear()