Merge "utils: Allow discovery of private key in ~/.ssh"

This commit is contained in:
Jenkins 2015-01-27 17:15:18 +00:00 committed by Gerrit Code Review
commit 1a9f738f4d
3 changed files with 49 additions and 27 deletions

View File

@ -477,3 +477,7 @@ class GaneshaCommandFailure(ProcessExecutionError):
class InvalidSqliteDB(Invalid): class InvalidSqliteDB(Invalid):
message = _("Invalid Sqlite database.") message = _("Invalid Sqlite database.")
class SSHException(ManilaException):
message = _("Exception in SSH protocol negotiation or logic.")

View File

@ -415,7 +415,7 @@ class FakeSSHClient(object):
pass pass
def connect(self, ip, port=22, username=None, password=None, def connect(self, ip, port=22, username=None, password=None,
pkey=None, timeout=10): key_filename=None, look_for_keys=None, timeout=10):
pass pass
def get_transport(self): def get_transport(self):
@ -473,28 +473,44 @@ class SSHPoolTestCase(test.TestCase):
fake_ssh_client.connect.assert_called_once_with( fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test", "127.0.0.1", port=22, username="test",
password="test", pkey=None, timeout=10) password="test", key_filename=None, look_for_keys=False,
timeout=10)
def test_create_ssh_with_key(self): def test_create_ssh_with_key(self):
key = os.path.expanduser("fake_key") path_to_private_key = "/fakepath/to/privatekey"
fake_ssh_client = mock.Mock() fake_ssh_client = mock.Mock()
ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test", ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test",
privatekey="fake_key") privatekey="/fakepath/to/privatekey")
with mock.patch.object(paramiko, "SSHClient", with mock.patch.object(paramiko, "SSHClient",
return_value=fake_ssh_client): return_value=fake_ssh_client):
with mock.patch.object(paramiko.RSAKey, "from_private_key_file", ssh_pool.create()
return_value=key) as from_private_key_mock: fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test", password=None,
ssh_pool.create() key_filename=path_to_private_key, look_for_keys=False,
from_private_key_mock.assert_called_once_with(key) timeout=10)
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test",
password=None, pkey=key, timeout=10)
def test_create_ssh_with_nothing(self): def test_create_ssh_with_nothing(self):
fake_ssh_client = mock.Mock()
ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test") ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test")
with mock.patch.object(paramiko, "SSHClient"): with mock.patch.object(paramiko, "SSHClient",
self.assertRaises(paramiko.SSHException, ssh_pool.create) return_value=fake_ssh_client):
ssh_pool.create()
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test", password=None,
key_filename=None, look_for_keys=True,
timeout=10)
def test_create_ssh_error_connecting(self):
attrs = {'connect.side_effect': paramiko.SSHException, }
fake_ssh_client = mock.Mock(**attrs)
ssh_pool = utils.SSHPool("127.0.0.1", 22, 10, "test")
with mock.patch.object(paramiko, "SSHClient",
return_value=fake_ssh_client):
self.assertRaises(exception.SSHException, ssh_pool.create)
fake_ssh_client.connect.assert_called_once_with(
"127.0.0.1", port=22, username="test", password=None,
key_filename=None, look_for_keys=True,
timeout=10)
def test_closed_reopend_ssh_connections(self): def test_closed_reopend_ssh_connections(self):
with mock.patch.object(paramiko, "SSHClient", with mock.patch.object(paramiko, "SSHClient",

View File

@ -80,26 +80,27 @@ class SSHPool(pools.Pool):
self.login = login self.login = login
self.password = password self.password = password
self.conn_timeout = conn_timeout if conn_timeout else None self.conn_timeout = conn_timeout if conn_timeout else None
self.privatekey = privatekey self.path_to_private_key = privatekey
super(SSHPool, self).__init__(*args, **kwargs) super(SSHPool, self).__init__(*args, **kwargs)
def create(self): def create(self):
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
look_for_keys = True
if self.path_to_private_key:
self.path_to_private_key = os.path.expanduser(
self.path_to_private_key)
look_for_keys = False
elif self.password:
look_for_keys = False
try: try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
if self.privatekey:
pkfile = os.path.expanduser(self.privatekey)
self.privatekey = paramiko.RSAKey.from_private_key_file(pkfile)
elif not self.password:
msg = _("Specify a password or private_key")
raise exception.ManilaException(msg)
ssh.connect(self.ip, ssh.connect(self.ip,
port=self.port, port=self.port,
username=self.login, username=self.login,
password=self.password, password=self.password,
pkey=self.privatekey, key_filename=self.path_to_private_key,
look_for_keys=look_for_keys,
timeout=self.conn_timeout) timeout=self.conn_timeout)
# Paramiko by default sets the socket timeout to 0.1 seconds, # Paramiko by default sets the socket timeout to 0.1 seconds,
# ignoring what we set thru the sshclient. This doesn't help for # ignoring what we set thru the sshclient. This doesn't help for
# keeping long lived connections. Hence we have to bypass it, by # keeping long lived connections. Hence we have to bypass it, by
@ -113,9 +114,10 @@ class SSHPool(pools.Pool):
transport.set_keepalive(self.conn_timeout) transport.set_keepalive(self.conn_timeout)
return ssh return ssh
except Exception as e: except Exception as e:
msg = _("Error connecting via ssh: %s") % e msg = _("Check whether private key or password are correctly "
"set. Error connecting via ssh: %s") % e
LOG.error(msg) LOG.error(msg)
raise paramiko.SSHException(msg) raise exception.SSHException(msg)
def get(self): def get(self):
"""Return an item from the pool, when one is available. """Return an item from the pool, when one is available.