Merge "Drop eventlet usage in ssh pools"
This commit is contained in:
@@ -12,11 +12,13 @@
|
||||
|
||||
"""Ssh utilities."""
|
||||
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from eventlet import pools
|
||||
from oslo_config import cfg
|
||||
from oslo_log import log
|
||||
|
||||
@@ -53,20 +55,28 @@ if paramiko is None:
|
||||
paramiko.pkey.PKey.get_fingerprint = get_fingerprint
|
||||
|
||||
|
||||
class SSHPool(pools.Pool):
|
||||
"""A simple eventlet pool to hold ssh connections."""
|
||||
class SSHPool:
|
||||
"""A thread-safe SSH connection pool."""
|
||||
|
||||
def __init__(self, ip, port, conn_timeout, login, password=None,
|
||||
privatekey=None, *args, **kwargs):
|
||||
privatekey=None, min_size=1, max_size=10):
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.login = login
|
||||
self.password = password
|
||||
self.conn_timeout = conn_timeout if conn_timeout else None
|
||||
self.path_to_private_key = privatekey
|
||||
super(SSHPool, self).__init__(*args, **kwargs)
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
def create(self, quiet=False): # pylint: disable=method-hidden
|
||||
# Concurrent connection management
|
||||
self._lock = threading.RLock()
|
||||
self._connections = deque()
|
||||
self._current_size = 0
|
||||
self._condition = threading.Condition(self._lock)
|
||||
|
||||
def create(self, quiet=False):
|
||||
"""Create one new SSH connection."""
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
look_for_keys = True
|
||||
@@ -108,28 +118,100 @@ class SSHPool(pools.Pool):
|
||||
def get(self):
|
||||
"""Return an item from the pool, when one is available.
|
||||
|
||||
This may cause the calling greenthread to block. Check if a
|
||||
connection is active before returning it. For dead connections
|
||||
create and return a new connection.
|
||||
This method will block if no connections are available and the pool
|
||||
is at maximum capacity. Check if a connection is active before
|
||||
returning it. For dead connections create and return a new connection.
|
||||
"""
|
||||
if self.free_items:
|
||||
conn = self.free_items.popleft()
|
||||
if conn:
|
||||
if conn.get_transport().is_active():
|
||||
return conn
|
||||
else:
|
||||
conn.close()
|
||||
return self.create()
|
||||
if self.current_size < self.max_size:
|
||||
created = self.create()
|
||||
self.current_size += 1
|
||||
return created
|
||||
return self.channel.get()
|
||||
with self._condition:
|
||||
# Try to get an existing connection
|
||||
while True:
|
||||
if self._connections:
|
||||
conn = self._connections.popleft()
|
||||
if conn and self._is_connection_active(conn):
|
||||
return conn
|
||||
else:
|
||||
# Connection is dead, close it and try again
|
||||
if conn:
|
||||
self._close_connection(conn)
|
||||
self._current_size -= 1
|
||||
continue
|
||||
|
||||
# No active connections available
|
||||
if self._current_size < self.max_size:
|
||||
# Create new connection
|
||||
conn = self.create()
|
||||
if conn:
|
||||
self._current_size += 1
|
||||
return conn
|
||||
|
||||
# Pool is at max capacity, wait for a connection
|
||||
self._condition.wait(timeout=30)
|
||||
# If we timeout, try to create anyway
|
||||
if (not self._connections and
|
||||
self._current_size < self.max_size):
|
||||
conn = self.create()
|
||||
if conn:
|
||||
self._current_size += 1
|
||||
return conn
|
||||
|
||||
def put(self, conn):
|
||||
"""Return a connection to the pool."""
|
||||
if not conn:
|
||||
return
|
||||
|
||||
with self._condition:
|
||||
if self._is_connection_active(conn):
|
||||
self._connections.append(conn)
|
||||
else:
|
||||
self._close_connection(conn)
|
||||
if self._current_size > 0:
|
||||
self._current_size -= 1
|
||||
self._condition.notify()
|
||||
|
||||
def remove(self, ssh):
|
||||
"""Close an ssh client and remove it from free_items."""
|
||||
ssh.close()
|
||||
if ssh in self.free_items:
|
||||
self.free_items.remove(ssh)
|
||||
if self.current_size > 0:
|
||||
self.current_size -= 1
|
||||
"""Close an ssh client and remove it from the pool."""
|
||||
with self._lock:
|
||||
if ssh in self._connections:
|
||||
self._connections.remove(ssh)
|
||||
self._close_connection(ssh)
|
||||
if self._current_size > 0:
|
||||
self._current_size -= 1
|
||||
|
||||
@contextmanager
|
||||
def item(self):
|
||||
"""Context manager for getting/returning connections."""
|
||||
conn = self.get()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self.put(conn)
|
||||
|
||||
def _is_connection_active(self, conn):
|
||||
"""Check if SSH connection is still active."""
|
||||
try:
|
||||
return (conn and
|
||||
conn.get_transport() and
|
||||
conn.get_transport().is_active())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _close_connection(self, conn):
|
||||
"""Safely close an SSH connection."""
|
||||
try:
|
||||
if conn:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing
|
||||
|
||||
# Properties for backward compatibility with eventlet.pools.Pool
|
||||
@property
|
||||
def current_size(self):
|
||||
"""Current number of connections in the pool."""
|
||||
with self._lock:
|
||||
return self._current_size
|
||||
|
||||
@property
|
||||
def free_items(self):
|
||||
"""Available connections (for backward compatibility)."""
|
||||
with self._lock:
|
||||
return self._connections
|
||||
|
||||
@@ -10,12 +10,16 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
from oslo_utils import uuidutils
|
||||
import paramiko
|
||||
|
||||
from manila import exception
|
||||
from manila import ssh_utils
|
||||
from manila import test
|
||||
from oslo_utils import uuidutils
|
||||
import paramiko
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class FakeSock(object):
|
||||
@@ -80,6 +84,8 @@ class SSHPoolTestCase(test.TestCase):
|
||||
|
||||
def test_create_ssh_with_password(self):
|
||||
fake_ssh_client = mock.Mock()
|
||||
fake_transport = mock.Mock()
|
||||
fake_ssh_client.get_transport.return_value = fake_transport
|
||||
ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test",
|
||||
password="test")
|
||||
with mock.patch.object(paramiko, "SSHClient",
|
||||
@@ -90,10 +96,13 @@ class SSHPoolTestCase(test.TestCase):
|
||||
"127.0.0.1", port=22, username="test",
|
||||
password="test", key_filename=None, look_for_keys=False,
|
||||
timeout=10, banner_timeout=10)
|
||||
fake_transport.set_keepalive.assert_called_once_with(10)
|
||||
|
||||
def test_create_ssh_with_key(self):
|
||||
path_to_private_key = "/fakepath/to/privatekey"
|
||||
fake_ssh_client = mock.Mock()
|
||||
fake_transport = mock.Mock()
|
||||
fake_ssh_client.get_transport.return_value = fake_transport
|
||||
ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
|
||||
"test",
|
||||
privatekey="/fakepath/to/privatekey")
|
||||
@@ -104,9 +113,12 @@ class SSHPoolTestCase(test.TestCase):
|
||||
"127.0.0.1", port=22, username="test", password=None,
|
||||
key_filename=path_to_private_key, look_for_keys=False,
|
||||
timeout=10, banner_timeout=10)
|
||||
fake_transport.set_keepalive.assert_called_once_with(10)
|
||||
|
||||
def test_create_ssh_with_nothing(self):
|
||||
fake_ssh_client = mock.Mock()
|
||||
fake_transport = mock.Mock()
|
||||
fake_ssh_client.get_transport.return_value = fake_transport
|
||||
ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10, "test")
|
||||
with mock.patch.object(paramiko, "SSHClient",
|
||||
return_value=fake_ssh_client):
|
||||
@@ -115,6 +127,7 @@ class SSHPoolTestCase(test.TestCase):
|
||||
"127.0.0.1", port=22, username="test", password=None,
|
||||
key_filename=None, look_for_keys=True,
|
||||
timeout=10, banner_timeout=10)
|
||||
fake_transport.set_keepalive.assert_called_once_with(10)
|
||||
|
||||
def test_create_ssh_error_connecting(self):
|
||||
attrs = {'connect.side_effect': paramiko.SSHException, }
|
||||
@@ -156,11 +169,22 @@ class SSHPoolTestCase(test.TestCase):
|
||||
@mock.patch('os.path.isfile', return_value=True)
|
||||
def test_sshpool_remove(self, mock_isfile, mock_sshclient, mock_open):
|
||||
ssh_to_remove = mock.Mock()
|
||||
ssh_to_remove.get_transport.return_value.is_active.return_value = True
|
||||
mock_sshclient.side_effect = [mock.Mock(), ssh_to_remove, mock.Mock()]
|
||||
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
|
||||
"test", password="test",
|
||||
min_size=3, max_size=3)
|
||||
|
||||
# Get connections to populate the pool
|
||||
conn1 = sshpool.get()
|
||||
conn2 = sshpool.get()
|
||||
conn3 = sshpool.get()
|
||||
|
||||
# Put them back so they're in free_items
|
||||
sshpool.put(conn1)
|
||||
sshpool.put(conn2)
|
||||
sshpool.put(conn3)
|
||||
|
||||
self.assertIn(ssh_to_remove, list(sshpool.free_items))
|
||||
|
||||
sshpool.remove(ssh_to_remove)
|
||||
@@ -174,11 +198,22 @@ class SSHPoolTestCase(test.TestCase):
|
||||
mock_sshclient, mock_open):
|
||||
# create an SSH Client that is not a part of sshpool.
|
||||
ssh_to_remove = mock.Mock()
|
||||
mock_sshclient.side_effect = [mock.Mock(), mock.Mock()]
|
||||
mock_conn1 = mock.Mock()
|
||||
mock_conn2 = mock.Mock()
|
||||
mock_conn1.get_transport.return_value.is_active.return_value = True
|
||||
mock_conn2.get_transport.return_value.is_active.return_value = True
|
||||
mock_sshclient.side_effect = [mock_conn1, mock_conn2]
|
||||
|
||||
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
|
||||
"test", password="test",
|
||||
min_size=2, max_size=2)
|
||||
|
||||
# Get and put back connections to populate free_items
|
||||
conn1 = sshpool.get()
|
||||
conn2 = sshpool.get()
|
||||
sshpool.put(conn1)
|
||||
sshpool.put(conn2)
|
||||
|
||||
listBefore = list(sshpool.free_items)
|
||||
|
||||
self.assertNotIn(ssh_to_remove, listBefore)
|
||||
@@ -186,3 +221,55 @@ class SSHPoolTestCase(test.TestCase):
|
||||
sshpool.remove(ssh_to_remove)
|
||||
|
||||
self.assertEqual(listBefore, list(sshpool.free_items))
|
||||
|
||||
def test_sshpool_thread_safety(self):
|
||||
"""Test that the pool is thread-safe."""
|
||||
with mock.patch.object(paramiko, "SSHClient",
|
||||
mock.Mock(return_value=FakeSSHClient())):
|
||||
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
|
||||
"test", password="test",
|
||||
min_size=1, max_size=5)
|
||||
|
||||
connections_acquired = []
|
||||
errors = []
|
||||
|
||||
def acquire_connection():
|
||||
try:
|
||||
with sshpool.item() as ssh:
|
||||
connections_acquired.append(ssh.id)
|
||||
time.sleep(0.1) # Simulate work
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for _ in range(10):
|
||||
thread = threading.Thread(target=acquire_connection)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify no errors
|
||||
self.assertEqual([], errors)
|
||||
self.assertEqual(10, len(connections_acquired))
|
||||
self.assertLessEqual(sshpool.current_size, 5)
|
||||
|
||||
def test_sshpool_put_get_behavior(self):
|
||||
with mock.patch.object(paramiko, "SSHClient",
|
||||
mock.Mock(return_value=FakeSSHClient())):
|
||||
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
|
||||
"test", password="test",
|
||||
min_size=1, max_size=3)
|
||||
|
||||
conn1 = sshpool.get()
|
||||
self.assertIsNotNone(conn1)
|
||||
self.assertEqual(1, sshpool.current_size)
|
||||
|
||||
sshpool.put(conn1)
|
||||
self.assertEqual(1, len(sshpool.free_items))
|
||||
|
||||
conn2 = sshpool.get()
|
||||
self.assertEqual(conn1.id, conn2.id)
|
||||
|
||||
Reference in New Issue
Block a user