Don't gather host keys for non ssh connections

In case of an image with the connection type winrm we cannot scan the
ssh host keys. So in case the connection type is not ssh we
need to skip gathering the host keys.

Change-Id: I56f308baa10d40461cf4a919bbcdc4467e85a551
This commit is contained in:
Tobias Henkel 2017-09-12 22:04:23 +02:00
parent ee78684521
commit 2da274e2ae
No known key found for this signature in database
GPG Key ID: 03750DEC158E5FA2
5 changed files with 30 additions and 18 deletions

View File

@ -194,18 +194,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter):
self._node.interface_ip, self._node.public_ipv4, self._node.interface_ip, self._node.public_ipv4,
self._node.public_ipv6)) self._node.public_ipv6))
# Get the SSH public keys for the new node and record in ZooKeeper # wait and scan the new node and record in ZooKeeper
host_keys = [] host_keys = []
if self._pool.host_key_checking: if self._pool.host_key_checking:
try: try:
self.log.debug( self.log.debug(
"Gathering host keys for node %s", self._node.id) "Gathering host keys for node %s", self._node.id)
host_keys = utils.keyscan( # only gather host keys if the connection type is ssh
interface_ip, timeout=self._provider_config.boot_timeout) gather_host_keys = connection_type == 'ssh'
if not host_keys: host_keys = utils.nodescan(
interface_ip,
timeout=self._provider_config.boot_timeout,
gather_hostkeys=gather_host_keys)
if gather_host_keys and not host_keys:
raise exceptions.LaunchKeyscanException( raise exceptions.LaunchKeyscanException(
"Unable to gather host keys") "Unable to gather host keys")
except exceptions.SSHTimeoutException: except exceptions.ConnectionTimeoutException:
self.logConsole(self._node.external_id, self._node.hostname) self.logConsole(self._node.external_id, self._node.hostname)
raise raise

View File

@ -16,7 +16,7 @@ import logging
from nodepool import exceptions from nodepool import exceptions
from nodepool.driver import Provider from nodepool.driver import Provider
from nodepool.nodeutils import keyscan from nodepool.nodeutils import nodescan
class StaticNodeError(Exception): class StaticNodeError(Exception):
@ -36,11 +36,12 @@ class StaticNodeProvider(Provider):
def checkHost(self, node): def checkHost(self, node):
# Check node is reachable # Check node is reachable
try: try:
keys = keyscan(node["name"], keys = nodescan(node["name"],
port=node["ssh-port"], port=node["ssh-port"],
timeout=node["timeout"]) timeout=node["timeout"])
except exceptions.SSHTimeoutException: except exceptions.ConnectionTimeoutException:
raise StaticNodeError("%s: SSHTimeoutException" % node["name"]) raise StaticNodeError(
"%s: ConnectionTimeoutException" % node["name"])
# Check node host-key # Check node host-key
if set(node["host-key"]).issubset(set(keys)): if set(node["host-key"]).issubset(set(keys)):

View File

@ -49,7 +49,7 @@ class TimeoutException(Exception):
pass pass
class SSHTimeoutException(TimeoutException): class ConnectionTimeoutException(TimeoutException):
statsd_key = 'error.ssh' statsd_key = 'error.ssh'

View File

@ -57,14 +57,17 @@ def set_node_ip(node):
"Unable to find public IP of server") "Unable to find public IP of server")
def keyscan(ip, port=22, timeout=60): def nodescan(ip, port=22, timeout=60, gather_hostkeys=True):
''' '''
Scan the IP address for public SSH keys. Scan the IP address for public SSH keys.
Keys are returned formatted as: "<type> <base64_string>" Keys are returned formatted as: "<type> <base64_string>"
''' '''
if 'fake' in ip: if 'fake' in ip:
return ['ssh-rsa FAKEKEY'] if gather_hostkeys:
return ['ssh-rsa FAKEKEY']
else:
return []
addrinfo = socket.getaddrinfo(ip, port)[0] addrinfo = socket.getaddrinfo(ip, port)[0]
family = addrinfo[0] family = addrinfo[0]
@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60):
keys = [] keys = []
key = None key = None
for count in iterate_timeout( for count in iterate_timeout(
timeout, exceptions.SSHTimeoutException, "ssh access"): timeout, exceptions.ConnectionTimeoutException,
"connection on port %s" % port):
sock = None sock = None
t = None t = None
try: try:
sock = socket.socket(family, socket.SOCK_STREAM) sock = socket.socket(family, socket.SOCK_STREAM)
sock.settimeout(timeout) sock.settimeout(timeout)
sock.connect(sockaddr) sock.connect(sockaddr)
t = paramiko.transport.Transport(sock) if gather_hostkeys:
t.start_client(timeout=timeout) t = paramiko.transport.Transport(sock)
key = t.get_remote_server_key() t.start_client(timeout=timeout)
key = t.get_remote_server_key()
break break
except socket.error as e: except socket.error as e:
if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]: if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]:

View File

@ -949,6 +949,7 @@ class TestLauncher(tests.DBTestCase):
self.assertEqual(len(nodes), 1) self.assertEqual(len(nodes), 1)
self.assertEqual('zuul', nodes[0].username) self.assertEqual('zuul', nodes[0].username)
self.assertEqual('winrm', nodes[0].connection_type) self.assertEqual('winrm', nodes[0].connection_type)
self.assertEqual(nodes[0].host_keys, [])
def test_unmanaged_image_provider_name(self): def test_unmanaged_image_provider_name(self):
""" """