diff --git a/manila/ssh_utils.py b/manila/ssh_utils.py index b013b68bcf..4a309795d1 100644 --- a/manila/ssh_utils.py +++ b/manila/ssh_utils.py @@ -12,11 +12,13 @@ """Ssh utilities.""" +from collections import deque +from contextlib import contextmanager import hashlib import logging import os +import threading -from eventlet import pools from oslo_config import cfg from oslo_log import log @@ -53,20 +55,28 @@ if paramiko is None: paramiko.pkey.PKey.get_fingerprint = get_fingerprint -class SSHPool(pools.Pool): - """A simple eventlet pool to hold ssh connections.""" +class SSHPool: + """A thread-safe SSH connection pool.""" def __init__(self, ip, port, conn_timeout, login, password=None, - privatekey=None, *args, **kwargs): + privatekey=None, min_size=1, max_size=10): self.ip = ip self.port = port self.login = login self.password = password self.conn_timeout = conn_timeout if conn_timeout else None self.path_to_private_key = privatekey - super(SSHPool, self).__init__(*args, **kwargs) + self.min_size = min_size + self.max_size = max_size - def create(self, quiet=False): # pylint: disable=method-hidden + # Concurrent connection management + self._lock = threading.RLock() + self._connections = deque() + self._current_size = 0 + self._condition = threading.Condition(self._lock) + + def create(self, quiet=False): + """Create one new SSH connection.""" ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) look_for_keys = True @@ -108,28 +118,100 @@ class SSHPool(pools.Pool): def get(self): """Return an item from the pool, when one is available. - This may cause the calling greenthread to block. Check if a - connection is active before returning it. For dead connections - create and return a new connection. + This method will block if no connections are available and the pool + is at maximum capacity. Check if a connection is active before + returning it. For dead connections create and return a new connection. """ - if self.free_items: - conn = self.free_items.popleft() - if conn: - if conn.get_transport().is_active(): - return conn - else: - conn.close() - return self.create() - if self.current_size < self.max_size: - created = self.create() - self.current_size += 1 - return created - return self.channel.get() + with self._condition: + # Try to get an existing connection + while True: + if self._connections: + conn = self._connections.popleft() + if conn and self._is_connection_active(conn): + return conn + else: + # Connection is dead, close it and try again + if conn: + self._close_connection(conn) + self._current_size -= 1 + continue + + # No active connections available + if self._current_size < self.max_size: + # Create new connection + conn = self.create() + if conn: + self._current_size += 1 + return conn + + # Pool is at max capacity, wait for a connection + self._condition.wait(timeout=30) + # If we timeout, try to create anyway + if (not self._connections and + self._current_size < self.max_size): + conn = self.create() + if conn: + self._current_size += 1 + return conn + + def put(self, conn): + """Return a connection to the pool.""" + if not conn: + return + + with self._condition: + if self._is_connection_active(conn): + self._connections.append(conn) + else: + self._close_connection(conn) + if self._current_size > 0: + self._current_size -= 1 + self._condition.notify() def remove(self, ssh): - """Close an ssh client and remove it from free_items.""" - ssh.close() - if ssh in self.free_items: - self.free_items.remove(ssh) - if self.current_size > 0: - self.current_size -= 1 + """Close an ssh client and remove it from the pool.""" + with self._lock: + if ssh in self._connections: + self._connections.remove(ssh) + self._close_connection(ssh) + if self._current_size > 0: + self._current_size -= 1 + + @contextmanager + def item(self): + """Context manager for getting/returning connections.""" + conn = self.get() + try: + yield conn + finally: + self.put(conn) + + def _is_connection_active(self, conn): + """Check if SSH connection is still active.""" + try: + return (conn and + conn.get_transport() and + conn.get_transport().is_active()) + except Exception: + return False + + def _close_connection(self, conn): + """Safely close an SSH connection.""" + try: + if conn: + conn.close() + except Exception: + pass # Ignore errors when closing + + # Properties for backward compatibility with eventlet.pools.Pool + @property + def current_size(self): + """Current number of connections in the pool.""" + with self._lock: + return self._current_size + + @property + def free_items(self): + """Available connections (for backward compatibility).""" + with self._lock: + return self._connections diff --git a/manila/tests/test_ssh_utils.py b/manila/tests/test_ssh_utils.py index c35e44d2ef..959d5e3040 100644 --- a/manila/tests/test_ssh_utils.py +++ b/manila/tests/test_ssh_utils.py @@ -10,12 +10,16 @@ # License for the specific language governing permissions and limitations # under the License. +import threading +import time +from unittest import mock + +from oslo_utils import uuidutils +import paramiko + from manila import exception from manila import ssh_utils from manila import test -from oslo_utils import uuidutils -import paramiko -from unittest import mock class FakeSock(object): @@ -80,6 +84,8 @@ class SSHPoolTestCase(test.TestCase): def test_create_ssh_with_password(self): fake_ssh_client = mock.Mock() + fake_transport = mock.Mock() + fake_ssh_client.get_transport.return_value = fake_transport ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test", password="test") with mock.patch.object(paramiko, "SSHClient", @@ -90,10 +96,13 @@ class SSHPoolTestCase(test.TestCase): "127.0.0.1", port=22, username="test", password="test", key_filename=None, look_for_keys=False, timeout=10, banner_timeout=10) + fake_transport.set_keepalive.assert_called_once_with(10) def test_create_ssh_with_key(self): path_to_private_key = "/fakepath/to/privatekey" fake_ssh_client = mock.Mock() + fake_transport = mock.Mock() + fake_ssh_client.get_transport.return_value = fake_transport ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test", privatekey="/fakepath/to/privatekey") @@ -104,9 +113,12 @@ class SSHPoolTestCase(test.TestCase): "127.0.0.1", port=22, username="test", password=None, key_filename=path_to_private_key, look_for_keys=False, timeout=10, banner_timeout=10) + fake_transport.set_keepalive.assert_called_once_with(10) def test_create_ssh_with_nothing(self): fake_ssh_client = mock.Mock() + fake_transport = mock.Mock() + fake_ssh_client.get_transport.return_value = fake_transport ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test") with mock.patch.object(paramiko, "SSHClient", return_value=fake_ssh_client): @@ -115,6 +127,7 @@ class SSHPoolTestCase(test.TestCase): "127.0.0.1", port=22, username="test", password=None, key_filename=None, look_for_keys=True, timeout=10, banner_timeout=10) + fake_transport.set_keepalive.assert_called_once_with(10) def test_create_ssh_error_connecting(self): attrs = {'connect.side_effect': paramiko.SSHException, } @@ -156,11 +169,22 @@ class SSHPoolTestCase(test.TestCase): @mock.patch('os.path.isfile', return_value=True) def test_sshpool_remove(self, mock_isfile, mock_sshclient, mock_open): ssh_to_remove = mock.Mock() + ssh_to_remove.get_transport.return_value.is_active.return_value = True mock_sshclient.side_effect = [mock.Mock(), ssh_to_remove, mock.Mock()] sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test", password="test", min_size=3, max_size=3) + # Get connections to populate the pool + conn1 = sshpool.get() + conn2 = sshpool.get() + conn3 = sshpool.get() + + # Put them back so they're in free_items + sshpool.put(conn1) + sshpool.put(conn2) + sshpool.put(conn3) + self.assertIn(ssh_to_remove, list(sshpool.free_items)) sshpool.remove(ssh_to_remove) @@ -174,11 +198,22 @@ class SSHPoolTestCase(test.TestCase): mock_sshclient, mock_open): # create an SSH Client that is not a part of sshpool. ssh_to_remove = mock.Mock() - mock_sshclient.side_effect = [mock.Mock(), mock.Mock()] + mock_conn1 = mock.Mock() + mock_conn2 = mock.Mock() + mock_conn1.get_transport.return_value.is_active.return_value = True + mock_conn2.get_transport.return_value.is_active.return_value = True + mock_sshclient.side_effect = [mock_conn1, mock_conn2] sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test", password="test", min_size=2, max_size=2) + + # Get and put back connections to populate free_items + conn1 = sshpool.get() + conn2 = sshpool.get() + sshpool.put(conn1) + sshpool.put(conn2) + listBefore = list(sshpool.free_items) self.assertNotIn(ssh_to_remove, listBefore) @@ -186,3 +221,55 @@ class SSHPoolTestCase(test.TestCase): sshpool.remove(ssh_to_remove) self.assertEqual(listBefore, list(sshpool.free_items)) + + def test_sshpool_thread_safety(self): + """Test that the pool is thread-safe.""" + with mock.patch.object(paramiko, "SSHClient", + mock.Mock(return_value=FakeSSHClient())): + sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, + "test", password="test", + min_size=1, max_size=5) + + connections_acquired = [] + errors = [] + + def acquire_connection(): + try: + with sshpool.item() as ssh: + connections_acquired.append(ssh.id) + time.sleep(0.1) # Simulate work + except Exception as e: + errors.append(str(e)) + + # Start multiple threads + threads = [] + for _ in range(10): + thread = threading.Thread(target=acquire_connection) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify no errors + self.assertEqual([], errors) + self.assertEqual(10, len(connections_acquired)) + self.assertLessEqual(sshpool.current_size, 5) + + def test_sshpool_put_get_behavior(self): + with mock.patch.object(paramiko, "SSHClient", + mock.Mock(return_value=FakeSSHClient())): + sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10, + "test", password="test", + min_size=1, max_size=3) + + conn1 = sshpool.get() + self.assertIsNotNone(conn1) + self.assertEqual(1, sshpool.current_size) + + sshpool.put(conn1) + self.assertEqual(1, len(sshpool.free_items)) + + conn2 = sshpool.get() + self.assertEqual(conn1.id, conn2.id)