Record SSH public keys for new nodes in ZK

Change-Id: I3ad63196d584d8dc93a8bcdd9b211f8f6a65bf2f
Story: 2000897
This commit is contained in:
David Shrewsbury 2017-03-13 15:02:44 -04:00
parent 4bc703883e
commit 88042886be
5 changed files with 54 additions and 1 deletions

View File

@ -69,6 +69,10 @@ class LaunchAuthException(Exception):
statsd_key = 'error.auth' statsd_key = 'error.auth'
class LaunchKeyscanException(Exception):
statsd_key = 'error.keyscan'
class StatsReporter(object): class StatsReporter(object):
''' '''
Class adding statsd reporting functionality. Class adding statsd reporting functionality.
@ -356,6 +360,14 @@ class NodeLauncher(threading.Thread, StatsReporter):
if not host: if not host:
raise LaunchAuthException("Unable to connect via ssh") raise LaunchAuthException("Unable to connect via ssh")
# Get the SSH public keys for the new node and record in ZooKeeper
self.log.debug("Gathering host keys for node %s", self._node.id)
host_keys = utils.keyscan(preferred_ip)
if not host_keys:
raise LaunchKeyscanException("Unable to gather host keys")
self._node.host_keys = host_keys
self._zk.storeNode(self._node)
self._writeNodepoolInfo(host, preferred_ip, self._node) self._writeNodepoolInfo(host, preferred_ip, self._node)
if self._label.ready_script: if self._label.ready_script:
self._runReadyScript(host, hostname, self._label.ready_script) self._runReadyScript(host, hostname, self._label.ready_script)

View File

@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import errno import errno
import time import time
import socket import socket
@ -73,3 +74,34 @@ def ssh_connect(ip, username, connect_kwargs={}, timeout=60):
if "access okay" in out: if "access okay" in out:
return client return client
return None return None
def keyscan(ip):
'''
Scan the IP address for public SSH keys.
Keys are returned formatted as: "<type> <base64_string>"
'''
if 'fake' in ip:
return ['ssh-rsa FAKEKEY']
keys = []
key = None
try:
t = paramiko.transport.Transport('%s:%s' % (ip, "22"))
t.start_client()
key = t.get_remote_server_key()
t.close()
except Exception as e:
log.exception("ssh-keyscan failure: %s", e)
# Paramiko, at this time, seems to return only the ssh-rsa key, so
# only the single key is placed into the list.
if key:
keys.append(
"%s %s" % (key.get_name(),
base64.encodestring(str(key)).replace('\n', ''))
)
return keys

View File

@ -202,6 +202,7 @@ class TestNodepool(tests.DBTestCase):
self.assertEqual(len(nodes), 1) self.assertEqual(len(nodes), 1)
self.assertEqual(nodes[0].provider, 'fake-provider') self.assertEqual(nodes[0].provider, 'fake-provider')
self.assertEqual(nodes[0].type, 'fake-label') self.assertEqual(nodes[0].type, 'fake-label')
self.assertNotEqual(nodes[0].host_keys, [])
def test_disabled_label(self): def test_disabled_label(self):
"""Test that a node is not created with min-ready=0""" """Test that a node is not created with min-ready=0"""

View File

@ -772,6 +772,7 @@ class TestZKModel(tests.BaseTestCase):
o.external_id = 'ABCD' o.external_id = 'ABCD'
o.hostname = 'xyz' o.hostname = 'xyz'
o.comment = 'comment' o.comment = 'comment'
o.host_keys = ['key1', 'key2']
d = o.toDict() d = o.toDict()
self.assertNotIn('id', d) self.assertNotIn('id', d)
@ -790,6 +791,7 @@ class TestZKModel(tests.BaseTestCase):
self.assertEqual(d['external_id'], o.external_id) self.assertEqual(d['external_id'], o.external_id)
self.assertEqual(d['hostname'], o.hostname) self.assertEqual(d['hostname'], o.hostname)
self.assertEqual(d['comment'], o.comment) self.assertEqual(d['comment'], o.comment)
self.assertEqual(d['host_keys'], o.host_keys)
def test_Node_fromDict(self): def test_Node_fromDict(self):
now = int(time.time()) now = int(time.time())
@ -810,6 +812,7 @@ class TestZKModel(tests.BaseTestCase):
'external_id': 'ABCD', 'external_id': 'ABCD',
'hostname': 'xyz', 'hostname': 'xyz',
'comment': 'comment', 'comment': 'comment',
'host_keys': ['key1', 'key2'],
} }
o = zk.Node.fromDict(d, node_id) o = zk.Node.fromDict(d, node_id)
@ -829,3 +832,4 @@ class TestZKModel(tests.BaseTestCase):
self.assertEqual(o.external_id, d['external_id']) self.assertEqual(o.external_id, d['external_id'])
self.assertEqual(o.hostname , d['hostname']) self.assertEqual(o.hostname , d['hostname'])
self.assertEqual(o.comment , d['comment']) self.assertEqual(o.comment , d['comment'])
self.assertEqual(o.host_keys , d['host_keys'])

View File

@ -416,6 +416,7 @@ class Node(BaseModel):
self.external_id = None self.external_id = None
self.hostname = None self.hostname = None
self.comment = None self.comment = None
self.host_keys = []
def __repr__(self): def __repr__(self):
d = self.toDict() d = self.toDict()
@ -440,7 +441,8 @@ class Node(BaseModel):
self.created_time == other.created_time and self.created_time == other.created_time and
self.external_id == other.external_id and self.external_id == other.external_id and
self.hostname == other.hostname and self.hostname == other.hostname and
self.comment == other.comment) self.comment == other.comment,
self.host_keys == other.host_keys)
else: else:
return False return False
@ -462,6 +464,7 @@ class Node(BaseModel):
d['external_id'] = self.external_id d['external_id'] = self.external_id
d['hostname'] = self.hostname d['hostname'] = self.hostname
d['comment'] = self.comment d['comment'] = self.comment
d['host_keys'] = self.host_keys
return d return d
@staticmethod @staticmethod
@ -489,6 +492,7 @@ class Node(BaseModel):
o.external_id = d.get('external_id') o.external_id = d.get('external_id')
o.hostname = d.get('hostname') o.hostname = d.get('hostname')
o.comment = d.get('comment') o.comment = d.get('comment')
o.host_keys = d.get('host_keys', [])
return o return o