Add utils.tempdir() context manager for easy temp dirs

Fixes bug 883323 (and others)

Users of tempfile.mkdtemp() need to make sure the directory is cleaned
up when it's done being used. Unfortunately, not all of the code does
so at all, or safely (by using a try/finally block).

Change-Id: I270109d83efec4f8b3dd954021493f4d96c6ab79
This commit is contained in:
Johannes Erdfelt 2012-02-28 05:54:48 +00:00
parent f01b9b8dd2
commit f0d5df523b
10 changed files with 249 additions and 314 deletions

View File

@ -24,9 +24,7 @@ Nova authentication management
"""
import os
import shutil
import string # pylint: disable=W0402
import tempfile
import uuid
import zipfile
@ -767,7 +765,7 @@ class AuthManager(object):
pid = Project.safe_id(project)
private_key, signed_cert = crypto.generate_x509_cert(user.id, pid)
tmpdir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
zf = os.path.join(tmpdir, "temp.zip")
zippy = zipfile.ZipFile(zf, 'w')
if use_dmz and FLAGS.region_list:
@ -805,7 +803,6 @@ class AuthManager(object):
with open(zf, 'rb') as f:
read_buffer = f.read()
shutil.rmtree(tmpdir)
return read_buffer
def get_environment_rc(self, user, project=None, use_dmz=True):

View File

@ -65,9 +65,9 @@ class CloudPipe(object):
def get_encoded_zip(self, project_id):
# Make a payload.zip
tmpfolder = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
filename = "payload.zip"
zippath = os.path.join(tmpfolder, filename)
zippath = os.path.join(tmpdir, filename)
z = zipfile.ZipFile(zippath, "w", zipfile.ZIP_DEFLATED)
shellfile = open(FLAGS.boot_script_template, "r")
s = string.Template(shellfile.read())
@ -82,11 +82,13 @@ class CloudPipe(object):
z.writestr('autorun.sh', boot_script)
crl = os.path.join(crypto.ca_folder(project_id), 'crl.pem')
z.write(crl, 'crl.pem')
server_key = os.path.join(crypto.ca_folder(project_id), 'server.key')
server_key = os.path.join(crypto.ca_folder(project_id),
'server.key')
z.write(server_key, 'server.key')
ca_crt = os.path.join(crypto.ca_path(project_id))
z.write(ca_crt, 'ca.crt')
server_crt = os.path.join(crypto.ca_folder(project_id), 'server.crt')
server_crt = os.path.join(crypto.ca_folder(project_id),
'server.crt')
z.write(server_crt, 'server.crt')
z.close()
zippy = open(zippath, "r")
@ -95,6 +97,7 @@ class CloudPipe(object):
# hence the double encoding.
encoded = zippy.read().encode("base64").encode("base64")
zippy.close()
return encoded
def launch_vpn_instance(self, project_id, user_id):

View File

@ -175,6 +175,8 @@ def handle_flagfiles_managed(args):
# Do stuff
# Any temporary fils have been removed
'''
# NOTE(johannes): Would be nice to use utils.tempdir(), but it
# causes an import loop
tempdir = tempfile.mkdtemp(prefix='nova-conf-')
try:
yield handle_flagfiles(args, tempdir=tempdir)

View File

@ -27,9 +27,7 @@ from __future__ import absolute_import
import base64
import hashlib
import os
import shutil
import string
import tempfile
import Crypto.Cipher.AES
@ -127,7 +125,7 @@ def _generate_fingerprint(public_key_file):
def generate_fingerprint(public_key):
tmpdir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
try:
pubfile = os.path.join(tmpdir, 'temp.pub')
with open(pubfile, 'w') as f:
@ -135,17 +133,12 @@ def generate_fingerprint(public_key):
return _generate_fingerprint(pubfile)
except exception.ProcessExecutionError:
raise exception.InvalidKeypair()
finally:
try:
shutil.rmtree(tmpdir)
except IOError, e:
LOG.debug(_('Could not remove tmpdir: %s'), str(e))
def generate_key_pair(bits=1024):
# what is the magic 65537?
tmpdir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
keyfile = os.path.join(tmpdir, 'temp')
utils.execute('ssh-keygen', '-q', '-b', bits, '-N', '',
'-t', 'rsa', '-f', keyfile)
@ -153,11 +146,6 @@ def generate_key_pair(bits=1024):
private_key = open(keyfile).read()
public_key = open(keyfile + '.pub').read()
try:
shutil.rmtree(tmpdir)
except OSError, e:
LOG.debug(_('Could not remove tmpdir: %s'), str(e))
return (private_key, public_key, fingerprint)
@ -233,20 +221,16 @@ def _user_cert_subject(user_id, project_id):
def generate_x509_cert(user_id, project_id, bits=1024):
"""Generate and sign a cert for user in project."""
subject = _user_cert_subject(user_id, project_id)
tmpdir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
keyfile = os.path.abspath(os.path.join(tmpdir, 'temp.key'))
csrfile = os.path.join(tmpdir, 'temp.csr')
utils.execute('openssl', 'genrsa', '-out', keyfile, str(bits))
utils.execute('openssl', 'req', '-new', '-key', keyfile, '-out', csrfile,
'-batch', '-subj', subject)
utils.execute('openssl', 'req', '-new', '-key', keyfile, '-out',
csrfile, '-batch', '-subj', subject)
private_key = open(keyfile).read()
csr = open(csrfile).read()
try:
shutil.rmtree(tmpdir)
except OSError, e:
LOG.debug(_('Could not remove tmpdir: %s'), str(e))
(serial, signed_csr) = sign_csr(csr, project_id)
fname = os.path.join(ca_folder(project_id), 'newcerts/%s.pem' % serial)
cert = {'user_id': user_id,
@ -298,17 +282,20 @@ def sign_csr(csr_text, project_id=None):
def _sign_csr(csr_text, ca_folder):
tmpfolder = tempfile.mkdtemp()
inbound = os.path.join(tmpfolder, 'inbound.csr')
outbound = os.path.join(tmpfolder, 'outbound.csr')
csrfile = open(inbound, 'w')
with utils.tempdir() as tmpdir:
inbound = os.path.join(tmpdir, 'inbound.csr')
outbound = os.path.join(tmpdir, 'outbound.csr')
with open(inbound, 'w') as csrfile:
csrfile.write(csr_text)
csrfile.close()
LOG.debug(_('Flags path: %s'), ca_folder)
start = os.getcwd()
# Change working dir to CA
if not os.path.exists(ca_folder):
os.makedirs(ca_folder)
os.chdir(ca_folder)
utils.execute('openssl', 'ca', '-batch', '-out', outbound, '-config',
'./openssl.cnf', '-infiles', inbound)
@ -316,6 +303,7 @@ def _sign_csr(csr_text, ca_folder):
'-serial', '-noout')
serial = string.strip(out.rpartition('=')[2])
os.chdir(start)
with open(outbound, 'r') as crtfile:
return (serial, crtfile.read())

View File

@ -17,8 +17,6 @@ Tests for Crypto module.
"""
import os
import shutil
import tempfile
import mox
@ -50,9 +48,8 @@ class SymmetricKeyTestCase(test.TestCase):
class X509Test(test.TestCase):
def test_can_generate_x509(self):
tmpdir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
self.flags(ca_path=tmpdir)
try:
crypto.ensure_ca_filesystem()
_key, cert_str = crypto.generate_x509_cert('fake', 'fake')
@ -70,14 +67,10 @@ class X509Test(test.TestCase):
project_cert_file, '-verbose', signed_cert_file)
self.assertFalse(err)
finally:
shutil.rmtree(tmpdir)
def test_encrypt_decrypt_x509(self):
tmpdir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
self.flags(ca_path=tmpdir)
project_id = "fake"
try:
crypto.ensure_ca_filesystem()
cert = crypto.fetch_ca(project_id)
public_key = os.path.join(tmpdir, "public.pem")
@ -92,8 +85,6 @@ class X509Test(test.TestCase):
process_input=text)
dec = crypto.decrypt_text(project_id, enc)
self.assertEqual(text, dec)
finally:
shutil.rmtree(tmpdir)
class RevokeCertsTest(test.TestCase):

View File

@ -17,12 +17,11 @@
# under the License.
import contextlib
import cStringIO
import hashlib
import logging
import os
import shutil
import tempfile
import time
from nova import test
@ -58,9 +57,8 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(csum, None)
def test_read_stored_checksum(self):
try:
dirname = tempfile.mkdtemp()
fname = os.path.join(dirname, 'aaa')
with utils.tempdir() as tmpdir:
fname = os.path.join(tmpdir, 'aaa')
csum_input = 'fdghkfhkgjjksfdgjksjkghsdf'
f = open('%s.sha1' % fname, 'w')
@ -71,9 +69,6 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(csum_input, csum_output)
finally:
shutil.rmtree(dirname)
def test_list_base_images(self):
listing = ['00000001',
'ephemeral_0_20_None',
@ -281,13 +276,17 @@ class ImageCacheManagerTestCase(test.TestCase):
(base_file2, True, False),
(base_file3, False, True)])
@contextlib.contextmanager
def _intercept_log_messages(self):
try:
mylog = log.getLogger()
stream = cStringIO.StringIO()
handler = logging.StreamHandler(stream)
handler.setFormatter(log.LegacyNovaFormatter())
mylog.logger.addHandler(handler)
return mylog, handler, stream
yield stream
finally:
mylog.logger.removeHandler(handler)
def test_verify_checksum(self):
testdata = ('OpenStack Software delivers a massively scalable cloud '
@ -295,11 +294,10 @@ class ImageCacheManagerTestCase(test.TestCase):
img = {'container_format': 'ami', 'id': '42'}
self.flags(checksum_base_images=True)
mylog, handler, stream = self._intercept_log_messages()
try:
dirname = tempfile.mkdtemp()
fname = os.path.join(dirname, 'aaa')
with self._intercept_log_messages() as stream:
with utils.tempdir() as tmpdir:
fname = os.path.join(tmpdir, 'aaa')
f = open(fname, 'w')
f.write(testdata)
@ -324,8 +322,8 @@ class ImageCacheManagerTestCase(test.TestCase):
image_cache_manager = imagecache.ImageCacheManager()
res = image_cache_manager._verify_checksum(img, fname)
self.assertFalse(res)
self.assertNotEqual(stream.getvalue().find('image verification '
'failed'), -1)
log = stream.getvalue()
self.assertNotEqual(log.find('image verification failed'), -1)
# Checksum file missing
os.remove('%s.sha1' % fname)
@ -337,15 +335,12 @@ class ImageCacheManagerTestCase(test.TestCase):
# side effect of creating the checksum
self.assertTrue(os.path.exists('%s.sha1' % fname))
finally:
shutil.rmtree(dirname)
mylog.logger.removeHandler(handler)
def _make_base_file(checksum=True):
@contextlib.contextmanager
def _make_base_file(self, checksum=True):
"""Make a base file for testing."""
dirname = tempfile.mkdtemp()
fname = os.path.join(dirname, 'aaa')
with utils.tempdir() as tmpdir:
fname = os.path.join(tmpdir, 'aaa')
base_file = open(fname, 'w')
base_file.write('data')
@ -358,11 +353,10 @@ class ImageCacheManagerTestCase(test.TestCase):
checksum_file.close()
base_file.close()
return dirname, fname
yield fname
def test_remove_base_file(self):
dirname, fname = self._make_base_file()
try:
with self._make_base_file() as fname:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager._remove_base_file(fname)
@ -377,12 +371,8 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertFalse(os.path.exists(fname))
self.assertFalse(os.path.exists('%s.sha1' % fname))
finally:
shutil.rmtree(dirname)
def test_remove_base_file_original(self):
dirname, fname = self._make_base_file()
try:
with self._make_base_file() as fname:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.originals = [fname]
image_cache_manager._remove_base_file(fname)
@ -405,27 +395,19 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertFalse(os.path.exists(fname))
self.assertFalse(os.path.exists('%s.sha1' % fname))
finally:
shutil.rmtree(dirname)
def test_remove_base_file_dne(self):
# This test is solely to execute the "does not exist" code path. We
# don't expect the method being tested to do anything in this case.
dirname = tempfile.mkdtemp()
try:
fname = os.path.join(dirname, 'aaa')
with utils.tempdir() as tmpdir:
fname = os.path.join(tmpdir, 'aaa')
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager._remove_base_file(fname)
finally:
shutil.rmtree(dirname)
def test_remove_base_file_oserror(self):
dirname = tempfile.mkdtemp()
fname = os.path.join(dirname, 'aaa')
mylog, handler, stream = self._intercept_log_messages()
with self._intercept_log_messages() as stream:
with utils.tempdir() as tmpdir:
fname = os.path.join(tmpdir, 'aaa')
try:
os.mkdir(fname)
os.utime(fname, (-1, time.time() - 3601))
@ -437,19 +419,14 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertNotEqual(stream.getvalue().find('Failed to remove'),
-1)
finally:
shutil.rmtree(dirname)
mylog.logger.removeHandler(handler)
def test_handle_base_image_unused(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
dirname, fname = self._make_base_file()
with self._make_base_file() as fname:
os.utime(fname, (-1, time.time() - 3601))
try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.unexplained_images = [fname]
image_cache_manager._handle_base_image(img, fname)
@ -459,18 +436,14 @@ class ImageCacheManagerTestCase(test.TestCase):
[fname])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
finally:
shutil.rmtree(dirname)
def test_handle_base_image_used(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
dirname, fname = self._make_base_file()
with self._make_base_file() as fname:
os.utime(fname, (-1, time.time() - 3601))
try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.unexplained_images = [fname]
image_cache_manager.used_images = {'123': (1, 0, ['banana-42'])}
@ -480,18 +453,14 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.removable_base_files, [])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
finally:
shutil.rmtree(dirname)
def test_handle_base_image_used_remotely(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
dirname, fname = self._make_base_file()
with self._make_base_file() as fname:
os.utime(fname, (-1, time.time() - 3601))
try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.used_images = {'123': (0, 1, ['banana-42'])}
image_cache_manager._handle_base_image(img, None)
@ -500,9 +469,6 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.removable_base_files, [])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
finally:
shutil.rmtree(dirname)
def test_handle_base_image_absent(self):
"""Ensure we warn for use of a missing base image."""
@ -510,9 +476,7 @@ class ImageCacheManagerTestCase(test.TestCase):
'id': '123',
'uuid': '1234-4567-2378'}
mylog, handler, stream = self._intercept_log_messages()
try:
with self._intercept_log_messages() as stream:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.used_images = {'123': (1, 0, ['banana-42'])}
image_cache_manager._handle_base_image(img, None)
@ -523,18 +487,14 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertNotEqual(stream.getvalue().find('an absent base file'),
-1)
finally:
mylog.logger.removeHandler(handler)
def test_handle_base_image_used_missing(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
dirname = tempfile.mkdtemp()
fname = os.path.join(dirname, 'aaa')
with utils.tempdir() as tmpdir:
fname = os.path.join(tmpdir, 'aaa')
try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.unexplained_images = [fname]
image_cache_manager.used_images = {'123': (1, 0, ['banana-42'])}
@ -544,17 +504,12 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.removable_base_files, [])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
finally:
shutil.rmtree(dirname)
def test_handle_base_image_checksum_fails(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
dirname, fname = self._make_base_file()
try:
with self._make_base_file() as fname:
f = open(fname, 'w')
f.write('banana')
f.close()
@ -569,9 +524,6 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.corrupt_base_files,
[fname])
finally:
shutil.rmtree(dirname)
def test_verify_base_images(self):
self.flags(instances_path='/instance_path')
self.flags(remove_unused_base_images=True)

View File

@ -1050,7 +1050,7 @@ class LibvirtConnTestCase(test.TestCase):
def test_pre_block_migration_works_correctly(self):
"""Confirms pre_block_migration works correctly."""
# Replace instances_path since this testcase creates tmpfile
tmpdir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
self.flags(instances_path=tmpdir)
# Test data
@ -1069,7 +1069,6 @@ class LibvirtConnTestCase(test.TestCase):
self.assertTrue(os.path.exists('%s/%s/' %
(tmpdir, instance_ref.name)))
shutil.rmtree(tmpdir)
db.instance_destroy(self.context, instance_ref['id'])
@test.skip_if(missing_libvirt(), "Test requires libvirt")
@ -1926,13 +1925,10 @@ disk size: 4.4M''', ''))
libvirt_utils.mkfs('swap', '/my/swap/block/dev')
def test_ensure_tree(self):
tmpdir = tempfile.mkdtemp()
try:
with utils.tempdir() as tmpdir:
testdir = '%s/foo/bar/baz' % (tmpdir,)
libvirt_utils.ensure_tree(testdir)
self.assertTrue(os.path.isdir(testdir))
finally:
shutil.rmtree(tmpdir)
def test_write_to_file(self):
dst_fd, dst_path = tempfile.mkstemp()

View File

@ -31,9 +31,11 @@ import pyclbr
import random
import re
import shlex
import shutil
import socket
import struct
import sys
import tempfile
import time
import types
import uuid
@ -1543,3 +1545,15 @@ def temporary_chown(path, owner_uid=None):
finally:
if orig_uid != owner_uid:
execute('chown', orig_uid, path, run_as_root=True)
@contextlib.contextmanager
def tempdir(**kwargs):
tmpdir = tempfile.mkdtemp(**kwargs)
try:
yield tmpdir
finally:
try:
shutil.rmtree(tmpdir)
except OSError, e:
LOG.debug(_('Could not remove tmpdir: %s'), str(e))

View File

@ -46,7 +46,6 @@ import multiprocessing
import os
import shutil
import sys
import tempfile
import uuid
from eventlet import greenthread
@ -622,9 +621,9 @@ class LibvirtConnection(driver.ComputeDriver):
disk_path = source.get('file')
# Export the snapshot to a raw image
temp_dir = tempfile.mkdtemp()
with utils.tempdir() as tmpdir:
try:
out_path = os.path.join(temp_dir, snapshot_name)
out_path = os.path.join(tmpdir, snapshot_name)
libvirt_utils.extract_snapshot(disk_path, source_format,
snapshot_name, out_path,
image_format)
@ -636,8 +635,6 @@ class LibvirtConnection(driver.ComputeDriver):
image_file)
finally:
# Clean up
shutil.rmtree(temp_dir)
snapshot_ptr.delete(0)
@exception.wrap_exception()

View File

@ -25,7 +25,6 @@ import json
import os
import pickle
import re
import tempfile
import time
import urllib
import urlparse
@ -1750,8 +1749,7 @@ def _mounted_processing(device, key, net, metadata):
"""Callback which runs with the image VDI attached"""
# NB: Partition 1 hardcoded
dev_path = utils.make_dev_path(device, partition=1)
tmpdir = tempfile.mkdtemp()
try:
with utils.tempdir() as tmpdir:
# Mount only Linux filesystems, to avoid disturbing NTFS images
err = _mount_filesystem(dev_path, tmpdir)
if not err:
@ -1770,9 +1768,6 @@ def _mounted_processing(device, key, net, metadata):
else:
LOG.info(_('Failed to mount filesystem (expected for '
'non-linux instances): %s') % err)
finally:
# remove temporary directory
os.rmdir(tmpdir)
def _prepare_injectables(inst, networks_info):