diff --git a/hooks/charmhelpers/contrib/openstack/context.py b/hooks/charmhelpers/contrib/openstack/context.py index d254de18..e4ad6c2a 100644 --- a/hooks/charmhelpers/contrib/openstack/context.py +++ b/hooks/charmhelpers/contrib/openstack/context.py @@ -147,7 +147,8 @@ class SharedDBContext(OSContextGenerator): 'database_host': rdata.get('db_host'), 'database': self.database, 'database_user': self.user, - 'database_password': rdata.get(password_setting) + 'database_password': rdata.get(password_setting), + 'database_type': 'mysql', } if context_complete(ctxt): db_ssl(rdata, ctxt, self.ssl_dir) @@ -155,6 +156,35 @@ class SharedDBContext(OSContextGenerator): return {} +class PostgresqlDBContext(OSContextGenerator): + interfaces = ['pgsql-db'] + + def __init__(self, database=None): + self.database = database + + def __call__(self): + self.database = self.database or config('database') + if self.database is None: + log('Could not generate postgresql_db context. ' + 'Missing required charm config options. ' + '(database name)') + raise OSContextError + ctxt = {} + + for rid in relation_ids(self.interfaces[0]): + for unit in related_units(rid): + ctxt = { + 'database_host': relation_get('host', rid=rid, unit=unit), + 'database': self.database, + 'database_user': relation_get('user', rid=rid, unit=unit), + 'database_password': relation_get('password', rid=rid, unit=unit), + 'database_type': 'postgresql', + } + if context_complete(ctxt): + return ctxt + return {} + + def db_ssl(rdata, ctxt, ssl_dir): if 'ssl_ca' in rdata and ssl_dir: ca_path = os.path.join(ssl_dir, 'db-client.ca') diff --git a/hooks/keystone_context.pyc b/hooks/keystone_context.pyc new file mode 100644 index 00000000..f88afe16 Binary files /dev/null and b/hooks/keystone_context.pyc differ diff --git a/hooks/keystone_hooks.py b/hooks/keystone_hooks.py index 5eb49aac..f35c25dc 100755 --- a/hooks/keystone_hooks.py +++ b/hooks/keystone_hooks.py @@ -12,7 +12,9 @@ from charmhelpers.core.hookenv import ( Hooks, UnregisteredHookError, config, + is_relation_made, log, + ERROR, relation_get, relation_ids, relation_set, @@ -102,11 +104,30 @@ def config_changed(): @hooks.hook('shared-db-relation-joined') def db_joined(): + if is_relation_made('pgsql-db'): + # error, postgresql is used + e = ('Attempting to associate a mysql database when there is already ' + 'associated a postgresql one') + log(e, level=ERROR) + raise Exception(e) + relation_set(database=config('database'), username=config('database-user'), hostname=unit_get('private-address')) +@hooks.hook('pgsql-db-relation-joined') +def pgsql_db_joined(): + if is_relation_made('shared-db'): + # raise error + e = ('Attempting to associate a postgresql database when there is already ' + 'associated a mysql one') + log(e, level=ERROR) + raise Exception(e) + + relation_set(database=config('database')) + + @hooks.hook('shared-db-relation-changed') @restart_on_change(restart_map()) def db_changed(): @@ -124,6 +145,23 @@ def db_changed(): identity_changed(relation_id=rid, remote_unit=unit) +@hooks.hook('pgsql-db-relation-changed') +@restart_on_change(restart_map()) +def pgsql_db_changed(): + if 'pgsql-db' not in CONFIGS.complete_contexts(): + log('pgsql-db relation incomplete. Peer not ready?') + else: + CONFIGS.write(KEYSTONE_CONF) + if eligible_leader(CLUSTER_RES): + migrate_database() + ensure_initial_admin(config) + # Ensure any existing service entries are updated in the + # new database backend + for rid in relation_ids('identity-service'): + for unit in related_units(rid): + identity_changed(relation_id=rid, remote_unit=unit) + + @hooks.hook('identity-service-relation-joined') def identity_joined(): """ Do nothing until we get information about requested service """ diff --git a/hooks/keystone_hooks.pyc b/hooks/keystone_hooks.pyc new file mode 100644 index 00000000..dda8b561 Binary files /dev/null and b/hooks/keystone_hooks.pyc differ diff --git a/hooks/keystone_ssl.py b/hooks/keystone_ssl.py index 1cbdfad7..45e0029d 100644 --- a/hooks/keystone_ssl.py +++ b/hooks/keystone_ssl.py @@ -1,12 +1,10 @@ #!/usr/bin/python -import base64 import os import shutil import subprocess import tarfile import tempfile -import zipfile CA_EXPIRY = '365' ORG_NAME = 'Ubuntu' diff --git a/hooks/keystone_ssl.pyc b/hooks/keystone_ssl.pyc new file mode 100644 index 00000000..bcdbbf62 Binary files /dev/null and b/hooks/keystone_ssl.pyc differ diff --git a/hooks/keystone_utils.py b/hooks/keystone_utils.py index 54f0d00e..efc97407 100644 --- a/hooks/keystone_utils.py +++ b/hooks/keystone_utils.py @@ -64,6 +64,7 @@ BASE_PACKAGES = [ 'openssl', 'python-keystoneclient', 'python-mysqldb', + 'python-psycopg2', 'pwgen', 'unison', 'uuid', @@ -98,6 +99,7 @@ BASE_RESOURCE_MAP = OrderedDict([ 'services': BASE_SERVICES, 'contexts': [keystone_context.KeystoneContext(), context.SharedDBContext(ssl_dir=KEYSTONE_CONF_DIR), + context.PostgresqlDBContext(), context.SyslogContext(), keystone_context.HAProxyContext()], }), diff --git a/hooks/keystone_utils.pyc b/hooks/keystone_utils.pyc new file mode 100644 index 00000000..ec566613 Binary files /dev/null and b/hooks/keystone_utils.pyc differ diff --git a/hooks/lib/apache_utils.py b/hooks/lib/apache_utils.py index 131f8ac0..3890582c 100644 --- a/hooks/lib/apache_utils.py +++ b/hooks/lib/apache_utils.py @@ -17,8 +17,7 @@ from lib.utils import ( config_get, install, get_host_ip, - restart - ) + restart) from lib.cluster_utils import https import os @@ -136,8 +135,7 @@ def enable_https(port_maps, namespace, cert, key, ca_cert=None): "ext": ext_port, "int": int_port, "namespace": namespace, - "private_address": get_host_ip() - } + "private_address": get_host_ip()} fsite.write(render_template(SITE_TEMPLATE, context)) @@ -160,7 +158,7 @@ def disable_https(port_maps, namespace): juju_log('INFO', 'Ensuring HTTPS disabled for {}'.format(port_maps)) if (not os.path.exists('/etc/apache2') or - not os.path.exists(os.path.join('/etc/apache2/ssl', namespace))): + not os.path.exists(os.path.join('/etc/apache2/ssl', namespace))): return http_restart = False diff --git a/hooks/lib/cluster_utils.py b/hooks/lib/cluster_utils.py index b7d00f8b..1405d6fb 100644 --- a/hooks/lib/cluster_utils.py +++ b/hooks/lib/cluster_utils.py @@ -14,8 +14,7 @@ from lib.utils import ( relation_list, relation_get, get_unit_hostname, - config_get - ) + config_get) import subprocess import os @@ -34,8 +33,7 @@ def is_clustered(): def is_leader(resource): cmd = [ "crm", "resource", - "show", resource - ] + "show", resource] try: status = subprocess.check_output(cmd) except subprocess.CalledProcessError: @@ -91,9 +89,9 @@ def https(): for r_id in relation_ids('identity-service'): for unit in relation_list(r_id): if (relation_get('https_keystone', rid=r_id, unit=unit) and - relation_get('ssl_cert', rid=r_id, unit=unit) and - relation_get('ssl_key', rid=r_id, unit=unit) and - relation_get('ca_cert', rid=r_id, unit=unit)): + relation_get('ssl_cert', rid=r_id, unit=unit) and + relation_get('ssl_key', rid=r_id, unit=unit) and + relation_get('ca_cert', rid=r_id, unit=unit)): return True return False diff --git a/hooks/lib/haproxy_utils.py b/hooks/lib/haproxy_utils.py index 721bb7f2..f14a20c1 100644 --- a/hooks/lib/haproxy_utils.py +++ b/hooks/lib/haproxy_utils.py @@ -14,8 +14,7 @@ from lib.utils import ( relation_get, unit_get, reload, - render_template - ) + render_template) import os HAPROXY_CONF = '/etc/haproxy/haproxy.cfg' @@ -44,8 +43,7 @@ def configure_haproxy(service_ports): unit=unit) context = { 'units': cluster_hosts, - 'service_ports': service_ports - } + 'service_ports': service_ports} with open(HAPROXY_CONF, 'w') as f: f.write(render_template(os.path.basename(HAPROXY_CONF), context)) diff --git a/hooks/lib/unison.py b/hooks/lib/unison.py index 06dd8b4c..fdef449d 100755 --- a/hooks/lib/unison.py +++ b/hooks/lib/unison.py @@ -73,15 +73,14 @@ def get_keypair(user): pub_key = '%s.pub' % priv_key if not os.path.isfile(pub_key): - utils.juju_log('INFO', 'Generatring missing ssh public key @ %s.' % \ + utils.juju_log('INFO', 'Generatring missing ssh public key @ %s.' % pub_key) cmd = ['ssh-keygen', '-y', '-f', priv_key] p = subprocess.check_output(cmd).strip() with open(pub_key, 'wb') as out: out.write(p) subprocess.check_call(['chown', '-R', user, ssh_dir]) - return open(priv_key, 'r').read().strip(), \ - open(pub_key, 'r').read().strip() + return open(priv_key, 'r').read().strip(), open(pub_key, 'r').read().strip() def write_authorized_keys(user, keys): @@ -149,7 +148,7 @@ def ssh_authorized_peers(peer_interface, user, group=None, ensure_local_user=Fal hosts.append(settings['private-address']) else: utils.juju_log('INFO', - 'ssh_authorized_peers(): ssh_pub_key '\ + 'ssh_authorized_peers(): ssh_pub_key ' 'missing for unit %s, skipping.' % unit) write_authorized_keys(user, keys) write_known_hosts(user, hosts) @@ -204,8 +203,7 @@ def sync_to_peers(peer_interface, user, paths=[], verbose=False): hosts.append(settings['private-address']) else: print 'unison sync_to_peers: peer (%s) has not authorized '\ - '*this* host yet, skipping.' %\ - settings['private-address'] + '*this* host yet, skipping.' % settings['private-address'] for path in paths: # removing trailing slash from directory paths, unison @@ -214,7 +212,6 @@ def sync_to_peers(peer_interface, user, paths=[], verbose=False): path = path[:(len(path) - 1)] for host in hosts: cmd = base_cmd + [path, 'ssh://%s@%s/%s' % (user, host, path)] - utils.juju_log('INFO', 'Syncing local path %s to %s@%s:%s' %\ - (path, user, host, path)) - print ' '.join(cmd) + utils.juju_log('INFO', 'Syncing local path %s to %s@%s:%s' % + (path, user, host, path)) run_as_user(user, cmd) diff --git a/hooks/lib/utils.py b/hooks/lib/utils.py index 8095e86a..018ac9e4 100644 --- a/hooks/lib/utils.py +++ b/hooks/lib/utils.py @@ -32,8 +32,7 @@ def install(*pkgs): cmd = [ 'apt-get', '-y', - 'install' - ] + 'install'] for pkg in pkgs: cmd.append(pkg) subprocess.check_call(cmd) @@ -54,16 +53,14 @@ except ImportError: def render_template(template_name, context, template_dir=TEMPLATES_DIR): - templates = jinja2.Environment( - loader=jinja2.FileSystemLoader(template_dir) - ) + templates = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir)) template = templates.get_template(template_name) return template.render(context) CLOUD_ARCHIVE = \ -""" # Ubuntu Cloud Archive -deb http://ubuntu-cloud.archive.canonical.com/ubuntu {} main -""" + """ # Ubuntu Cloud Archive + deb http://ubuntu-cloud.archive.canonical.com/ubuntu {} main + """ CLOUD_ARCHIVE_POCKETS = { 'folsom': 'precise-updates/folsom', @@ -77,8 +74,7 @@ CLOUD_ARCHIVE_POCKETS = { 'havana/proposed': 'precise-proposed/havana', 'icehouse': 'precise-updates/icehouse', 'icehouse/updates': 'precise-updates/icehouse', - 'icehouse/proposed': 'precise-proposed/icehouse', - } + 'icehouse/proposed': 'precise-proposed/icehouse'} def configure_source(): @@ -88,8 +84,7 @@ def configure_source(): if source.startswith('ppa:'): cmd = [ 'add-apt-repository', - source - ] + source] subprocess.check_call(cmd) if source.startswith('cloud:'): # CA values should be formatted as cloud:ubuntu-openstack/pocket, eg: @@ -106,8 +101,7 @@ def configure_source(): cmd = [ 'apt-key', 'adv', '--keyserver keyserver.ubuntu.com', - '--recv-keys', key - ] + '--recv-keys', key] subprocess.check_call(cmd) elif l == 1: apt_line = source @@ -116,8 +110,7 @@ def configure_source(): apt.write(apt_line + "\n") cmd = [ 'apt-get', - 'update' - ] + 'update'] subprocess.check_call(cmd) # Protocols @@ -128,8 +121,7 @@ UDP = 'UDP' def expose(port, protocol='TCP'): cmd = [ 'open-port', - '{}/{}'.format(port, protocol) - ] + '{}/{}'.format(port, protocol)] subprocess.check_call(cmd) @@ -137,8 +129,7 @@ def juju_log(severity, message): cmd = [ 'juju-log', '--log-level', severity, - message - ] + message] subprocess.check_call(cmd) @@ -162,8 +153,7 @@ def cached(func): def relation_ids(relation): cmd = [ 'relation-ids', - relation - ] + relation] result = str(subprocess.check_output(cmd)).split() if result == "": return None @@ -175,8 +165,7 @@ def relation_ids(relation): def relation_list(rid): cmd = [ 'relation-list', - '-r', rid, - ] + '-r', rid] result = str(subprocess.check_output(cmd)).split() if result == "": return None @@ -187,8 +176,7 @@ def relation_list(rid): @cached def relation_get(attribute, unit=None, rid=None): cmd = [ - 'relation-get', - ] + 'relation-get'] if rid: cmd.append('-r') cmd.append(rid) @@ -206,8 +194,7 @@ def relation_get(attribute, unit=None, rid=None): def relation_get_dict(relation_id=None, remote_unit=None): """Obtain all relation data as dict by way of JSON""" cmd = [ - 'relation-get', '--format=json' - ] + 'relation-get', '--format=json'] if relation_id: cmd.append('-r') cmd.append(relation_id) @@ -225,8 +212,7 @@ def relation_get_dict(relation_id=None, remote_unit=None): def relation_set(**kwargs): cmd = [ - 'relation-set' - ] + 'relation-set'] args = [] for k, v in kwargs.items(): if k == 'rid': @@ -243,8 +229,7 @@ def relation_set(**kwargs): def unit_get(attribute): cmd = [ 'unit-get', - attribute - ] + attribute] value = subprocess.check_output(cmd).strip() # IGNORE:E1103 if value == "": return None @@ -257,8 +242,7 @@ def config_get(attribute): cmd = [ 'config-get', '--format', - 'json', - ] + 'json'] out = subprocess.check_output(cmd).strip() # IGNORE:E1103 cfg = json.loads(out) @@ -321,8 +305,7 @@ def running(service): except subprocess.CalledProcessError: return False else: - if ("start/running" in output or - "is running" in output): + if ("start/running" in output or "is running" in output): return True else: return False diff --git a/hooks/pgsql-db-relation-changed b/hooks/pgsql-db-relation-changed new file mode 120000 index 00000000..dd3b3eff --- /dev/null +++ b/hooks/pgsql-db-relation-changed @@ -0,0 +1 @@ +keystone_hooks.py \ No newline at end of file diff --git a/hooks/pgsql-db-relation-joined b/hooks/pgsql-db-relation-joined new file mode 120000 index 00000000..dd3b3eff --- /dev/null +++ b/hooks/pgsql-db-relation-joined @@ -0,0 +1 @@ +keystone_hooks.py \ No newline at end of file diff --git a/metadata.yaml b/metadata.yaml index 42d82ab2..498b197b 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -12,6 +12,8 @@ provides: requires: shared-db: interface: mysql-shared + pgsql-db: + interface: pgsql ha: interface: hacluster scope: container diff --git a/revision b/revision index bf18240e..dcb6b5ba 100644 --- a/revision +++ b/revision @@ -1 +1 @@ -229 +230 diff --git a/templates/essex/keystone.conf b/templates/essex/keystone.conf index 9580f959..f514d9bb 100644 --- a/templates/essex/keystone.conf +++ b/templates/essex/keystone.conf @@ -14,7 +14,7 @@ verbose = {{ verbose }} [sql] {% if database_host -%} -connection = mysql://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} +connection = {{ database_type }}://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} {% else -%} connection = sqlite:////var/lib/keystone/keystone.db {% endif -%} diff --git a/templates/folsom/keystone.conf b/templates/folsom/keystone.conf index 8d1c560c..1daa88ec 100644 --- a/templates/folsom/keystone.conf +++ b/templates/folsom/keystone.conf @@ -14,7 +14,7 @@ verbose = {{ verbose }} [sql] {% if database_host -%} -connection = mysql://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} +connection = {{ database_type }}://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} {% else -%} connection = sqlite:////var/lib/keystone/keystone.db {% endif -%} diff --git a/templates/grizzly/keystone.conf b/templates/grizzly/keystone.conf index 0ffb2bfa..370f78a4 100644 --- a/templates/grizzly/keystone.conf +++ b/templates/grizzly/keystone.conf @@ -14,7 +14,7 @@ verbose = {{ verbose }} [sql] {% if database_host -%} -connection = mysql://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} +connection = {{ database_type }}://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} {% else -%} connection = sqlite:////var/lib/keystone/keystone.db {% endif -%} diff --git a/templates/havana/keystone.conf b/templates/havana/keystone.conf index ca28d9b0..f53310e2 100644 --- a/templates/havana/keystone.conf +++ b/templates/havana/keystone.conf @@ -14,7 +14,7 @@ verbose = {{ verbose }} [sql] {% if database_host -%} -connection = mysql://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} +connection = {{ database_type }}://{{ database_user }}:{{ database_password }}@{{ database_host }}/{{ database }}{% if database_ssl_ca %}?ssl_ca={{ database_ssl_ca }}{% if database_ssl_cert %}&ssl_cert={{ database_ssl_cert }}&ssl_key={{ database_ssl_key }}{% endif %}{% endif %} {% else -%} connection = sqlite:////var/lib/keystone/keystone.db {% endif -%} diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/test_keystone_hooks.py b/unit_tests/test_keystone_hooks.py new file mode 100644 index 00000000..9e2de575 --- /dev/null +++ b/unit_tests/test_keystone_hooks.py @@ -0,0 +1,120 @@ +from mock import call, patch, MagicMock +import os + +from test_utils import CharmTestCase + +os.environ['JUJU_UNIT_NAME'] = 'keystone' +with patch('charmhelpers.core.hookenv.config') as config: + config.return_value = 'keystone' + import keystone_utils as utils + +_reg = utils.register_configs +_map = utils.restart_map + +utils.register_configs = MagicMock() +utils.restart_map = MagicMock() + +import keystone_hooks as hooks + +utils.register_configs = _reg +utils.restart_map = _map + +TO_PATCH = [ + # charmhelpers.core.hookenv + 'Hooks', + 'config', + 'is_relation_made', + 'log', + 'relation_ids', + 'relation_set', + 'relation_get', + 'unit_get', + # charmhelpers.core.host + 'apt_install', + 'apt_update', + 'restart_on_change', + # charmhelpers.contrib.openstack.utils + 'configure_installation_source', + # charmhelpers.contrib.hahelpers.cluster_utils + 'eligible_leader', + # keystone_utils + 'restart_map', + 'register_configs', + 'do_openstack_upgrade', + 'migrate_database', + # other + 'check_call', + 'execd_preinstall', + 'mkdir' +] + + +class KeystoneRelationTests(CharmTestCase): + + def setUp(self): + super(KeystoneRelationTests, self).setUp(hooks, TO_PATCH) + self.config.side_effect = self.test_config.get + + + def test_db_joined(self): + self.unit_get.return_value = 'keystone.foohost.com' + self.is_relation_made.return_value = False + hooks.db_joined() + self.relation_set.assert_called_with(database='keystone', + username='keystone', + hostname='keystone.foohost.com') + self.unit_get.assert_called_with('private-address') + + def test_postgresql_db_joined(self): + self.unit_get.return_value = 'keystone.foohost.com' + self.is_relation_made.return_value = False + hooks.pgsql_db_joined() + self.relation_set.assert_called_with(database='keystone'), + + def test_db_joined_with_postgresql(self): + self.is_relation_made.return_value = True + + with self.assertRaises(Exception) as context: + hooks.db_joined() + self.assertEqual(context.exception.message, + 'Attempting to associate a mysql database when there ' + 'is already associated a postgresql one') + + def test_postgresql_joined_with_db(self): + self.is_relation_made.return_value = True + + with self.assertRaises(Exception) as context: + hooks.pgsql_db_joined() + self.assertEqual(context.exception.message, + 'Attempting to associate a postgresql database when there ' + 'is already associated a mysql one') + + @patch.object(hooks, 'CONFIGS') + def test_db_changed_missing_relation_data(self, configs): + configs.complete_contexts = MagicMock() + configs.complete_contexts.return_value = [] + hooks.db_changed() + self.log.assert_called_with( + 'shared-db relation incomplete. Peer not ready?' + ) + + @patch.object(hooks, 'CONFIGS') + def test_postgresql_db_changed_missing_relation_data(self, configs): + configs.complete_contexts = MagicMock() + configs.complete_contexts.return_value = [] + hooks.pgsql_db_changed() + self.log.assert_called_with( + 'pgsql-db relation incomplete. Peer not ready?' + ) + + def _shared_db_test(self, configs): + configs.complete_contexts = MagicMock() + configs.complete_contexts.return_value = ['shared-db'] + configs.write = MagicMock() + hooks.db_changed() + + def _postgresql_db_test(self, configs): + configs.complete_contexts = MagicMock() + configs.complete_contexts.return_value = ['pgsql-db'] + configs.write = MagicMock() + hooks.pgsql_db_changed() diff --git a/unit_tests/test_utils.py b/unit_tests/test_utils.py new file mode 100644 index 00000000..e1e346b1 --- /dev/null +++ b/unit_tests/test_utils.py @@ -0,0 +1,119 @@ +import logging +import os +import unittest +import yaml + +from contextlib import contextmanager +from mock import patch, MagicMock + + +def load_config(): + '''Walk backwords from __file__ looking for config.yaml, + load and return the 'options' section' + ''' + config = None + f = __file__ + while config is None: + d = os.path.dirname(f) + if os.path.isfile(os.path.join(d, 'config.yaml')): + config = os.path.join(d, 'config.yaml') + break + f = d + + if not config: + logging.error('Could not find config.yaml in any parent directory ' + 'of %s. ' % file) + raise Exception + + return yaml.safe_load(open(config).read())['options'] + + +def get_default_config(): + '''Load default charm config from config.yaml return as a dict. + If no default is set in config.yaml, its value is None. + ''' + default_config = {} + config = load_config() + for k, v in config.iteritems(): + if 'default' in v: + default_config[k] = v['default'] + else: + default_config[k] = None + return default_config + + +class CharmTestCase(unittest.TestCase): + + def setUp(self, obj, patches): + super(CharmTestCase, self).setUp() + self.patches = patches + self.obj = obj + self.test_config = TestConfig() + self.test_relation = TestRelation() + self.patch_all() + + def patch(self, method): + _m = patch.object(self.obj, method) + mock = _m.start() + self.addCleanup(_m.stop) + return mock + + def patch_all(self): + for method in self.patches: + setattr(self, method, self.patch(method)) + + +class TestConfig(object): + + def __init__(self): + self.config = get_default_config() + + def get(self, attr=None): + if not attr: + return self.get_all() + try: + return self.config[attr] + except KeyError: + return None + + def get_all(self): + return self.config + + def set(self, attr, value): + if attr not in self.config: + raise KeyError + self.config[attr] = value + + +class TestRelation(object): + + def __init__(self, relation_data={}): + self.relation_data = relation_data + + def set(self, relation_data): + self.relation_data = relation_data + + def get(self, attr=None, unit=None, rid=None): + if attr is None: + return self.relation_data + elif attr in self.relation_data: + return self.relation_data[attr] + return None + + +@contextmanager +def patch_open(): + '''Patch open() to allow mocking both open() itself and the file that is + yielded. + Yields the mock for "open" and "file", respectively. + ''' + mock_open = MagicMock(spec=open) + mock_file = MagicMock(spec=file) + + @contextmanager + def stub_open(*args, **kwargs): + mock_open(*args, **kwargs) + yield mock_file + + with patch('__builtin__.open', stub_open): + yield mock_open, mock_file