Don't gather host keys for non ssh connections
In case of an image with the connection type winrm we cannot scan the ssh host keys. So in case the connection type is not ssh we need to skip gathering the host keys. Change-Id: I56f308baa10d40461cf4a919bbcdc4467e85a551
This commit is contained in:
parent
ee78684521
commit
2da274e2ae
|
@ -194,18 +194,23 @@ class NodeLauncher(threading.Thread, stats.StatsReporter):
|
||||||
self._node.interface_ip, self._node.public_ipv4,
|
self._node.interface_ip, self._node.public_ipv4,
|
||||||
self._node.public_ipv6))
|
self._node.public_ipv6))
|
||||||
|
|
||||||
# Get the SSH public keys for the new node and record in ZooKeeper
|
# wait and scan the new node and record in ZooKeeper
|
||||||
host_keys = []
|
host_keys = []
|
||||||
if self._pool.host_key_checking:
|
if self._pool.host_key_checking:
|
||||||
try:
|
try:
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
"Gathering host keys for node %s", self._node.id)
|
"Gathering host keys for node %s", self._node.id)
|
||||||
host_keys = utils.keyscan(
|
# only gather host keys if the connection type is ssh
|
||||||
interface_ip, timeout=self._provider_config.boot_timeout)
|
gather_host_keys = connection_type == 'ssh'
|
||||||
if not host_keys:
|
host_keys = utils.nodescan(
|
||||||
|
interface_ip,
|
||||||
|
timeout=self._provider_config.boot_timeout,
|
||||||
|
gather_hostkeys=gather_host_keys)
|
||||||
|
|
||||||
|
if gather_host_keys and not host_keys:
|
||||||
raise exceptions.LaunchKeyscanException(
|
raise exceptions.LaunchKeyscanException(
|
||||||
"Unable to gather host keys")
|
"Unable to gather host keys")
|
||||||
except exceptions.SSHTimeoutException:
|
except exceptions.ConnectionTimeoutException:
|
||||||
self.logConsole(self._node.external_id, self._node.hostname)
|
self.logConsole(self._node.external_id, self._node.hostname)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
||||||
|
|
||||||
from nodepool import exceptions
|
from nodepool import exceptions
|
||||||
from nodepool.driver import Provider
|
from nodepool.driver import Provider
|
||||||
from nodepool.nodeutils import keyscan
|
from nodepool.nodeutils import nodescan
|
||||||
|
|
||||||
|
|
||||||
class StaticNodeError(Exception):
|
class StaticNodeError(Exception):
|
||||||
|
@ -36,11 +36,12 @@ class StaticNodeProvider(Provider):
|
||||||
def checkHost(self, node):
|
def checkHost(self, node):
|
||||||
# Check node is reachable
|
# Check node is reachable
|
||||||
try:
|
try:
|
||||||
keys = keyscan(node["name"],
|
keys = nodescan(node["name"],
|
||||||
port=node["ssh-port"],
|
port=node["ssh-port"],
|
||||||
timeout=node["timeout"])
|
timeout=node["timeout"])
|
||||||
except exceptions.SSHTimeoutException:
|
except exceptions.ConnectionTimeoutException:
|
||||||
raise StaticNodeError("%s: SSHTimeoutException" % node["name"])
|
raise StaticNodeError(
|
||||||
|
"%s: ConnectionTimeoutException" % node["name"])
|
||||||
|
|
||||||
# Check node host-key
|
# Check node host-key
|
||||||
if set(node["host-key"]).issubset(set(keys)):
|
if set(node["host-key"]).issubset(set(keys)):
|
||||||
|
|
|
@ -49,7 +49,7 @@ class TimeoutException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SSHTimeoutException(TimeoutException):
|
class ConnectionTimeoutException(TimeoutException):
|
||||||
statsd_key = 'error.ssh'
|
statsd_key = 'error.ssh'
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,14 +57,17 @@ def set_node_ip(node):
|
||||||
"Unable to find public IP of server")
|
"Unable to find public IP of server")
|
||||||
|
|
||||||
|
|
||||||
def keyscan(ip, port=22, timeout=60):
|
def nodescan(ip, port=22, timeout=60, gather_hostkeys=True):
|
||||||
'''
|
'''
|
||||||
Scan the IP address for public SSH keys.
|
Scan the IP address for public SSH keys.
|
||||||
|
|
||||||
Keys are returned formatted as: "<type> <base64_string>"
|
Keys are returned formatted as: "<type> <base64_string>"
|
||||||
'''
|
'''
|
||||||
if 'fake' in ip:
|
if 'fake' in ip:
|
||||||
return ['ssh-rsa FAKEKEY']
|
if gather_hostkeys:
|
||||||
|
return ['ssh-rsa FAKEKEY']
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
addrinfo = socket.getaddrinfo(ip, port)[0]
|
addrinfo = socket.getaddrinfo(ip, port)[0]
|
||||||
family = addrinfo[0]
|
family = addrinfo[0]
|
||||||
|
@ -73,16 +76,18 @@ def keyscan(ip, port=22, timeout=60):
|
||||||
keys = []
|
keys = []
|
||||||
key = None
|
key = None
|
||||||
for count in iterate_timeout(
|
for count in iterate_timeout(
|
||||||
timeout, exceptions.SSHTimeoutException, "ssh access"):
|
timeout, exceptions.ConnectionTimeoutException,
|
||||||
|
"connection on port %s" % port):
|
||||||
sock = None
|
sock = None
|
||||||
t = None
|
t = None
|
||||||
try:
|
try:
|
||||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||||
sock.settimeout(timeout)
|
sock.settimeout(timeout)
|
||||||
sock.connect(sockaddr)
|
sock.connect(sockaddr)
|
||||||
t = paramiko.transport.Transport(sock)
|
if gather_hostkeys:
|
||||||
t.start_client(timeout=timeout)
|
t = paramiko.transport.Transport(sock)
|
||||||
key = t.get_remote_server_key()
|
t.start_client(timeout=timeout)
|
||||||
|
key = t.get_remote_server_key()
|
||||||
break
|
break
|
||||||
except socket.error as e:
|
except socket.error as e:
|
||||||
if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]:
|
if e.errno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH, None]:
|
||||||
|
|
|
@ -949,6 +949,7 @@ class TestLauncher(tests.DBTestCase):
|
||||||
self.assertEqual(len(nodes), 1)
|
self.assertEqual(len(nodes), 1)
|
||||||
self.assertEqual('zuul', nodes[0].username)
|
self.assertEqual('zuul', nodes[0].username)
|
||||||
self.assertEqual('winrm', nodes[0].connection_type)
|
self.assertEqual('winrm', nodes[0].connection_type)
|
||||||
|
self.assertEqual(nodes[0].host_keys, [])
|
||||||
|
|
||||||
def test_unmanaged_image_provider_name(self):
|
def test_unmanaged_image_provider_name(self):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue