Don't gather host keys for non ssh connections

In case of an image with the connection type winrm we cannot scan the
ssh host keys. So in case the connection type is not ssh we
need to skip gathering the host keys.

Change-Id: I56f308baa10d40461cf4a919bbcdc4467e85a551
This commit is contained in:
Tobias Henkel 2017-09-12 22:04:23 +02:00
parent ee78684521
commit 2da274e2ae
No known key found for this signature in database
GPG Key ID: 03750DEC158E5FA2
5 changed files with 30 additions and 18 deletions

View File

@ -194,18 +194,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter):
self._node.interface_ip, self._node.public_ipv4,
self._node.public_ipv6))
# Get the SSH public keys for the new node and record in ZooKeeper
# wait and scan the new node and record in ZooKeeper
host_keys = []
if self._pool.host_key_checking:
try:
self.log.debug(
"Gathering host keys for node %s", self._node.id)
host_keys = utils.keyscan(
interface_ip, timeout=self._provider_config.boot_timeout)
if not host_keys:
# only gather host keys if the connection type is ssh
gather_host_keys = connection_type == 'ssh'
host_keys = utils.nodescan(
interface_ip,
timeout=self._provider_config.boot_timeout,
gather_hostkeys=gather_host_keys)
if gather_host_keys and not host_keys:
raise exceptions.LaunchKeyscanException(
"Unable to gather host keys")
except exceptions.SSHTimeoutException:
except exceptions.ConnectionTimeoutException:
self.logConsole(self._node.external_id, self._node.hostname)
raise

View File

@ -16,7 +16,7 @@ import logging
from nodepool import exceptions
from nodepool.driver import Provider
from nodepool.nodeutils import keyscan
from nodepool.nodeutils import nodescan
class StaticNodeError(Exception):
@ -36,11 +36,12 @@ class StaticNodeProvider(Provider):
def checkHost(self, node):
# Check node is reachable
try:
keys = keyscan(node["name"],
port=node["ssh-port"],
timeout=node["timeout"])
except exceptions.SSHTimeoutException:
raise StaticNodeError("%s: SSHTimeoutException" % node["name"])
keys = nodescan(node["name"],
port=node["ssh-port"],
timeout=node["timeout"])
except exceptions.ConnectionTimeoutException:
raise StaticNodeError(
"%s: ConnectionTimeoutException" % node["name"])
# Check node host-key
if set(node["host-key"]).issubset(set(keys)):

View File

@ -49,7 +49,7 @@ class TimeoutException(Exception):
pass
class SSHTimeoutException(TimeoutException):
class ConnectionTimeoutException(TimeoutException):
statsd_key = 'error.ssh'

View File

@ -57,14 +57,17 @@ def set_node_ip(node):
"Unable to find public IP of server")
def keyscan(ip, port=22, timeout=60):
def nodescan(ip, port=22, timeout=60, gather_hostkeys=True):
'''
Scan the IP address for public SSH keys.
Keys are returned formatted as: "<type> <base64_string>"
'''
if 'fake' in ip:
return ['ssh-rsa FAKEKEY']
if gather_hostkeys:
return ['ssh-rsa FAKEKEY']
else:
return []
addrinfo = socket.getaddrinfo(ip, port)[0]
family = addrinfo[0]
@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60):
keys = []
key = None
for count in iterate_timeout(
timeout, exceptions.SSHTimeoutException, "ssh access"):
timeout, exceptions.ConnectionTimeoutException,
"connection on port %s" % port):
sock = None
t = None
try:
sock = socket.socket(family, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect(sockaddr)
t = paramiko.transport.Transport(sock)
t.start_client(timeout=timeout)
key = t.get_remote_server_key()
if gather_hostkeys:
t = paramiko.transport.Transport(sock)
t.start_client(timeout=timeout)
key = t.get_remote_server_key()
break
except socket.error as e:
if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]:

View File

@ -949,6 +949,7 @@ class TestLauncher(tests.DBTestCase):
self.assertEqual(len(nodes), 1)
self.assertEqual('zuul', nodes[0].username)
self.assertEqual('winrm', nodes[0].connection_type)
self.assertEqual(nodes[0].host_keys, [])
def test_unmanaged_image_provider_name(self):
"""