Merge "Drop eventlet usage in ssh pools"

This commit is contained in:
Zuul
2025-12-04 09:34:34 +00:00
committed by Gerrit Code Review
2 changed files with 201 additions and 32 deletions

View File

@@ -12,11 +12,13 @@
"""Ssh utilities."""
from collections import deque
from contextlib import contextmanager
import hashlib
import logging
import os
import threading
from eventlet import pools
from oslo_config import cfg
from oslo_log import log
@@ -53,20 +55,28 @@ if paramiko is None:
paramiko.pkey.PKey.get_fingerprint = get_fingerprint
class SSHPool(pools.Pool):
"""A simple eventlet pool to hold ssh connections."""
class SSHPool:
"""A thread-safe SSH connection pool."""
def __init__(self, ip, port, conn_timeout, login, password=None,
privatekey=None, *args, **kwargs):
privatekey=None, min_size=1, max_size=10):
self.ip = ip
self.port = port
self.login = login
self.password = password
self.conn_timeout = conn_timeout if conn_timeout else None
self.path_to_private_key = privatekey
super(SSHPool, self).__init__(*args, **kwargs)
self.min_size = min_size
self.max_size = max_size
def create(self, quiet=False): # pylint: disable=method-hidden
# Concurrent connection management
self._lock = threading.RLock()
self._connections = deque()
self._current_size = 0
self._condition = threading.Condition(self._lock)
def create(self, quiet=False):
"""Create one new SSH connection."""
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
look_for_keys = True
@@ -108,28 +118,100 @@ class SSHPool(pools.Pool):
def get(self):
"""Return an item from the pool, when one is available.
This may cause the calling greenthread to block. Check if a
connection is active before returning it. For dead connections
create and return a new connection.
This method will block if no connections are available and the pool
is at maximum capacity. Check if a connection is active before
returning it. For dead connections create and return a new connection.
"""
if self.free_items:
conn = self.free_items.popleft()
if conn:
if conn.get_transport().is_active():
return conn
else:
conn.close()
return self.create()
if self.current_size < self.max_size:
created = self.create()
self.current_size += 1
return created
return self.channel.get()
with self._condition:
# Try to get an existing connection
while True:
if self._connections:
conn = self._connections.popleft()
if conn and self._is_connection_active(conn):
return conn
else:
# Connection is dead, close it and try again
if conn:
self._close_connection(conn)
self._current_size -= 1
continue
# No active connections available
if self._current_size < self.max_size:
# Create new connection
conn = self.create()
if conn:
self._current_size += 1
return conn
# Pool is at max capacity, wait for a connection
self._condition.wait(timeout=30)
# If we timeout, try to create anyway
if (not self._connections and
self._current_size < self.max_size):
conn = self.create()
if conn:
self._current_size += 1
return conn
def put(self, conn):
"""Return a connection to the pool."""
if not conn:
return
with self._condition:
if self._is_connection_active(conn):
self._connections.append(conn)
else:
self._close_connection(conn)
if self._current_size > 0:
self._current_size -= 1
self._condition.notify()
def remove(self, ssh):
"""Close an ssh client and remove it from free_items."""
ssh.close()
if ssh in self.free_items:
self.free_items.remove(ssh)
if self.current_size > 0:
self.current_size -= 1
"""Close an ssh client and remove it from the pool."""
with self._lock:
if ssh in self._connections:
self._connections.remove(ssh)
self._close_connection(ssh)
if self._current_size > 0:
self._current_size -= 1
@contextmanager
def item(self):
"""Context manager for getting/returning connections."""
conn = self.get()
try:
yield conn
finally:
self.put(conn)
def _is_connection_active(self, conn):
"""Check if SSH connection is still active."""
try:
return (conn and
conn.get_transport() and
conn.get_transport().is_active())
except Exception:
return False
def _close_connection(self, conn):
"""Safely close an SSH connection."""
try:
if conn:
conn.close()
except Exception:
pass # Ignore errors when closing
# Properties for backward compatibility with eventlet.pools.Pool
@property
def current_size(self):
"""Current number of connections in the pool."""
with self._lock:
return self._current_size
@property
def free_items(self):
"""Available connections (for backward compatibility)."""
with self._lock:
return self._connections

View File

@@ -10,12 +10,16 @@
# License for the specific language governing permissions and limitations
# under the License.
import threading
import time
from unittest import mock
from oslo_utils import uuidutils
import paramiko
from manila import exception
from manila import ssh_utils
from manila import test
from oslo_utils import uuidutils
import paramiko
from unittest import mock
class FakeSock(object):
@@ -80,6 +84,8 @@ class SSHPoolTestCase(test.TestCase):
def test_create_ssh_with_password(self):
fake_ssh_client = mock.Mock()
fake_transport = mock.Mock()
fake_ssh_client.get_transport.return_value = fake_transport
ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test",
password="test")
with mock.patch.object(paramiko, "SSHClient",
@@ -90,10 +96,13 @@ class SSHPoolTestCase(test.TestCase):
"127.0.0.1", port=22, username="test",
password="test", key_filename=None, look_for_keys=False,
timeout=10, banner_timeout=10)
fake_transport.set_keepalive.assert_called_once_with(10)
def test_create_ssh_with_key(self):
path_to_private_key = "/fakepath/to/privatekey"
fake_ssh_client = mock.Mock()
fake_transport = mock.Mock()
fake_ssh_client.get_transport.return_value = fake_transport
ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test",
privatekey="/fakepath/to/privatekey")
@@ -104,9 +113,12 @@ class SSHPoolTestCase(test.TestCase):
"127.0.0.1", port=22, username="test", password=None,
key_filename=path_to_private_key, look_for_keys=False,
timeout=10, banner_timeout=10)
fake_transport.set_keepalive.assert_called_once_with(10)
def test_create_ssh_with_nothing(self):
fake_ssh_client = mock.Mock()
fake_transport = mock.Mock()
fake_ssh_client.get_transport.return_value = fake_transport
ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test")
with mock.patch.object(paramiko, "SSHClient",
return_value=fake_ssh_client):
@@ -115,6 +127,7 @@ class SSHPoolTestCase(test.TestCase):
"127.0.0.1", port=22, username="test", password=None,
key_filename=None, look_for_keys=True,
timeout=10, banner_timeout=10)
fake_transport.set_keepalive.assert_called_once_with(10)
def test_create_ssh_error_connecting(self):
attrs = {'connect.side_effect': paramiko.SSHException, }
@@ -156,11 +169,22 @@ class SSHPoolTestCase(test.TestCase):
@mock.patch('os.path.isfile', return_value=True)
def test_sshpool_remove(self, mock_isfile, mock_sshclient, mock_open):
ssh_to_remove = mock.Mock()
ssh_to_remove.get_transport.return_value.is_active.return_value = True
mock_sshclient.side_effect = [mock.Mock(), ssh_to_remove, mock.Mock()]
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test", password="test",
min_size=3, max_size=3)
# Get connections to populate the pool
conn1 = sshpool.get()
conn2 = sshpool.get()
conn3 = sshpool.get()
# Put them back so they're in free_items
sshpool.put(conn1)
sshpool.put(conn2)
sshpool.put(conn3)
self.assertIn(ssh_to_remove, list(sshpool.free_items))
sshpool.remove(ssh_to_remove)
@@ -174,11 +198,22 @@ class SSHPoolTestCase(test.TestCase):
mock_sshclient, mock_open):
# create an SSH Client that is not a part of sshpool.
ssh_to_remove = mock.Mock()
mock_sshclient.side_effect = [mock.Mock(), mock.Mock()]
mock_conn1 = mock.Mock()
mock_conn2 = mock.Mock()
mock_conn1.get_transport.return_value.is_active.return_value = True
mock_conn2.get_transport.return_value.is_active.return_value = True
mock_sshclient.side_effect = [mock_conn1, mock_conn2]
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test", password="test",
min_size=2, max_size=2)
# Get and put back connections to populate free_items
conn1 = sshpool.get()
conn2 = sshpool.get()
sshpool.put(conn1)
sshpool.put(conn2)
listBefore = list(sshpool.free_items)
self.assertNotIn(ssh_to_remove, listBefore)
@@ -186,3 +221,55 @@ class SSHPoolTestCase(test.TestCase):
sshpool.remove(ssh_to_remove)
self.assertEqual(listBefore, list(sshpool.free_items))
def test_sshpool_thread_safety(self):
"""Test that the pool is thread-safe."""
with mock.patch.object(paramiko, "SSHClient",
mock.Mock(return_value=FakeSSHClient())):
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test", password="test",
min_size=1, max_size=5)
connections_acquired = []
errors = []
def acquire_connection():
try:
with sshpool.item() as ssh:
connections_acquired.append(ssh.id)
time.sleep(0.1) # Simulate work
except Exception as e:
errors.append(str(e))
# Start multiple threads
threads = []
for _ in range(10):
thread = threading.Thread(target=acquire_connection)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify no errors
self.assertEqual([], errors)
self.assertEqual(10, len(connections_acquired))
self.assertLessEqual(sshpool.current_size, 5)
def test_sshpool_put_get_behavior(self):
with mock.patch.object(paramiko, "SSHClient",
mock.Mock(return_value=FakeSSHClient())):
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test", password="test",
min_size=1, max_size=3)
conn1 = sshpool.get()
self.assertIsNotNone(conn1)
self.assertEqual(1, sshpool.current_size)
sshpool.put(conn1)
self.assertEqual(1, len(sshpool.free_items))
conn2 = sshpool.get()
self.assertEqual(conn1.id, conn2.id)