tobiko/tobiko/tests/functional/shell/sh/test_connection.py

231 lines
7.3 KiB
Python

# Copyright (c) 2022 Red Hat, Inc.
#
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import getpass
import io
import os
import socket
import shutil
import tempfile
import typing
import uuid
import testtools
import tobiko
from tobiko.openstack import stacks
from tobiko.openstack import topology
from tobiko.shell import sh
from tobiko.shell import ssh
class LocalTempDirFixture(tobiko.SharedFixture):
path: typing.Optional[str] = None
def setup_fixture(self):
self.path = self.create_dir()
def cleanup_fixture(self):
path = self.path
if path is not None:
try:
self.delete_dir()
finally:
del self.path
def create_dir(self) -> str:
return tempfile.mkdtemp()
def delete_dir(self):
if os.path.isdir(self.path):
shutil.rmtree(self.path)
class LocalShellConnectionTest(testtools.TestCase):
@property
def ssh_client(self) -> ssh.SSHClientType:
return False
@property
def connection(self) -> sh.ShellConnection:
return sh.shell_connection(ssh_client=self.ssh_client)
connection_class = sh.LocalShellConnection
def test_shell_connection(self):
connection = sh.shell_connection(ssh_client=self.ssh_client)
self.assertIsInstance(connection, self.connection_class)
self.assertIs(self.connection, connection)
@property
def is_local(self) -> bool:
return True
def test_is_local(self):
self.assertIs(self.is_local, self.connection.is_local)
def test_is_local_connection(self):
is_local = sh.is_local_connection(ssh_client=self.ssh_client)
self.assertIs(self.is_local, is_local)
@property
def hostname(self) -> str:
return socket.gethostname()
def test_hostname(self):
self.assertEqual(self.hostname, self.connection.hostname)
def test_connection_hostname(self):
hostname = sh.connection_hostname(ssh_client=self.ssh_client)
self.assertEqual(self.hostname, hostname)
@property
def login(self) -> str:
return f"{self.username}@{self.hostname}"
def test_login(self):
self.assertEqual(self.login, self.connection.login)
def test_connection_login(self):
login = sh.connection_login(ssh_client=self.ssh_client)
self.assertEqual(self.login, login)
@property
def username(self) -> str:
return getpass.getuser()
def test_username(self):
self.assertEqual(self.username, self.connection.username)
def test_connection_username(self):
username = sh.connection_username(ssh_client=self.ssh_client)
self.assertEqual(self.username, username)
@property
def is_cirros(self) -> bool:
return False
def test_is_cirros(self):
self.assertIs(self.is_cirros, self.connection.is_cirros)
def test_is_cirros_connection(self):
is_cirros = sh.is_cirros_connection(ssh_client=self.ssh_client)
self.assertIs(self.is_cirros, is_cirros)
@property
def local_connection(self) -> sh.LocalShellConnection:
return sh.local_shell_connection()
def test_get_file(self):
local_file = os.path.join(self.local_connection.make_temp_dir(),
'local_file')
remote_file = os.path.join(self.connection.make_temp_dir(),
'remote_file')
text = str(uuid.uuid4())
sh.execute(f"echo '{text}' > '{remote_file}'",
ssh_client=self.ssh_client)
self.assertFalse(os.path.isfile(local_file))
sh.get_file(local_file=local_file,
remote_file=remote_file,
ssh_client=self.ssh_client)
self.assertTrue(os.path.isfile(local_file), 'file not copied')
with io.open(local_file, 'rt') as fd:
self.assertEqual(f'{text}\n', fd.read())
def test_put_file(self):
local_file = os.path.join(self.local_connection.make_temp_dir(),
'local_file')
remote_file = os.path.join(self.connection.make_temp_dir(),
'remote_file')
text = str(uuid.uuid4())
with io.open(local_file, 'wt') as fd:
fd.write(text)
self.assertRaises(sh.ShellCommandFailed,
sh.execute, f"cat '{remote_file}'",
ssh_client=self.ssh_client)
sh.put_file(remote_file=remote_file,
local_file=local_file,
ssh_client=self.ssh_client)
output = sh.execute(f"cat '{remote_file}'",
ssh_client=self.ssh_client).stdout
self.assertEqual(text, output)
def test_put_files(self):
local_dir = self.local_connection.make_temp_dir()
local_file = os.path.join(local_dir, 'some_file')
remote_dir = self.connection.make_temp_dir()
remote_file = os.path.join(remote_dir,
os.path.basename(local_dir),
'some_file')
text = str(uuid.uuid4())
with io.open(local_file, 'wt') as fd:
fd.write(text)
self.assertRaises(sh.ShellCommandFailed,
sh.execute, f"cat '{remote_file}'",
ssh_client=self.ssh_client)
sh.put_files(local_dir,
remote_dir=remote_dir,
ssh_client=self.ssh_client)
output = sh.execute(f"cat '{remote_file}'",
ssh_client=self.ssh_client).stdout
self.assertEqual(text, output)
class SSHShellConnectionTest(LocalShellConnectionTest):
connection_class = sh.SSHShellConnection
server = tobiko.required_fixture(stacks.UbuntuMinimalServerStackFixture)
@property
def ssh_client(self) -> ssh.SSHClientFixture:
ssh_client = ssh.ssh_proxy_client()
if isinstance(ssh_client, ssh.SSHClientFixture):
return ssh_client
nodes = topology.list_openstack_nodes()
for node in nodes:
if isinstance(node.ssh_client, ssh.SSHClientFixture):
return node.ssh_client
return self.server.ssh_client
@property
def is_local(self) -> bool:
return False
@property
def hostname(self) -> str:
return sh.get_hostname(ssh_client=self.ssh_client)
@property
def username(self) -> str:
return self.ssh_client.username
class CirrosShellConnectionTest(SSHShellConnectionTest):
connection_class = stacks.CirrosShellConnection
server = tobiko.required_fixture(stacks.CirrosServerStackFixture)
@property
def ssh_client(self) -> ssh.SSHClientFixture:
return self.server.ssh_client
@property
def is_cirros(self) -> bool:
return True