Update and sync helpers.

This commit is contained in:
Adam Gandelman
2013-11-05 19:57:51 -08:00
parent 22a9416f4a
commit 9205ee9003
10 changed files with 286 additions and 46 deletions

View File

@@ -9,6 +9,7 @@ import json
import yaml
import subprocess
import UserDict
from subprocess import CalledProcessError
CRITICAL = "CRITICAL"
ERROR = "ERROR"
@@ -21,7 +22,7 @@ cache = {}
def cached(func):
''' Cache return values for multiple executions of func + args
"""Cache return values for multiple executions of func + args
For example:
@@ -32,7 +33,7 @@ def cached(func):
unit_get('test')
will cache the result of unit_get + 'test' for future calls.
'''
"""
def wrapper(*args, **kwargs):
global cache
key = str((func, args, kwargs))
@@ -46,8 +47,8 @@ def cached(func):
def flush(key):
''' Flushes any entries from function cache where the
key is found in the function+args '''
"""Flushes any entries from function cache where the
key is found in the function+args """
flush_list = []
for item in cache:
if key in item:
@@ -57,7 +58,7 @@ def flush(key):
def log(message, level=None):
"Write a message to the juju log"
"""Write a message to the juju log"""
command = ['juju-log']
if level:
command += ['-l', level]
@@ -66,7 +67,7 @@ def log(message, level=None):
class Serializable(UserDict.IterableUserDict):
"Wrapper, an object that can be serialized to yaml or json"
"""Wrapper, an object that can be serialized to yaml or json"""
def __init__(self, obj):
# wrap the object
@@ -96,11 +97,11 @@ class Serializable(UserDict.IterableUserDict):
self.data = state
def json(self):
"Serialize the object to json"
"""Serialize the object to json"""
return json.dumps(self.data)
def yaml(self):
"Serialize the object to yaml"
"""Serialize the object to yaml"""
return yaml.dump(self.data)
@@ -119,38 +120,38 @@ def execution_environment():
def in_relation_hook():
"Determine whether we're running in a relation hook"
"""Determine whether we're running in a relation hook"""
return 'JUJU_RELATION' in os.environ
def relation_type():
"The scope for the current relation hook"
"""The scope for the current relation hook"""
return os.environ.get('JUJU_RELATION', None)
def relation_id():
"The relation ID for the current relation hook"
"""The relation ID for the current relation hook"""
return os.environ.get('JUJU_RELATION_ID', None)
def local_unit():
"Local unit ID"
"""Local unit ID"""
return os.environ['JUJU_UNIT_NAME']
def remote_unit():
"The remote unit for the current relation hook"
"""The remote unit for the current relation hook"""
return os.environ['JUJU_REMOTE_UNIT']
def service_name():
"The name service group this unit belongs to"
"""The name service group this unit belongs to"""
return local_unit().split('/')[0]
@cached
def config(scope=None):
"Juju charm configuration"
"""Juju charm configuration"""
config_cmd_line = ['config-get']
if scope is not None:
config_cmd_line.append(scope)
@@ -163,6 +164,7 @@ def config(scope=None):
@cached
def relation_get(attribute=None, unit=None, rid=None):
"""Get relation information"""
_args = ['relation-get', '--format=json']
if rid:
_args.append('-r')
@@ -174,9 +176,14 @@ def relation_get(attribute=None, unit=None, rid=None):
return json.loads(subprocess.check_output(_args))
except ValueError:
return None
except CalledProcessError, e:
if e.returncode == 2:
return None
raise
def relation_set(relation_id=None, relation_settings={}, **kwargs):
"""Set relation information for the current unit"""
relation_cmd_line = ['relation-set']
if relation_id is not None:
relation_cmd_line.extend(('-r', relation_id))
@@ -192,7 +199,7 @@ def relation_set(relation_id=None, relation_settings={}, **kwargs):
@cached
def relation_ids(reltype=None):
"A list of relation_ids"
"""A list of relation_ids"""
reltype = reltype or relation_type()
relid_cmd_line = ['relation-ids', '--format=json']
if reltype is not None:
@@ -203,7 +210,7 @@ def relation_ids(reltype=None):
@cached
def related_units(relid=None):
"A list of related units"
"""A list of related units"""
relid = relid or relation_id()
units_cmd_line = ['relation-list', '--format=json']
if relid is not None:
@@ -213,7 +220,7 @@ def related_units(relid=None):
@cached
def relation_for_unit(unit=None, rid=None):
"Get the json represenation of a unit's relation"
"""Get the json represenation of a unit's relation"""
unit = unit or remote_unit()
relation = relation_get(unit=unit, rid=rid)
for key in relation:
@@ -225,7 +232,7 @@ def relation_for_unit(unit=None, rid=None):
@cached
def relations_for_id(relid=None):
"Get relations of a specific relation ID"
"""Get relations of a specific relation ID"""
relation_data = []
relid = relid or relation_ids()
for unit in related_units(relid):
@@ -237,7 +244,7 @@ def relations_for_id(relid=None):
@cached
def relations_of_type(reltype=None):
"Get relations of a specific type"
"""Get relations of a specific type"""
relation_data = []
reltype = reltype or relation_type()
for relid in relation_ids(reltype):
@@ -249,7 +256,7 @@ def relations_of_type(reltype=None):
@cached
def relation_types():
"Get a list of relation types supported by this charm"
"""Get a list of relation types supported by this charm"""
charmdir = os.environ.get('CHARM_DIR', '')
mdf = open(os.path.join(charmdir, 'metadata.yaml'))
md = yaml.safe_load(mdf)
@@ -264,6 +271,7 @@ def relation_types():
@cached
def relations():
"""Get a nested dictionary of relation data for all related units"""
rels = {}
for reltype in relation_types():
relids = {}
@@ -277,15 +285,35 @@ def relations():
return rels
@cached
def is_relation_made(relation, keys='private-address'):
'''
Determine whether a relation is established by checking for
presence of key(s). If a list of keys is provided, they
must all be present for the relation to be identified as made
'''
if isinstance(keys, str):
keys = [keys]
for r_id in relation_ids(relation):
for unit in related_units(r_id):
context = {}
for k in keys:
context[k] = relation_get(k, rid=r_id,
unit=unit)
if None not in context.values():
return True
return False
def open_port(port, protocol="TCP"):
"Open a service network port"
"""Open a service network port"""
_args = ['open-port']
_args.append('{}/{}'.format(port, protocol))
subprocess.check_call(_args)
def close_port(port, protocol="TCP"):
"Close a service network port"
"""Close a service network port"""
_args = ['close-port']
_args.append('{}/{}'.format(port, protocol))
subprocess.check_call(_args)
@@ -293,6 +321,7 @@ def close_port(port, protocol="TCP"):
@cached
def unit_get(attribute):
"""Get the unit ID for the remote unit"""
_args = ['unit-get', '--format=json', attribute]
try:
return json.loads(subprocess.check_output(_args))
@@ -301,22 +330,46 @@ def unit_get(attribute):
def unit_private_ip():
"""Get this unit's private IP address"""
return unit_get('private-address')
class UnregisteredHookError(Exception):
"""Raised when an undefined hook is called"""
pass
class Hooks(object):
"""A convenient handler for hook functions.
Example:
hooks = Hooks()
# register a hook, taking its name from the function name
@hooks.hook()
def install():
...
# register a hook, providing a custom hook name
@hooks.hook("config-changed")
def config_changed():
...
if __name__ == "__main__":
# execute a hook based on the name the program is called by
hooks.execute(sys.argv)
"""
def __init__(self):
super(Hooks, self).__init__()
self._hooks = {}
def register(self, name, function):
"""Register a hook"""
self._hooks[name] = function
def execute(self, args):
"""Execute a registered hook based on args[0]"""
hook_name = os.path.basename(args[0])
if hook_name in self._hooks:
self._hooks[hook_name]()
@@ -324,6 +377,7 @@ class Hooks(object):
raise UnregisteredHookError(hook_name)
def hook(self, *hook_names):
"""Decorator, registering them as hooks"""
def wrapper(decorated):
for hook_name in hook_names:
self.register(hook_name, decorated)
@@ -337,4 +391,5 @@ class Hooks(object):
def charm_dir():
"""Return the root directory of the current charm"""
return os.environ.get('CHARM_DIR')