Add put_files function to copy multiple files or directories

Change-Id: I26a86f5b7579f7b12e0a35ab2e9c7bf5738bbfd4
This commit is contained in:
Federico Ressi 2022-05-24 10:49:18 +02:00
parent fc03c3052a
commit 7d3d0ce526
3 changed files with 60 additions and 0 deletions

View File

@ -52,6 +52,7 @@ shell_connection = _connection.shell_connection
is_cirros_connection = _connection.is_cirros_connection
is_local_connection = _connection.is_local_connection
put_file = _connection.put_file
put_files = _connection.put_files
get_file = _connection.get_file
make_temp_dir = _connection.make_temp_dir
make_dirs = _connection.make_dirs

View File

@ -72,6 +72,14 @@ def put_file(local_file: str,
remote_file=remote_file)
def put_files(*local_files: str,
remote_dir: str,
ssh_client: ssh.SSHClientType = None) -> bool:
return shell_connection(
ssh_client=ssh_client).put_files(*local_files,
remote_dir=remote_dir)
def make_temp_dir(ssh_client: ssh.SSHClientType = None,
auto_clean=True,
sudo: bool = None) -> str:
@ -197,6 +205,37 @@ class ShellConnection(tobiko.SharedFixture):
def put_file(self, local_file: str, remote_file: str):
raise NotImplementedError
def put_files(self,
*local_files: str,
remote_dir: str,
make_dirs=True):
# pylint: disable=redefined-outer-name
remote_dir = os.path.normpath(remote_dir)
put_files = {}
for local_file in local_files:
local_file = os.path.normpath(local_file)
if os.path.isdir(local_file):
top_dir = os.path.dirname(local_file)
for local_dir, _, files in os.walk(local_file):
for filename in files:
local_file = os.path.join(local_dir, filename)
remote_file = os.path.join(
remote_dir,
os.path.relpath(local_file, start=top_dir))
put_files[os.path.realpath(local_file)] = remote_file
else:
remote_file = os.path.join(
remote_dir, os.path.basename(local_file))
put_files[os.path.realpath(local_file)] = remote_file
remote_dirs = set()
for local_file, remote_file in sorted(put_files.items()):
if make_dirs:
remote_dir = os.path.dirname(remote_file)
if remote_dir not in remote_dirs:
self.make_dirs(remote_dir, exist_ok=True)
remote_dirs.add(remote_dir)
self.put_file(local_file, remote_file)
def get_file(self, remote_file: str, local_file: str):
raise NotImplementedError

View File

@ -166,6 +166,26 @@ class LocalShellConnectionTest(testtools.TestCase):
ssh_client=self.ssh_client).stdout
self.assertEqual(text, output)
def test_put_files(self):
local_dir = self.local_connection.make_temp_dir()
local_file = os.path.join(local_dir, 'some_file')
remote_dir = self.connection.make_temp_dir()
remote_file = os.path.join(remote_dir,
os.path.basename(local_dir),
'some_file')
text = str(uuid.uuid4())
with io.open(local_file, 'wt') as fd:
fd.write(text)
self.assertRaises(sh.ShellCommandFailed,
sh.execute, f"cat '{remote_file}'",
ssh_client=self.ssh_client)
sh.put_files(local_dir,
remote_dir=remote_dir,
ssh_client=self.ssh_client)
output = sh.execute(f"cat '{remote_file}'",
ssh_client=self.ssh_client).stdout
self.assertEqual(text, output)
class SSHShellConnectionTest(LocalShellConnectionTest):
connection_class = sh.SSHShellConnection