Update shell command join and quote methods

Change-Id: I57f913f5230e82909d59ce4412e04b1bff101e44
This commit is contained in:
Federico Ressi 2021-06-29 10:15:33 +02:00
parent 76d1859695
commit 3e3761e7c2
2 changed files with 57 additions and 42 deletions

View File

@ -15,8 +15,9 @@
# under the License.
from __future__ import absolute_import
import re
import shlex
import typing # noqa
import typing
ShellCommandType = typing.Union['ShellCommand', str, typing.Iterable[str]]
@ -28,63 +29,49 @@ class ShellCommand(tuple):
return f"ShellCommand({str(self)!r})"
def __str__(self) -> str:
return join_command(self)
return join(self)
def __add__(self, other: ShellCommandType) -> 'ShellCommand':
return shell_command(tuple(self) + shell_command(other))
def shell_command(command: ShellCommandType) -> ShellCommand:
def shell_command(command: ShellCommandType,
**shlex_params) -> ShellCommand:
if isinstance(command, ShellCommand):
return command
elif isinstance(command, str):
return ShellCommand(split_command(command))
return split(command, **shlex_params)
else:
return ShellCommand(str(a) for a in command)
NEED_QUOTE_CHARS = {' ', '\t', '\n', '\r', "'", '"'}
_find_unsafe = re.compile(r'[^\w@&%+=:,.;<>/\-()\[\]|*]', re.ASCII).search
_is_quoted = re.compile(r'(^\'.*\'$)|(^".*"$)', re.ASCII).search
def join_command(sequence: typing.Iterable[str]) -> str:
result: typing.List[str] = []
for arg in sequence:
bs_buf: typing.List[str] = []
def quote(s: str):
"""Return a shell-escaped version of the string *s*."""
if not s:
return "''"
# Add a space to separate this argument from the others
if result:
result.append(' ')
if _is_quoted(s):
return s
needquote = (" " in arg) or ("\t" in arg) or not arg
if needquote:
result.append("'")
if _find_unsafe(s) is None:
return s
for c in arg:
if c == '\\':
# Don't know if we need to double yet.
bs_buf.append(c)
elif c == '"':
# Double backslashes.
result.append('\\' * len(bs_buf)*2)
bs_buf = []
result.append('\\"')
else:
# Normal char
if bs_buf:
result.extend(bs_buf)
bs_buf = []
result.append(c)
# Add remaining backslashes, if any.
if bs_buf:
result.extend(bs_buf)
if needquote:
result.extend(bs_buf)
result.append("'")
return ''.join(result)
# use single quotes, and put single quotes into double quotes
# the string $'b is then quoted as '$'"'"'b'
return "'" + s.replace("'", "'\"'\"'") + "'"
def split_command(command: str) -> typing.Sequence[str]:
return shlex.split(command)
def join(sequence: typing.Iterable[str]) -> str:
return ' '.join(quote(s)
for s in sequence)
def split(command: str, posix=True, **shlex_params) -> ShellCommand:
lex = shlex.shlex(command, posix=posix, **shlex_params)
lex.whitespace_split = True
return ShellCommand(lex)

View File

@ -20,6 +20,9 @@ from tobiko.shell import sh
from tobiko.tests import unit
SPECIAL_CHARS = r'@&%+=:,.;<>/-()[]*|'
class ShellCommandTest(unit.TobikoUnitTest):
def test_from_str(self):
@ -51,6 +54,25 @@ class ShellCommandTest(unit.TobikoUnitTest):
result = sh.shell_command(other)
self.assertIs(other, result)
def test_from_special_chars(self):
command = sh.shell_command(SPECIAL_CHARS)
self.assertEqual((SPECIAL_CHARS,), command)
self.assertEqual(SPECIAL_CHARS, str(command))
def test_from_journalctl_command(self):
command_line = (
'journalctl', '--no-pager', '--unit',
'devstack@q-svc', '--since', '30 minutes ago',
'--output', 'short-iso', '--grep',
"'Nova.+event.+response.*09e69236-2a3b-4077-bd50-0c80946bf5b3'")
command = sh.shell_command(command_line)
self.assertEqual(command_line, command)
self.assertEqual(
"journalctl --no-pager --unit devstack@q-svc "
"--since '30 minutes ago' --output short-iso --grep "
"'Nova.+event.+response.*09e69236-2a3b-4077-bd50-0c80946bf5b3'",
str(command))
def test_add_str(self):
base = sh.shell_command('ssh pippo@clubhouse.mouse')
result = base + 'ls -lh *.py'
@ -95,3 +117,9 @@ class ShellCommandTest(unit.TobikoUnitTest):
self.assertIsInstance(result, sh.ShellCommand)
self.assertEqual(('sh', '-c', "echo Hello!"), result)
self.assertEqual("sh -c 'echo Hello!'", str(result))
def test_add_special_chars(self):
base = sh.shell_command('echo')
result = base + SPECIAL_CHARS
self.assertEqual(('echo', SPECIAL_CHARS), result)
self.assertEqual('echo ' + SPECIAL_CHARS, str(result))