232 lines
6.8 KiB
Python
232 lines
6.8 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import contextlib
|
|
import io
|
|
import logging
|
|
import pipes
|
|
import random
|
|
import threading
|
|
|
|
import paramiko
|
|
from paramiko import channel
|
|
|
|
from octane import magic_consts
|
|
from octane.util import subprocess
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
PIPE = subprocess.PIPE
|
|
|
|
|
|
class _cache(object):
|
|
def __init__(self, new):
|
|
self.new = new
|
|
self.cache = {}
|
|
self.lock = threading.Lock()
|
|
self.invalidate = []
|
|
self.check_fn = None
|
|
|
|
def __call__(self, node):
|
|
node_id = node.data['id']
|
|
try:
|
|
obj = self.cache[node_id]
|
|
except KeyError:
|
|
obj = None
|
|
else:
|
|
if not self.check_fn or self.check_fn(node, obj):
|
|
return obj
|
|
# Now obj is either bad old obj or None
|
|
with self.lock:
|
|
try:
|
|
new_obj = self.cache[node_id]
|
|
except KeyError:
|
|
pass # Need to just create a new one
|
|
else:
|
|
if new_obj is not obj:
|
|
return new_obj # Someone already created a new one
|
|
# We're going to replace this obj, invalidate other caches
|
|
for cache in self.invalidate:
|
|
with cache.lock:
|
|
cache.cache.pop(node_id, None)
|
|
|
|
new_obj = self.new(node)
|
|
self.cache[node_id] = new_obj
|
|
return new_obj
|
|
|
|
def check(self, fn):
|
|
self.check_fn = fn
|
|
return fn
|
|
|
|
|
|
@_cache
|
|
def get_client(node):
|
|
LOG.info("Creating new SSH connection to node %s", node.data['id'])
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
client.connect(node.data['ip'], key_filename=magic_consts.SSH_KEYS)
|
|
return client
|
|
|
|
|
|
@get_client.check
|
|
def _check_client(node, client):
|
|
t = client.get_transport()
|
|
if t:
|
|
# Send normal keepalive packet, but wait for result to let socket die
|
|
t.global_request('keepalive@lag.net', wait=True)
|
|
if t.is_active():
|
|
return True
|
|
LOG.info("SSH connection to node %s died, reconnecting", node.data['id'])
|
|
return False
|
|
|
|
|
|
class ChannelFile(io.IOBase, channel.ChannelFile):
|
|
pass
|
|
|
|
|
|
class ChannelStderrFile(io.IOBase, channel.ChannelStderrFile):
|
|
pass
|
|
|
|
|
|
class _LogPipe(subprocess._BaseLogPipe):
|
|
def __init__(self, level, pipe, parse_levels=False):
|
|
super(_LogPipe, self).__init__(level, parse_levels=parse_levels)
|
|
self._pipe = pipe
|
|
|
|
def pipe(self):
|
|
return self._pipe
|
|
|
|
|
|
class SSHPopen(subprocess.BasePopen):
|
|
def __init__(self, name, cmd, popen_kwargs):
|
|
self.node = popen_kwargs.pop('node')
|
|
for key in ['stdin', 'stdout', 'stderr']:
|
|
assert popen_kwargs.get(key) in [None, PIPE]
|
|
super(SSHPopen, self).__init__(name, cmd, popen_kwargs)
|
|
self._channel = get_client(self.node).get_transport().open_session()
|
|
self._channel.exec_command(" ".join(map(pipes.quote, cmd)))
|
|
self.name = "%s[at node-%d]" % (self.name, self.node.data['id'])
|
|
if 'stdin' not in self.popen_kwargs:
|
|
self.close_stdin()
|
|
else:
|
|
self.stdin = ChannelFile(self._channel, 'wb')
|
|
stdout = ChannelFile(self._channel, 'rb')
|
|
if 'stdout' not in self.popen_kwargs:
|
|
self._pipe_stdout = _LogPipe(logging.INFO, stdout)
|
|
self._pipe_stdout.start(self.name + " stdout")
|
|
else:
|
|
self._pipe_stdout = None
|
|
self.stdout = stdout
|
|
stderr = ChannelStderrFile(self._channel, 'rb')
|
|
if 'stderr' not in self.popen_kwargs:
|
|
self._pipe_stderr = _LogPipe(
|
|
logging.ERROR, stderr,
|
|
parse_levels=popen_kwargs.get('parse_levels', False),
|
|
)
|
|
self._pipe_stderr.start(self.name + " stderr")
|
|
else:
|
|
self._pipe_stderr = None
|
|
self.stderr = stderr
|
|
|
|
def poll(self):
|
|
if self._channel.exit_status_ready():
|
|
return self._channel.recv_exit_status()
|
|
else:
|
|
return None
|
|
|
|
def wait(self):
|
|
return self._channel.recv_exit_status()
|
|
|
|
def terminate(self):
|
|
self._channel.close()
|
|
|
|
def close_stdin(self):
|
|
self._channel.shutdown_write()
|
|
|
|
def communicate(self):
|
|
if self.stdin:
|
|
self.close_stdin()
|
|
if self.stdout:
|
|
stdout = self.stdout.read()
|
|
else:
|
|
stdout = None
|
|
if self.stderr:
|
|
stderr = self.stderr.read()
|
|
else:
|
|
stderr = None
|
|
return stdout, stderr
|
|
|
|
|
|
def popen(cmd, **kwargs):
|
|
return subprocess.popen(cmd, popen_class=SSHPopen, **kwargs)
|
|
|
|
|
|
def call(cmd, **kwargs):
|
|
return subprocess.call(cmd, popen_class=SSHPopen, **kwargs)
|
|
|
|
|
|
def call_output(cmd, **kwargs):
|
|
return subprocess.call_output(cmd, popen_class=SSHPopen, **kwargs)
|
|
|
|
|
|
@_cache
|
|
def _get_sftp(node):
|
|
transport = get_client(node).get_transport()
|
|
return paramiko.SFTPClient.from_transport(transport)
|
|
|
|
get_client.invalidate.append(_get_sftp)
|
|
|
|
|
|
def sftp(node):
|
|
get_client(node) # ensure we're still connected
|
|
return _get_sftp(node)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def update_file(sftp, filename):
|
|
old = sftp.open(filename, 'r')
|
|
try:
|
|
temp_filename = '%s.octane.%08x' % (filename,
|
|
random.randrange(1 << 8 * 4))
|
|
new = sftp.open(temp_filename, 'wx')
|
|
except IOError: # we're unlucky, try other name (or fail)
|
|
temp_filename = '%s.octane.%08x' % (filename,
|
|
random.randrange(1 << 8 * 4))
|
|
new = sftp.open(temp_filename, 'wx')
|
|
with contextlib.nested(old, new):
|
|
try:
|
|
yield old, new
|
|
except subprocess.DontUpdateException:
|
|
sftp.unlink(temp_filename)
|
|
return
|
|
except Exception:
|
|
sftp.unlink(temp_filename)
|
|
raise
|
|
stat = old.stat()
|
|
new.chmod(stat.st_mode)
|
|
new.chown(stat.st_uid, stat.st_gid)
|
|
|
|
bak_filename = filename + '.octane.bak'
|
|
sftp.rename(filename, bak_filename)
|
|
sftp.rename(temp_filename, filename)
|
|
sftp.unlink(bak_filename)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tempdir(node):
|
|
out = call_output(['mktemp', '-d'], node=node)
|
|
dirname = out[:-1]
|
|
try:
|
|
yield dirname
|
|
finally:
|
|
call(['rm', '-rf', dirname], node=node)
|