From 2da274e2ae6cf42948d6d758bdd5c5851355b278 Mon Sep 17 00:00:00 2001 From: Tobias Henkel Date: Tue, 12 Sep 2017 22:04:23 +0200 Subject: [PATCH] Don't gather host keys for non ssh connections In case of an image with the connection type winrm we cannot scan the ssh host keys. So in case the connection type is not ssh we need to skip gathering the host keys. Change-Id: I56f308baa10d40461cf4a919bbcdc4467e85a551 --- nodepool/driver/openstack/handler.py | 15 ++++++++++----- nodepool/driver/static/provider.py | 13 +++++++------ nodepool/exceptions.py | 2 +- nodepool/nodeutils.py | 17 +++++++++++------ nodepool/tests/test_launcher.py | 1 + 5 files changed, 30 insertions(+), 18 deletions(-) diff --git a/nodepool/driver/openstack/handler.py b/nodepool/driver/openstack/handler.py index 6550ed494..1c02ab46f 100644 --- a/nodepool/driver/openstack/handler.py +++ b/nodepool/driver/openstack/handler.py @@ -194,18 +194,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter): self._node.interface_ip, self._node.public_ipv4, self._node.public_ipv6)) - # Get the SSH public keys for the new node and record in ZooKeeper + # wait and scan the new node and record in ZooKeeper host_keys = [] if self._pool.host_key_checking: try: self.log.debug( "Gathering host keys for node %s", self._node.id) - host_keys = utils.keyscan( - interface_ip, timeout=self._provider_config.boot_timeout) - if not host_keys: + # only gather host keys if the connection type is ssh + gather_host_keys = connection_type == 'ssh' + host_keys = utils.nodescan( + interface_ip, + timeout=self._provider_config.boot_timeout, + gather_hostkeys=gather_host_keys) + + if gather_host_keys and not host_keys: raise exceptions.LaunchKeyscanException( "Unable to gather host keys") - except exceptions.SSHTimeoutException: + except exceptions.ConnectionTimeoutException: self.logConsole(self._node.external_id, self._node.hostname) raise diff --git a/nodepool/driver/static/provider.py b/nodepool/driver/static/provider.py index 212f20648..42da57d81 100644 --- a/nodepool/driver/static/provider.py +++ b/nodepool/driver/static/provider.py @@ -16,7 +16,7 @@ import logging from nodepool import exceptions from nodepool.driver import Provider -from nodepool.nodeutils import keyscan +from nodepool.nodeutils import nodescan class StaticNodeError(Exception): @@ -36,11 +36,12 @@ class StaticNodeProvider(Provider): def checkHost(self, node): # Check node is reachable try: - keys = keyscan(node["name"], - port=node["ssh-port"], - timeout=node["timeout"]) - except exceptions.SSHTimeoutException: - raise StaticNodeError("%s: SSHTimeoutException" % node["name"]) + keys = nodescan(node["name"], + port=node["ssh-port"], + timeout=node["timeout"]) + except exceptions.ConnectionTimeoutException: + raise StaticNodeError( + "%s: ConnectionTimeoutException" % node["name"]) # Check node host-key if set(node["host-key"]).issubset(set(keys)): diff --git a/nodepool/exceptions.py b/nodepool/exceptions.py index c754e4943..44cabfe59 100755 --- a/nodepool/exceptions.py +++ b/nodepool/exceptions.py @@ -49,7 +49,7 @@ class TimeoutException(Exception): pass -class SSHTimeoutException(TimeoutException): +class ConnectionTimeoutException(TimeoutException): statsd_key = 'error.ssh' diff --git a/nodepool/nodeutils.py b/nodepool/nodeutils.py index 3c6de886b..39bfb0b7b 100755 --- a/nodepool/nodeutils.py +++ b/nodepool/nodeutils.py @@ -57,14 +57,17 @@ def set_node_ip(node): "Unable to find public IP of server") -def keyscan(ip, port=22, timeout=60): +def nodescan(ip, port=22, timeout=60, gather_hostkeys=True): ''' Scan the IP address for public SSH keys. Keys are returned formatted as: " " ''' if 'fake' in ip: - return ['ssh-rsa FAKEKEY'] + if gather_hostkeys: + return ['ssh-rsa FAKEKEY'] + else: + return [] addrinfo = socket.getaddrinfo(ip, port)[0] family = addrinfo[0] @@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60): keys = [] key = None for count in iterate_timeout( - timeout, exceptions.SSHTimeoutException, "ssh access"): + timeout, exceptions.ConnectionTimeoutException, + "connection on port %s" % port): sock = None t = None try: sock = socket.socket(family, socket.SOCK_STREAM) sock.settimeout(timeout) sock.connect(sockaddr) - t = paramiko.transport.Transport(sock) - t.start_client(timeout=timeout) - key = t.get_remote_server_key() + if gather_hostkeys: + t = paramiko.transport.Transport(sock) + t.start_client(timeout=timeout) + key = t.get_remote_server_key() break except socket.error as e: if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]: diff --git a/nodepool/tests/test_launcher.py b/nodepool/tests/test_launcher.py index 0b4d67b09..56a881358 100644 --- a/nodepool/tests/test_launcher.py +++ b/nodepool/tests/test_launcher.py @@ -949,6 +949,7 @@ class TestLauncher(tests.DBTestCase): self.assertEqual(len(nodes), 1) self.assertEqual('zuul', nodes[0].username) self.assertEqual('winrm', nodes[0].connection_type) + self.assertEqual(nodes[0].host_keys, []) def test_unmanaged_image_provider_name(self): """