Get rid of pycrypto dep
* ssh keys are generated with ssh-keygen now; * most of the code is ported from nova. Change-Id: I9cbf4bbff481c3a414b825abbe030ea27b7a2b63
This commit is contained in:
parent
d42f6091b1
commit
e4ff7dd711
|
@ -1,5 +1,5 @@
|
|||
[DEFAULT]
|
||||
modules=setup, jsonutils, xmlutils, timeutils, exception, gettextutils, log, local, notifier, importutils, context, uuidutils, version, threadgroup, db, db.sqlalchemy, excutils, cfgfilter, middleware.base, middleware.debug, rpc, service, thread_group, periodic_task, loopingcall
|
||||
modules=setup, jsonutils, xmlutils, timeutils, exception, gettextutils, log, local, notifier, importutils, context, uuidutils, version, threadgroup, db, db.sqlalchemy, excutils, cfgfilter, middleware.base, middleware.debug, rpc, service, thread_group, periodic_task, loopingcall, processutils
|
||||
base=savanna
|
||||
|
||||
# The following code from 'wsgi' is needed:
|
||||
|
|
|
@ -6,7 +6,6 @@ kombu>=2.4.8
|
|||
netaddr
|
||||
paramiko>=1.8.0
|
||||
pbr>=0.5.21,<1.0
|
||||
pycrypto>=2.6
|
||||
# NOTE(slukjanov): <1.2.3 added to support old clients with the same constraint
|
||||
requests>=1.1,<1.2.3
|
||||
python-cinderclient>=1.0.4
|
||||
|
|
|
@ -19,7 +19,7 @@ import copy
|
|||
|
||||
from savanna.db import base as db_base
|
||||
from savanna.utils import configs
|
||||
# from savanna.openstack.common.rpc import common as rpc_common
|
||||
from savanna.utils import crypto
|
||||
|
||||
|
||||
CLUSTER_DEFAULTS = {
|
||||
|
@ -121,6 +121,10 @@ class ConductorManager(db_base.Base):
|
|||
values = _apply_defaults(values, CLUSTER_DEFAULTS)
|
||||
values['tenant_id'] = context.tenant_id
|
||||
|
||||
private_key, public_key = crypto.generate_key_pair()
|
||||
values['management_private_key'] = private_key
|
||||
values['management_public_key'] = public_key
|
||||
|
||||
cluster_template_id = values.get('cluster_template_id')
|
||||
if cluster_template_id:
|
||||
c_tmpl = self.cluster_template_get(context, cluster_template_id)
|
||||
|
|
|
@ -44,7 +44,8 @@ class Cluster(object):
|
|||
see the docs for details
|
||||
default_image_id
|
||||
anti_affinity
|
||||
private_key
|
||||
management_private_key
|
||||
management_public_key
|
||||
user_keypair_id
|
||||
status
|
||||
status_description
|
||||
|
|
|
@ -196,7 +196,7 @@ class ClusterResource(Resource, objects.Cluster):
|
|||
'cluster_template': (ClusterTemplateResource, None)
|
||||
}
|
||||
|
||||
_filter_fields = ['private_key']
|
||||
_filter_fields = ['management_private_key']
|
||||
|
||||
|
||||
##EDP Resources
|
||||
|
|
|
@ -19,7 +19,6 @@ from sqlalchemy.orm import relationship
|
|||
from savanna.db.sqlalchemy import model_base as mb
|
||||
from savanna.db.sqlalchemy import types as st
|
||||
from savanna.openstack.common import uuidutils
|
||||
from savanna.utils import crypto
|
||||
|
||||
|
||||
## Helpers
|
||||
|
@ -55,7 +54,8 @@ class Cluster(mb.SavannaBase):
|
|||
default_image_id = sa.Column(sa.String(36))
|
||||
neutron_management_network = sa.Column(sa.String(36))
|
||||
anti_affinity = sa.Column(st.JsonListType())
|
||||
private_key = sa.Column(sa.Text, default=crypto.generate_private_key())
|
||||
management_private_key = sa.Column(sa.Text, nullable=False)
|
||||
management_public_key = sa.Column(sa.Text, nullable=False)
|
||||
user_keypair_id = sa.Column(sa.String(80))
|
||||
status = sa.Column(sa.String(80))
|
||||
status_description = sa.Column(sa.String(200))
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright 2011 OpenStack Foundation.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
System-level utilities and helper functions.
|
||||
"""
|
||||
|
||||
import logging as stdlib_logging
|
||||
import os
|
||||
import random
|
||||
import shlex
|
||||
import signal
|
||||
|
||||
from eventlet.green import subprocess
|
||||
from eventlet import greenthread
|
||||
|
||||
from savanna.openstack.common.gettextutils import _ # noqa
|
||||
from savanna.openstack.common import log as logging
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidArgumentError(Exception):
|
||||
def __init__(self, message=None):
|
||||
super(InvalidArgumentError, self).__init__(message)
|
||||
|
||||
|
||||
class UnknownArgumentError(Exception):
|
||||
def __init__(self, message=None):
|
||||
super(UnknownArgumentError, self).__init__(message)
|
||||
|
||||
|
||||
class ProcessExecutionError(Exception):
|
||||
def __init__(self, stdout=None, stderr=None, exit_code=None, cmd=None,
|
||||
description=None):
|
||||
self.exit_code = exit_code
|
||||
self.stderr = stderr
|
||||
self.stdout = stdout
|
||||
self.cmd = cmd
|
||||
self.description = description
|
||||
|
||||
if description is None:
|
||||
description = "Unexpected error while running command."
|
||||
if exit_code is None:
|
||||
exit_code = '-'
|
||||
message = ("%s\nCommand: %s\nExit code: %s\nStdout: %r\nStderr: %r"
|
||||
% (description, cmd, exit_code, stdout, stderr))
|
||||
super(ProcessExecutionError, self).__init__(message)
|
||||
|
||||
|
||||
class NoRootWrapSpecified(Exception):
|
||||
def __init__(self, message=None):
|
||||
super(NoRootWrapSpecified, self).__init__(message)
|
||||
|
||||
|
||||
def _subprocess_setup():
|
||||
# Python installs a SIGPIPE handler by default. This is usually not what
|
||||
# non-Python subprocesses expect.
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
|
||||
|
||||
|
||||
def execute(*cmd, **kwargs):
|
||||
"""Helper method to shell out and execute a command through subprocess.
|
||||
|
||||
Allows optional retry.
|
||||
|
||||
:param cmd: Passed to subprocess.Popen.
|
||||
:type cmd: string
|
||||
:param process_input: Send to opened process.
|
||||
:type proces_input: string
|
||||
:param check_exit_code: Single bool, int, or list of allowed exit
|
||||
codes. Defaults to [0]. Raise
|
||||
:class:`ProcessExecutionError` unless
|
||||
program exits with one of these code.
|
||||
:type check_exit_code: boolean, int, or [int]
|
||||
:param delay_on_retry: True | False. Defaults to True. If set to True,
|
||||
wait a short amount of time before retrying.
|
||||
:type delay_on_retry: boolean
|
||||
:param attempts: How many times to retry cmd.
|
||||
:type attempts: int
|
||||
:param run_as_root: True | False. Defaults to False. If set to True,
|
||||
the command is prefixed by the command specified
|
||||
in the root_helper kwarg.
|
||||
:type run_as_root: boolean
|
||||
:param root_helper: command to prefix to commands called with
|
||||
run_as_root=True
|
||||
:type root_helper: string
|
||||
:param shell: whether or not there should be a shell used to
|
||||
execute this command. Defaults to false.
|
||||
:type shell: boolean
|
||||
:param loglevel: log level for execute commands.
|
||||
:type loglevel: int. (Should be stdlib_logging.DEBUG or
|
||||
stdlib_logging.INFO)
|
||||
:returns: (stdout, stderr) from process execution
|
||||
:raises: :class:`UnknownArgumentError` on
|
||||
receiving unknown arguments
|
||||
:raises: :class:`ProcessExecutionError`
|
||||
"""
|
||||
|
||||
process_input = kwargs.pop('process_input', None)
|
||||
check_exit_code = kwargs.pop('check_exit_code', [0])
|
||||
ignore_exit_code = False
|
||||
delay_on_retry = kwargs.pop('delay_on_retry', True)
|
||||
attempts = kwargs.pop('attempts', 1)
|
||||
run_as_root = kwargs.pop('run_as_root', False)
|
||||
root_helper = kwargs.pop('root_helper', '')
|
||||
shell = kwargs.pop('shell', False)
|
||||
loglevel = kwargs.pop('loglevel', stdlib_logging.DEBUG)
|
||||
|
||||
if isinstance(check_exit_code, bool):
|
||||
ignore_exit_code = not check_exit_code
|
||||
check_exit_code = [0]
|
||||
elif isinstance(check_exit_code, int):
|
||||
check_exit_code = [check_exit_code]
|
||||
|
||||
if kwargs:
|
||||
raise UnknownArgumentError(_('Got unknown keyword args '
|
||||
'to utils.execute: %r') % kwargs)
|
||||
|
||||
if run_as_root and hasattr(os, 'geteuid') and os.geteuid() != 0:
|
||||
if not root_helper:
|
||||
raise NoRootWrapSpecified(
|
||||
message=('Command requested root, but did not specify a root '
|
||||
'helper.'))
|
||||
cmd = shlex.split(root_helper) + list(cmd)
|
||||
|
||||
cmd = map(str, cmd)
|
||||
|
||||
while attempts > 0:
|
||||
attempts -= 1
|
||||
try:
|
||||
LOG.log(loglevel, _('Running cmd (subprocess): %s'), ' '.join(cmd))
|
||||
_PIPE = subprocess.PIPE # pylint: disable=E1101
|
||||
|
||||
if os.name == 'nt':
|
||||
preexec_fn = None
|
||||
close_fds = False
|
||||
else:
|
||||
preexec_fn = _subprocess_setup
|
||||
close_fds = True
|
||||
|
||||
obj = subprocess.Popen(cmd,
|
||||
stdin=_PIPE,
|
||||
stdout=_PIPE,
|
||||
stderr=_PIPE,
|
||||
close_fds=close_fds,
|
||||
preexec_fn=preexec_fn,
|
||||
shell=shell)
|
||||
result = None
|
||||
if process_input is not None:
|
||||
result = obj.communicate(process_input)
|
||||
else:
|
||||
result = obj.communicate()
|
||||
obj.stdin.close() # pylint: disable=E1101
|
||||
_returncode = obj.returncode # pylint: disable=E1101
|
||||
if _returncode:
|
||||
LOG.log(loglevel, _('Result was %s') % _returncode)
|
||||
if not ignore_exit_code and _returncode not in check_exit_code:
|
||||
(stdout, stderr) = result
|
||||
raise ProcessExecutionError(exit_code=_returncode,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cmd=' '.join(cmd))
|
||||
return result
|
||||
except ProcessExecutionError:
|
||||
if not attempts:
|
||||
raise
|
||||
else:
|
||||
LOG.log(loglevel, _('%r failed. Retrying.'), cmd)
|
||||
if delay_on_retry:
|
||||
greenthread.sleep(random.randint(20, 200) / 100.0)
|
||||
finally:
|
||||
# NOTE(termie): this appears to be necessary to let the subprocess
|
||||
# call clean something up in between calls, without
|
||||
# it two execute calls in a row hangs the second one
|
||||
greenthread.sleep(0)
|
||||
|
||||
|
||||
def trycmd(*args, **kwargs):
|
||||
"""A wrapper around execute() to more easily handle warnings and errors.
|
||||
|
||||
Returns an (out, err) tuple of strings containing the output of
|
||||
the command's stdout and stderr. If 'err' is not empty then the
|
||||
command can be considered to have failed.
|
||||
|
||||
:discard_warnings True | False. Defaults to False. If set to True,
|
||||
then for succeeding commands, stderr is cleared
|
||||
|
||||
"""
|
||||
discard_warnings = kwargs.pop('discard_warnings', False)
|
||||
|
||||
try:
|
||||
out, err = execute(*args, **kwargs)
|
||||
failed = False
|
||||
except ProcessExecutionError as exn:
|
||||
out, err = '', str(exn)
|
||||
failed = True
|
||||
|
||||
if not failed and discard_warnings and err:
|
||||
# Handle commands that output to stderr but otherwise succeed
|
||||
err = ''
|
||||
|
||||
return out, err
|
||||
|
||||
|
||||
def ssh_execute(ssh, cmd, process_input=None,
|
||||
addl_env=None, check_exit_code=True):
|
||||
LOG.debug(_('Running cmd (SSH): %s'), cmd)
|
||||
if addl_env:
|
||||
raise InvalidArgumentError(_('Environment not supported over SSH'))
|
||||
|
||||
if process_input:
|
||||
# This is (probably) fixable if we need it...
|
||||
raise InvalidArgumentError(_('process_input not supported over SSH'))
|
||||
|
||||
stdin_stream, stdout_stream, stderr_stream = ssh.exec_command(cmd)
|
||||
channel = stdout_stream.channel
|
||||
|
||||
# NOTE(justinsb): This seems suspicious...
|
||||
# ...other SSH clients have buffering issues with this approach
|
||||
stdout = stdout_stream.read()
|
||||
stderr = stderr_stream.read()
|
||||
stdin_stream.close()
|
||||
|
||||
exit_status = channel.recv_exit_status()
|
||||
|
||||
# exit_status == -1 if no exit code was returned
|
||||
if exit_status != -1:
|
||||
LOG.debug(_('Result was %s') % exit_status)
|
||||
if check_exit_code and exit_status != 0:
|
||||
raise ProcessExecutionError(exit_code=exit_status,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
cmd=cmd)
|
||||
|
||||
return (stdout, stderr)
|
|
@ -22,7 +22,6 @@ from savanna.plugins import provisioning as p
|
|||
from savanna.plugins.vanilla import config_helper as c_helper
|
||||
from savanna.plugins.vanilla import run_scripts as run
|
||||
from savanna.plugins.vanilla import scaling as sc
|
||||
from savanna.utils import crypto
|
||||
from savanna.utils import files as f
|
||||
from savanna.utils import remote
|
||||
|
||||
|
@ -95,8 +94,9 @@ class VanillaProvider(p.ProvisioningPluginBase):
|
|||
|
||||
def configure_cluster(self, cluster):
|
||||
self._push_configs_to_nodes(cluster)
|
||||
self._write_hadoop_user_keys(cluster.private_key,
|
||||
utils.get_instances(cluster))
|
||||
self._write_hadoop_user_keys(utils.get_instances(cluster),
|
||||
cluster.management_private_key,
|
||||
cluster.management_public_key)
|
||||
|
||||
def start_cluster(self, cluster):
|
||||
nn_instance = utils.get_namenode(cluster)
|
||||
|
@ -192,8 +192,9 @@ class VanillaProvider(p.ProvisioningPluginBase):
|
|||
|
||||
def scale_cluster(self, cluster, instances):
|
||||
self._push_configs_to_nodes(cluster, instances=instances)
|
||||
self._write_hadoop_user_keys(cluster.private_key,
|
||||
instances)
|
||||
self._write_hadoop_user_keys(instances,
|
||||
cluster.management_private_key,
|
||||
cluster.management_public_key)
|
||||
run.refresh_nodes(remote.get_remote(
|
||||
utils.get_namenode(cluster)), "dfsadmin")
|
||||
jt = utils.get_jobtracker(cluster)
|
||||
|
@ -290,9 +291,7 @@ class VanillaProvider(p.ProvisioningPluginBase):
|
|||
ctx = context.ctx()
|
||||
conductor.cluster_update(ctx, cluster, {'info': info})
|
||||
|
||||
def _write_hadoop_user_keys(self, private_key, instances):
|
||||
public_key = crypto.private_key_to_public_key(private_key)
|
||||
|
||||
def _write_hadoop_user_keys(self, instances, private_key, public_key):
|
||||
files = {
|
||||
'id_rsa': private_key,
|
||||
'authorized_keys': public_key
|
||||
|
|
|
@ -22,7 +22,6 @@ from savanna.openstack.common import excutils
|
|||
from savanna.openstack.common import log as logging
|
||||
from savanna.service import networks
|
||||
from savanna.service import volumes
|
||||
from savanna.utils import crypto
|
||||
from savanna.utils import general as g
|
||||
from savanna.utils.openstack import nova
|
||||
|
||||
|
@ -251,8 +250,8 @@ echo "%(private_key)s" > %(user_home)s/.ssh/id_rsa
|
|||
node_group)
|
||||
|
||||
return script_template % {
|
||||
"public_key": crypto.private_key_to_public_key(cluster.private_key),
|
||||
"private_key": cluster.private_key,
|
||||
"public_key": cluster.management_public_key,
|
||||
"private_key": cluster.management_private_key,
|
||||
"user_home": user_home
|
||||
}
|
||||
|
||||
|
|
|
@ -55,6 +55,9 @@ SAMPLE_CLUSTER_DICT = {
|
|||
|
||||
|
||||
class TestResource(unittest2.TestCase):
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
|
||||
def test_resource_creation(self):
|
||||
res = r.Resource(SAMPLE_DICT)
|
||||
|
||||
|
@ -103,7 +106,7 @@ class TestResource(unittest2.TestCase):
|
|||
|
||||
def test_to_dict_filtering(self):
|
||||
cluster_dict = copy.deepcopy(SAMPLE_CLUSTER_DICT)
|
||||
cluster_dict['private_key'] = 'abacaba'
|
||||
cluster_dict['management_private_key'] = 'abacaba'
|
||||
cluster_dict['node_groups'][0]['id'] = 'some_id'
|
||||
|
||||
cluster = r.ClusterResource(cluster_dict)
|
||||
|
|
|
@ -167,7 +167,7 @@ def _create_cluster_mock(node_groups, aa):
|
|||
|
||||
user_kp = mock.Mock()
|
||||
user_kp.public_key = "123"
|
||||
private_key = c.generate_private_key()
|
||||
private_key = c.generate_key_pair()[0]
|
||||
|
||||
dct = {'name': 'test_cluster',
|
||||
'plugin_name': 'mock_plugin',
|
||||
|
@ -199,10 +199,9 @@ def _generate_user_data_script(cluster):
|
|||
echo "%(public_key)s" >> %(user_home)s/.ssh/authorized_keys
|
||||
echo "%(private_key)s" > %(user_home)s/.ssh/id_rsa
|
||||
"""
|
||||
key = c.private_key_to_public_key(cluster.private_key)
|
||||
return script_template % {
|
||||
"public_key": key,
|
||||
"private_key": cluster.private_key,
|
||||
"public_key": cluster.management_public_key,
|
||||
"private_key": cluster.management_private_key,
|
||||
"user_home": "/root/"
|
||||
}
|
||||
|
||||
|
|
|
@ -19,23 +19,21 @@ from savanna.utils import crypto as c
|
|||
|
||||
|
||||
class CryptoTest(unittest2.TestCase):
|
||||
def test_generate_private_key(self):
|
||||
pk = c.generate_private_key()
|
||||
def test_generate_key_pair(self):
|
||||
kp = c.generate_key_pair()
|
||||
|
||||
self.assertIsNotNone(pk)
|
||||
self.assertIn('-----BEGIN RSA PRIVATE KEY-----', pk)
|
||||
self.assertIn('-----END RSA PRIVATE KEY-----', pk)
|
||||
self.assertIsInstance(kp, tuple)
|
||||
self.assertIsNotNone(kp[0])
|
||||
self.assertIsNotNone(kp[1])
|
||||
self.assertIn('-----BEGIN RSA PRIVATE KEY-----', kp[0])
|
||||
self.assertIn('-----END RSA PRIVATE KEY-----', kp[0])
|
||||
self.assertIn('ssh-rsa ', kp[1])
|
||||
self.assertIn('Generated by Savanna', kp[1])
|
||||
|
||||
def test_to_paramiko_private_key(self):
|
||||
pk_str = c.generate_private_key()
|
||||
pk_str = c.generate_key_pair()[0]
|
||||
pk = c.to_paramiko_private_key(pk_str)
|
||||
|
||||
self.assertIsNotNone(pk)
|
||||
self.assertEqual(2048, pk.size)
|
||||
self.assertEqual('ssh-rsa', pk.get_name())
|
||||
|
||||
def test_private_key_to_public_key(self):
|
||||
key = c.private_key_to_public_key(c.generate_private_key())
|
||||
|
||||
self.assertIsNotNone(key)
|
||||
self.assertIn('ssh-rsa', key)
|
||||
|
|
|
@ -13,16 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto import Random
|
||||
import os
|
||||
import paramiko
|
||||
import six
|
||||
|
||||
|
||||
def generate_private_key(length=2048):
|
||||
"""Generate RSA private key (str) with the specified length."""
|
||||
rsa = RSA.generate(length, Random.new().read)
|
||||
return rsa.exportKey('PEM')
|
||||
from savanna.openstack.common import processutils
|
||||
from savanna.utils import tempfiles
|
||||
|
||||
|
||||
def to_paramiko_private_key(pkey):
|
||||
|
@ -30,6 +26,32 @@ def to_paramiko_private_key(pkey):
|
|||
return paramiko.RSAKey(file_obj=six.StringIO(pkey))
|
||||
|
||||
|
||||
def private_key_to_public_key(key):
|
||||
"""Convert private key (str) to public key (str)."""
|
||||
return RSA.importKey(key).exportKey('OpenSSH')
|
||||
def generate_key_pair(key_length=2048):
|
||||
"""Create RSA key pair with specified number of bits in key.
|
||||
|
||||
Returns tuple of private and public keys.
|
||||
"""
|
||||
with tempfiles.tempdir() as tmpdir:
|
||||
keyfile = os.path.join(tmpdir, 'tempkey')
|
||||
args = [
|
||||
'ssh-keygen',
|
||||
'-q', # quiet
|
||||
'-N', '', # w/o passphrase
|
||||
'-t', 'rsa', # create key of rsa type
|
||||
'-f', keyfile, # filename of the key file
|
||||
'-C', 'Generated by Savanna' # key comment
|
||||
]
|
||||
if key_length is not None:
|
||||
args.extend(['-b', key_length])
|
||||
processutils.execute(*args)
|
||||
if not os.path.exists(keyfile):
|
||||
# TODO(slukjanov): replace with specific exception
|
||||
raise RuntimeError("Private key file hasn't been created")
|
||||
private_key = open(keyfile).read()
|
||||
public_key_path = keyfile + '.pub'
|
||||
if not os.path.exists(public_key_path):
|
||||
# TODO(slukjanov): replace with specific exception
|
||||
raise RuntimeError("Public key file hasn't been created")
|
||||
public_key = open(public_key_path).read()
|
||||
|
||||
return private_key, public_key
|
||||
|
|
|
@ -109,7 +109,7 @@ class InstanceInteropHelper(object):
|
|||
username = nova.get_node_group_image_username(self.instance.node_group)
|
||||
return setup_ssh_connection(
|
||||
self.instance.management_ip, username,
|
||||
self.instance.node_group.cluster.private_key)
|
||||
self.instance.node_group.cluster.management_private_key)
|
||||
|
||||
def execute_command(self, cmd, get_stderr=False, raise_when_error=True):
|
||||
with contextlib.closing(self.ssh_connection()) as ssh:
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) 2013 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tempdir(**kwargs):
|
||||
argdict = kwargs.copy()
|
||||
if 'dir' not in argdict:
|
||||
argdict['dir'] = '/tmp/'
|
||||
tmpdir = tempfile.mkdtemp(**argdict)
|
||||
try:
|
||||
yield tmpdir
|
||||
finally:
|
||||
try:
|
||||
shutil.rmtree(tmpdir)
|
||||
except OSError:
|
||||
raise RuntimeError("error")
|
Loading…
Reference in New Issue