diff --git a/tobiko/shell/ssh/_client.py b/tobiko/shell/ssh/_client.py index 8f08d5576..00f7155c9 100644 --- a/tobiko/shell/ssh/_client.py +++ b/tobiko/shell/ssh/_client.py @@ -18,8 +18,8 @@ from __future__ import absolute_import import collections import contextlib import getpass +import io import os -import socket import subprocess import time import threading @@ -579,6 +579,21 @@ def ssh_client(host, port=None, username=None, proxy_jump=None, **connect_parameters) +def load_private_keys(key_filenames: typing.List[str]): + pkeys = [] + for filename in key_filenames: + if os.path.exists(filename): + try: + with io.open(filename, 'rt') as fd: + pkey = paramiko.RSAKey.from_private_key(fd) + except Exception: + LOG.error('Unable to get RSAKey private key from file: ' + f'{filename}', exc_info=1) + else: + pkeys.append(pkey) + return pkeys + + def ssh_connect(hostname, username=None, port=None, connection_interval=None, connection_attempts=None, connection_timeout=None, proxy_command=None, proxy_client=None, key_filename=None, @@ -587,13 +602,9 @@ def ssh_connect(hostname, username=None, port=None, connection_interval=None, client.set_missing_host_key_policy(paramiko.WarningPolicy()) login = _command.ssh_login(hostname=hostname, username=username, port=port) - if key_filename: - # Ensures we try enough times to try all keys - tobiko.check_valid_type(key_filename, list) - connection_attempts = max(connection_attempts or 1, - len(key_filename), - 1) - + assert isinstance(key_filename, list) + pkeys = load_private_keys(key_filename) + auth_failed: typing.Optional[Exception] = None for attempt in tobiko.retry(count=connection_attempts, timeout=connection_timeout, interval=connection_interval, @@ -603,8 +614,8 @@ def ssh_connect(hostname, username=None, port=None, connection_interval=None, LOG.debug(f"Logging in to '{login}'...\n" f" - parameters: {parameters}\n" f" - attempt: {attempt.details}\n") - - try: + for pkey in pkeys + [None]: + succeeded = False proxy_sock = ssh_proxy_sock( hostname=hostname, port=port, @@ -613,29 +624,35 @@ def ssh_connect(hostname, username=None, port=None, connection_interval=None, timeout=connection_timeout, connection_attempts=1, connection_interval=connection_interval) - client.connect(hostname=hostname, - username=username, - port=port, - sock=proxy_sock, - key_filename=key_filename, - **parameters) - except ValueError as ex: - attempt.check_limits() - if (str(ex) == 'q must be exactly 160, 224, or 256 bits long' and - len(key_filename) > 1): - # Must try without the first key - LOG.debug(f"Retry connecting with the next key: {ex}") - key_filename = key_filename[1:] + try: + client.connect(hostname=hostname, + username=username, + port=port, + sock=proxy_sock, + pkey=pkey, + **parameters) + except paramiko.ssh_exception.AuthenticationException as ex: + if auth_failed is not None: + ex.__cause__ = auth_failed + auth_failed = ex continue + except Exception as ex: + LOG.debug(f"Error logging in to '{login}': {ex}", exc_info=1) + attempt.check_limits() + break else: - raise - except (EOFError, socket.error, socket.timeout, - paramiko.SSHException) as ex: - attempt.check_limits() - LOG.debug(f"Error logging in to '{login}': {ex}") + LOG.debug(f"Successfully logged in to '{login}'") + succeeded = True + return client, proxy_sock + finally: + if not succeeded: + try: + proxy_sock.close() + except Exception: + pass else: - LOG.debug(f"Successfully logged in to '{login}'") - return client, proxy_sock + if isinstance(auth_failed, Exception): + raise auth_failed def ssh_proxy_sock(hostname=None, port=None, command=None, client=None,