Instead of having the cloud pass large references to its constructor, this has been reduced to actual objects.

Added a get template filename helper which can be used to locate template files for various handlers/transforms.
Ensured that the config that we give back out is copied, so that it can't be modified by any 'malicous' handlers/transforms.
Added helper method cycle_logging that can resetup logging, this is mainly used by the rsyslog transform.
This commit is contained in:
Joshua Harlow
2012-06-15 17:45:52 -07:00
parent 3b6745531e
commit 707c10341d

View File

@@ -20,6 +20,9 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import copy
import os
from cloudinit import distros
from cloudinit import helpers
from cloudinit import log as logging
@@ -28,18 +31,40 @@ LOG = logging.getLogger(__name__)
class Cloud(object):
def __init__(self, datasource, paths, cfg):
def __init__(self, datasource, paths, cfg, distro, runners):
self.datasource = datasource
self.paths = paths
self.cfg = cfg
self.distro = distros.fetch(cfg, self)
self.runners = helpers.Runners(paths)
self.distro = distro
self._cfg = cfg
self._runners = runners
# If a transform manipulates logging or logging services
# it is typically useful to cause the logging to be
# setup again.
def cycle_logging(self):
logging.setupLogging(self.cfg)
@property
def cfg(self):
# Ensure that not indirectly modified
return copy.deepcopy(self._cfg)
def run(self, name, functor, args, freq=None, clear_on_fail=False):
return self.runners.run(name, functor, args, freq, clear_on_fail)
return self._runners.run(name, functor, args, freq, clear_on_fail)
def get_template_filename(self, name):
fn = self.paths.template_tpl % (name)
if not os.path.isfile(fn):
LOG.warn("No template found at %s for template named %s", fn, name)
return None
return fn
# The rest of thes are just useful proxies
def get_userdata(self):
return self.datasource.get_userdata()
def get_instance_id(self):
return self.datasource.get_instance_id()
def get_public_ssh_keys(self):
return self.datasource.get_public_ssh_keys()
@@ -47,7 +72,7 @@ class Cloud(object):
def get_locale(self):
return self.datasource.get_locale()
def get_mirror(self):
def get_local_mirror(self):
return self.datasource.get_local_mirror()
def get_hostname(self, fqdn=False):