Provide file extension when when looking for files

* Allow an extension to be passed to find_config files, defaulting to '.conf'

Change-Id: I022a3b28d9067158e9ed0da741a5e72cb73af167
This commit is contained in:
Brian Waldon 2012-04-25 16:11:45 -07:00
parent 6e4948b952
commit 9768498a53
2 changed files with 16 additions and 5 deletions

View File

@ -319,11 +319,12 @@ class ConfigFileValueError(Error):
pass pass
def find_config_files(project=None, prog=None): def find_config_files(project=None, prog=None, extension='.conf'):
"""Return a list of default configuration files. """Return a list of default configuration files.
:param project: an optional project name :param project: an optional project name
:param prog: the program name, defaulting to the basename of sys.argv[0] :param prog: the program name, defaulting to the basename of sys.argv[0]
:param extension: the type of the config file
We default to two config files: [${project}.conf, ${prog}.conf] We default to two config files: [${project}.conf, ${prog}.conf]
@ -356,16 +357,16 @@ def find_config_files(project=None, prog=None):
] ]
cfg_dirs = filter(bool, cfg_dirs) cfg_dirs = filter(bool, cfg_dirs)
def search_dirs(dirs, basename): def search_dirs(dirs, basename, extension):
for d in dirs: for d in dirs:
path = os.path.join(d, basename) path = os.path.join(d, '%s%s' % (basename, extension))
if os.path.exists(path): if os.path.exists(path):
return path return path
config_files = [] config_files = []
if project: if project:
config_files.append(search_dirs(cfg_dirs, '%s.conf' % project)) config_files.append(search_dirs(cfg_dirs, project, extension))
config_files.append(search_dirs(cfg_dirs, '%s.conf' % prog)) config_files.append(search_dirs(cfg_dirs, prog, extension))
return filter(bool, config_files) return filter(bool, config_files)

View File

@ -133,6 +133,16 @@ class FindConfigFilesTestCase(BaseTestCase):
self.assertEquals(find_config_files(project='blaa'), config_files) self.assertEquals(find_config_files(project='blaa'), config_files)
def test_find_config_files_with_extension(self):
config_files = ['/etc/foo.json']
self.stubs.Set(sys, 'argv', ['foo'])
self.stubs.Set(os.path, 'exists', lambda p: p in config_files)
self.assertEquals(find_config_files(project='blaa'), [])
self.assertEquals(find_config_files(project='blaa', extension='.json'),
config_files)
class CliOptsTestCase(BaseTestCase): class CliOptsTestCase(BaseTestCase):