diff --git a/tests/base.py b/tests/base.py index 0f8b3af6eb..4dd4c64244 100755 --- a/tests/base.py +++ b/tests/base.py @@ -1837,12 +1837,20 @@ class ZuulTestCase(BaseTestCase): # Make per test copy of Configuration. self.setup_config() + self.private_key_file = os.path.join(self.test_root, 'test_id_rsa') + if not os.path.exists(self.private_key_file): + src_private_key_file = os.path.join(FIXTURE_DIR, 'test_id_rsa') + shutil.copy(src_private_key_file, self.private_key_file) + shutil.copy('{}.pub'.format(src_private_key_file), + '{}.pub'.format(self.private_key_file)) + os.chmod(self.private_key_file, 0o0600) self.config.set('zuul', 'tenant_config', os.path.join(FIXTURE_DIR, self.config.get('zuul', 'tenant_config'))) self.config.set('merger', 'git_dir', self.merger_src_root) self.config.set('executor', 'git_dir', self.executor_src_root) self.config.set('zuul', 'state_dir', self.state_root) + self.config.set('executor', 'private_key_file', self.private_key_file) self.statsd = FakeStatsd() # note, use 127.0.0.1 rather than localhost to avoid getting ipv6 diff --git a/tests/fixtures/test_id_rsa b/tests/fixtures/test_id_rsa new file mode 100644 index 0000000000..a793bd0096 --- /dev/null +++ b/tests/fixtures/test_id_rsa @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICWwIBAAKBgQCX10EQhi7hEMk1h7/fQaEj9H2DxWR0s3RXD5UI7j1Bn21tBUus +Y0tPC5wXES4VfilXg+EuOKsE6z8x8txP1wd1+d6Hq3SWXnOcqxxv2ueAy6Gc31E7 +a2IVDYvqVsAOtxsWddvMGTj98/lexQBX6Bh+wmuba/43lq5UPepwvfgNOQIDAQAB +AoGADMCHNlwOk9hVDanY82cPoXVnFSn+xc5MdwNYAOgBPQGmrwFC2bd9G6Zd9ZH7 +zNJLpo3s23Tm6ALZy9gZqJrmhWDZBOqeYtmkd0yUf5bCbUzNre8+gHJY8k9PAxVM +dPr2bq8G4PyN3yC2euTht35KLjb7hD8WiF3exgI/d8oBvgECQQDFKuWmkLtkSkGo +1KRbeBfRePbfzhGJ1yHRyO72Z1+hVXuRmtcjTfPhMikgx9dxWbpqr/RPgs7D7N8D +JpFlsiR/AkEAxSX4LOwovklPzCZ8FyfHhkydNgDyBw8y2Xe1OO0LBN51batf9rcl +rJBYFvulrD+seYNRCWBFpEi4KKZh4YESRwJAKmz+mYbPK9dmpYOMEjqXNXXH+YSH +9ZcbKd8IvHCl/Ts9qakd3fTqI2z9uJYH39Yk7MwL0Agfob0Yh78GzlE01QJACheu +g8Y3M76XCjFyKtFLgpGLfsc/nKLnjIB3U4m3BbHJuyqJyByKHjJpgAuz6IR99N6H +GH7IMefTHame2yd7YwJAUIGRD+iOO0RJvtEHUbsz6IxrQdubNOvzm/78eyBTcbsa +8996D18fJF6Q0/Gg0Cm65PNOpIthP3qxFkuuduUEUg== +-----END RSA PRIVATE KEY----- diff --git a/tests/fixtures/test_id_rsa.pub b/tests/fixtures/test_id_rsa.pub new file mode 100644 index 0000000000..bffc7265b4 --- /dev/null +++ b/tests/fixtures/test_id_rsa.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQCX10EQhi7hEMk1h7/fQaEj9H2DxWR0s3RXD5UI7j1Bn21tBUusY0tPC5wXES4VfilXg+EuOKsE6z8x8txP1wd1+d6Hq3SWXnOcqxxv2ueAy6Gc31E7a2IVDYvqVsAOtxsWddvMGTj98/lexQBX6Bh+wmuba/43lq5UPepwvfgNOQ== Private Key For Zuul Tests DO NOT USE diff --git a/tests/unit/test_ssh_agent.py b/tests/unit/test_ssh_agent.py new file mode 100644 index 0000000000..c9c1ebd513 --- /dev/null +++ b/tests/unit/test_ssh_agent.py @@ -0,0 +1,56 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import subprocess + +from tests.base import ZuulTestCase +from zuul.executor.server import SshAgent + + +class TestSshAgent(ZuulTestCase): + tenant_config_file = 'config/single-tenant/main.yaml' + + def test_ssh_agent(self): + # Need a private key to add + env_copy = dict(os.environ) + # DISPLAY and SSH_ASKPASS will cause interactive test runners to get a + # surprise + if 'DISPLAY' in env_copy: + del env_copy['DISPLAY'] + if 'SSH_ASKPASS' in env_copy: + del env_copy['SSH_ASKPASS'] + + agent = SshAgent() + agent.start() + env_copy.update(agent.env) + + pub_key_file = '{}.pub'.format(self.private_key_file) + pub_key = None + with open(pub_key_file) as pub_key_f: + pub_key = pub_key_f.read().split('== ')[0] + + agent.add(self.private_key_file) + keys = agent.list() + self.assertEqual(1, len(keys)) + self.assertEqual(keys[0].split('== ')[0], pub_key) + agent.remove(self.private_key_file) + keys = agent.list() + self.assertEqual([], keys) + agent.stop() + # Agent is now dead and thus this should fail + with open('/dev/null') as devnull: + self.assertRaises(subprocess.CalledProcessError, + subprocess.check_call, + ['ssh-add', self.private_key_file], + env=env_copy, + stderr=devnull) diff --git a/zuul/executor/server.py b/zuul/executor/server.py index 99d2a9c185..14bb588438 100644 --- a/zuul/executor/server.py +++ b/zuul/executor/server.py @@ -79,6 +79,78 @@ class JobDirPlaybook(object): self.path = None +class SshAgent(object): + log = logging.getLogger("zuul.ExecutorServer") + + def __init__(self): + self.env = {} + self.ssh_agent = None + + def start(self): + if self.ssh_agent: + return + with open('/dev/null', 'r+') as devnull: + ssh_agent = subprocess.Popen(['ssh-agent'], close_fds=True, + stdout=subprocess.PIPE, + stderr=devnull, + stdin=devnull) + (output, _) = ssh_agent.communicate() + output = output.decode('utf8') + for line in output.split("\n"): + if '=' in line: + line = line.split(";", 1)[0] + (key, value) = line.split('=') + self.env[key] = value + self.log.info('Started SSH Agent, {}'.format(self.env)) + + def stop(self): + if 'SSH_AGENT_PID' in self.env: + try: + os.kill(int(self.env['SSH_AGENT_PID']), signal.SIGTERM) + except OSError: + self.log.exception( + 'Problem sending SIGTERM to agent {}'.format(self.env)) + self.log.info('Sent SIGTERM to SSH Agent, {}'.format(self.env)) + self.env = {} + + def add(self, key_path): + env = os.environ.copy() + env.update(self.env) + key_path = os.path.expanduser(key_path) + self.log.debug('Adding SSH Key {}'.format(key_path)) + output = '' + try: + output = subprocess.check_output(['ssh-add', key_path], env=env, + stderr=subprocess.PIPE) + except subprocess.CalledProcessError: + self.log.error('ssh-add failed: {}'.format(output)) + raise + self.log.info('Added SSH Key {}'.format(key_path)) + + def remove(self, key_path): + env = os.environ.copy() + env.update(self.env) + key_path = os.path.expanduser(key_path) + self.log.debug('Removing SSH Key {}'.format(key_path)) + subprocess.check_output(['ssh-add', '-d', key_path], env=env, + stderr=subprocess.PIPE) + self.log.info('Removed SSH Key {}'.format(key_path)) + + def list(self): + if 'SSH_AUTH_SOCK' not in self.env: + return None + env = os.environ.copy() + env.update(self.env) + result = [] + for line in subprocess.Popen(['ssh-add', '-L'], env=env, + stdout=subprocess.PIPE).stdout: + line = line.decode('utf8') + if line.strip() == 'The agent has no identities.': + break + result.append(line.strip()) + return result + + class JobDir(object): def __init__(self, root=None, keep=False): # root @@ -168,7 +240,7 @@ class UpdateTask(object): self.event = threading.Event() def __eq__(self, other): - if (other.connection_name == self.connection_name and + if (other and other.connection_name == self.connection_name and other.project_name == self.project_name): return True return False @@ -513,6 +585,8 @@ class AnsibleJob(object): self.proc_lock = threading.Lock() self.running = False self.aborted = False + self.thread = None + self.ssh_agent = None if self.executor_server.config.has_option( 'executor', 'private_key_file'): @@ -520,8 +594,11 @@ class AnsibleJob(object): 'executor', 'private_key_file') else: self.private_key_file = '~/.ssh/id_rsa' + self.ssh_agent = SshAgent() def run(self): + self.ssh_agent.start() + self.ssh_agent.add(self.private_key_file) self.running = True self.thread = threading.Thread(target=self.execute) self.thread.start() @@ -529,7 +606,8 @@ class AnsibleJob(object): def stop(self): self.aborted = True self.abortRunningProc() - self.thread.join() + if self.thread: + self.thread.join() def execute(self): try: @@ -549,6 +627,11 @@ class AnsibleJob(object): self.executor_server.finishJob(self.job.unique) except Exception: self.log.exception("Error finalizing job thread:") + if self.ssh_agent: + try: + self.ssh_agent.stop() + except Exception: + self.log.exception("Error stopping SSH agent:") def _execute(self): self.log.debug("Job %s: beginning" % (self.job.unique,)) @@ -1032,6 +1115,7 @@ class AnsibleJob(object): def runAnsible(self, cmd, timeout, trusted=False): env_copy = os.environ.copy() + env_copy.update(self.ssh_agent.env) env_copy['LOGNAME'] = 'zuul' if trusted: