diff --git a/provides.py b/provides.py index dfac3b5..4412923 100644 --- a/provides.py +++ b/provides.py @@ -105,20 +105,27 @@ class MySQLRouterProvides(reactive.Endpoint): :returns: None, this function is called for its side effect :rtype: None """ + self._set_db_connection_info( + getattr(self.relations[relation_id], 'to_publish_app'), + db_host, password, allowed_units, prefix, wait_timeout, ssl_ca) + # NOTE(ganso): Deprecated non-app-bag data for backwards compatibility + self._set_db_connection_info( + getattr(self.relations[relation_id], 'to_publish'), + db_host, password, allowed_units, prefix, wait_timeout, ssl_ca) + def _set_db_connection_info( + self, publish_prop, db_host, password, + allowed_units=None, prefix=None, wait_timeout=None, + ssl_ca=None): # No prefix for db_host or wait_timeout - self.relations[relation_id].to_publish["db_host"] = db_host + publish_prop["db_host"] = db_host if wait_timeout: - self.relations[relation_id].to_publish["wait_timeout"] = ( - wait_timeout) + publish_prop["wait_timeout"] = wait_timeout if ssl_ca: - self.relations[relation_id].to_publish["ssl_ca"] = ssl_ca + publish_prop["ssl_ca"] = ssl_ca if not prefix: - self.relations[relation_id].to_publish["password"] = password - self.relations[relation_id].to_publish[ - "allowed_units"] = allowed_units + publish_prop["password"] = password + publish_prop["allowed_units"] = allowed_units else: - self.relations[relation_id].to_publish[ - "{}_password".format(prefix)] = password - self.relations[relation_id].to_publish[ - "{}_allowed_units".format(prefix)] = allowed_units + publish_prop["{}_password".format(prefix)] = password + publish_prop["{}_allowed_units".format(prefix)] = allowed_units diff --git a/requires.py b/requires.py index 82d64c8..43ff74a 100644 --- a/requires.py +++ b/requires.py @@ -1,57 +1,91 @@ -from charmhelpers.core import hookenv -import charms.reactive as reactive +from charmhelpers.core import unitdata +from charms.reactive import ( + Endpoint, + when, + set_flag, + clear_flag, +) -class MySQLRouterRequires(reactive.RelationBase): - scope = reactive.scopes.GLOBAL +# NOTE: fork of relations.AutoAccessors for forwards compat behaviour +class MySQLRouterAutoAccessors(type): + """ + Metaclass that converts fields referenced by ``auto_accessors`` into + accessor methods with very basic doc strings. + """ + def __new__(cls, name, parents, dct): + for field in dct.get('auto_accessors', []): + meth_name = field.replace('-', '_') + meth = cls._accessor(field) + meth.__name__ = meth_name + meth.__module__ = dct.get('__module__') + meth.__doc__ = 'Get the %s, if available, or None.' % field + dct[meth_name] = meth + return super(MySQLRouterAutoAccessors, cls).__new__( + cls, name, parents, dct + ) - # These remote data fields will be automatically mapped to accessors - # with a basic documentation string provided. - auto_accessors = [ - 'db_host', 'ssl_ca', 'ssl_cert', 'ssl_key', 'wait_timeout'] + @staticmethod + def _accessor(field): + def __accessor(self): + return self.all_joined_units.received_raw.get(field) + return __accessor - @reactive.hook('{requires:mysql-router}-relation-joined') + +class MySQLRouterRequires(Endpoint, metaclass=MySQLRouterAutoAccessors): + + key = 'reactive.conversations.db-router.global.local-data.{}' + + kv = unitdata.kv() + + auto_accessors = ['db_host', 'ssl_ca', 'ssl_cert', 'ssl_key', + 'wait_timeout'] + + @when('endpoint.{endpoint_name}.joined') def joined(self): - self.set_state('{relation_name}.connected') + set_flag(self.expand_name('{endpoint_name}.connected')) self.set_or_clear_available() - def set_or_clear_available(self): - if self.db_router_data_complete(): - self.set_state('{relation_name}.available') - else: - self.remove_state('{relation_name}.available') - if self.proxy_db_data_complete(): - self.set_state('{relation_name}.available.proxy') - else: - self.remove_state('{relation_name}.available.proxy') - if self.ssl_data_complete(): - self.set_state('{relation_name}.available.ssl') - else: - self.remove_state('{relation_name}.available.ssl') - - @reactive.hook('{requires:mysql-router}-relation-changed') + @when('endpoint.{endpoint_name}.changed') def changed(self): self.joined() - @reactive.hook('{requires:mysql-router}-relation-{broken,departed}') + def set_or_clear_available(self): + if self.db_router_data_complete(): + set_flag(self.expand_name('{endpoint_name}.available')) + else: + clear_flag(self.expand_name('{endpoint_name}.available')) + if self.proxy_db_data_complete(): + set_flag(self.expand_name('{endpoint_name}.available.proxy')) + else: + clear_flag(self.expand_name('{endpoint_name}.available.proxy')) + if self.ssl_data_complete(): + set_flag(self.expand_name('{endpoint_name}.available.ssl')) + else: + clear_flag(self.expand_name('{endpoint_name}.available.ssl')) + + @when('endpoint.{endpoint_name}.broken') + def broken(self): + self.departed() + + @when('endpoint.{endpoint_name}.departed') def departed(self): # Clear state - self.remove_state('{relation_name}.connected') - self.remove_state('{relation_name}.available') - self.remove_state('{relation_name}.proxy.available') - self.remove_state('{relation_name}.available.ssl') + clear_flag(self.expand_name('{endpoint_name}.connected')) + clear_flag(self.expand_name('{endpoint_name}.available')) + clear_flag(self.expand_name('{endpoint_name}.proxy.available')) + clear_flag(self.expand_name('{endpoint_name}.available.ssl')) # Check if this is the last unit last_unit = True - for conversation in self.conversations(): - for rel_id in conversation.relation_ids: - if len(hookenv.related_units(rel_id)) > 0: - # This is not the last unit so reevaluate state - self.joined() - self.changed() - last_unit = False + for relation in self.relations: + if len(relation.units) > 0: + # This is not the last unit so reevaluate state + self.joined() + self.changed() + last_unit = False if last_unit: # Bug #1972883 - self.set_local('prefixes', []) + self._set_local('prefixes', []) def configure_db_router(self, username, hostname, prefix): """ @@ -64,8 +98,10 @@ class MySQLRouterRequires(reactive.RelationBase): 'private-address': hostname, } self.set_prefix(prefix) - self.set_remote(**relation_info) - self.set_local(**relation_info) + for relation in self.relations: + for k, v in relation_info.items(): + relation.to_publish_raw[k] = v + self._set_local(k, v) def configure_proxy_db(self, database, username, hostname, prefix): """ @@ -78,55 +114,82 @@ class MySQLRouterRequires(reactive.RelationBase): prefix + '_hostname': hostname, } self.set_prefix(prefix) - self.set_remote(**relation_info) - self.set_local(**relation_info) + for relation in self.relations: + for k, v in relation_info.items(): + relation.to_publish_raw[k] = v + self._set_local(k, v) + + def _set_local(self, key, value): + self.kv.set(self.key.format(key), value) + + def _get_local(self, key): + return self.kv.get(self.key.format(key)) def set_prefix(self, prefix): """ Store all of the database prefixes in a list. """ - prefixes = self.get_local('prefixes') - if prefixes: - if prefix not in prefixes: - self.set_local('prefixes', prefixes + [prefix]) - else: - self.set_local('prefixes', [prefix]) + prefixes = self._get_local('prefixes') + for relation in self.relations: + if prefixes: + if prefix not in prefixes: + self._set_local('prefixes', prefixes + [prefix]) + else: + self._set_local('prefixes', [prefix]) def get_prefixes(self): """ Return the list of saved prefixes. """ - return self.get_local('prefixes') + return self._get_local('prefixes') def database(self, prefix): """ Return a configured database name. """ - return self.get_local(prefix + '_database') + return self._get_local(prefix + '_database') def username(self, prefix): """ Return a configured username. """ - return self.get_local(prefix + '_username') + return self._get_local(prefix + '_username') def hostname(self, prefix): """ Return a configured hostname. """ - return self.get_local(prefix + '_hostname') + return self._get_local(prefix + '_hostname') + + def _received_app(self, key): + value = None + for relation in self.relations: + value = relation.received_app_raw.get(key) + if value: + return value + # NOTE(ganso): backwards compatibility with non-app-bag below + if not value: + return self.all_joined_units.received_raw.get(key) def password(self, prefix): """ Return a database password. """ - return self.get_remote(prefix + '_password') + return self._received_app(prefix + '_password') def allowed_units(self, prefix): """ Return a database's allowed_units. """ - return self.get_remote(prefix + '_allowed_units') + return self._received_app(prefix + '_allowed_units') + + def _read_suffixes(self, suffixes): + data = {} + for prefix in self.get_prefixes(): + for suffix in suffixes: + key = prefix + suffix + data[key] = self._received_app(key) + return data def db_router_data_complete(self): """ @@ -137,10 +200,7 @@ class MySQLRouterRequires(reactive.RelationBase): } if self.get_prefixes(): suffixes = ['_password'] - for prefix in self.get_prefixes(): - for suffix in suffixes: - key = prefix + suffix - data[key] = self.get_remote(key) + data.update(self._read_suffixes(suffixes)) if all(data.values()): return True return False @@ -155,10 +215,7 @@ class MySQLRouterRequires(reactive.RelationBase): # The mysql-router prefix + proxied db prefixes if self.get_prefixes() and len(self.get_prefixes()) > 1: suffixes = ['_password', '_allowed_units'] - for prefix in self.get_prefixes(): - for suffix in suffixes: - key = prefix + suffix - data[key] = self.get_remote(key) + data.update(self._read_suffixes(suffixes)) if all(data.values()): return True return False diff --git a/unit_tests/test_provides.py b/unit_tests/test_provides.py index df53eba..4499895 100644 --- a/unit_tests/test_provides.py +++ b/unit_tests/test_provides.py @@ -139,6 +139,7 @@ class TestMySQLRouterProvides(test_utils.PatchHelper): mock.call("password", _pw), mock.call("allowed_units", self.fake_unit.unit_name)] self.fake_relation.to_publish.__setitem__.assert_has_calls(_calls) + self.fake_relation.to_publish_app.__setitem__.assert_has_calls(_calls) def test_set_db_connection_info_prefixed(self): _p = "prefix" @@ -154,6 +155,7 @@ class TestMySQLRouterProvides(test_utils.PatchHelper): mock.call("{}_password".format(_p), _pw), mock.call("{}_allowed_units".format(_p), self.fake_unit.unit_name)] self.fake_relation.to_publish.__setitem__.assert_has_calls(_calls) + self.fake_relation.to_publish_app.__setitem__.assert_has_calls(_calls) def test_set_db_connection_info_wait_timeout(self): _wto = 90 @@ -171,3 +173,4 @@ class TestMySQLRouterProvides(test_utils.PatchHelper): mock.call("{}_password".format(_p), _pw), mock.call("{}_allowed_units".format(_p), self.fake_unit.unit_name)] self.fake_relation.to_publish.__setitem__.assert_has_calls(_calls) + self.fake_relation.to_publish_app.__setitem__.assert_has_calls(_calls) diff --git a/unit_tests/test_requires.py b/unit_tests/test_requires.py index 4bdc8b2..b2beafd 100644 --- a/unit_tests/test_requires.py +++ b/unit_tests/test_requires.py @@ -10,8 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import unittest +import charms_openstack.test_utils as test_utils from unittest import mock import requires @@ -30,7 +29,7 @@ def mock_hook(*args, **kwargs): return inner -class TestMySQLRouterRequires(unittest.TestCase): +class TestMySQLRouterRequires(test_utils.PatchHelper): @classmethod def setUpClass(cls): @@ -61,24 +60,27 @@ class TestMySQLRouterRequires(unittest.TestCase): self._patches = {} self._patches_start = {} - self._rel_ids = ["mysql-router:3"] + self._rel_ids = ["db-router:3"] self._remote_data = {} + self._published_data = {} self._local_data = {} - self._conversation = mock.MagicMock() - self._conversation.relation_ids = self._rel_ids - self._conversation.scope = requires.reactive.scopes.GLOBAL - self._conversation.get_remote.side_effect = self.get_fake_remote_data - self._conversation.get_local.side_effect = self.get_fake_local_data - # The Relation object + self.fake_relation = mock.MagicMock() + self.fake_unit = mock.MagicMock() + self.fake_unit.unit_name = "unit/1" + self.fake_relation.relation_id = self._rel_ids[0] + self.fake_relation.units = [self.fake_unit] self.mysql_router = requires.MySQLRouterRequires( - 'mysql-router', [self._conversation]) - self.patch_mysql_router('conversations', [self._conversation]) - self.patch_mysql_router('set_remote') - self.patch_mysql_router('set_local') - self.patch_mysql_router('set_state') - self.patch_mysql_router('remove_state') + 'mysql-router', self._rel_ids) + self.mysql_router._get_local = mock.MagicMock( + side_effect=self.get_fake_local_data) + self.mysql_router.relations[0] = self.fake_relation + self.fake_relation.to_publish_raw = self._published_data + self.fake_relation.received_app_raw = self._remote_data + self.patch_mysql_router('_set_local') + self.patch_object(requires, "clear_flag") + self.patch_object(requires, "set_flag") self.patch_mysql_router('db_host', "10.5.0.21") self.patch_mysql_router('wait_timeout', 90) @@ -98,9 +100,6 @@ class TestMySQLRouterRequires(unittest.TestCase): self._patches_start[attr] = started setattr(self, attr, started) - def get_fake_remote_data(self, key, default=None): - return self._remote_data.get(key) or default - def get_fake_local_data(self, key, default=None): return self._local_data.get(key) or default @@ -122,48 +121,51 @@ class TestMySQLRouterRequires(unittest.TestCase): self.patch_mysql_router('ssl_data_complete', True) self._local_data = {"prefixes": ["myprefix"]} _calls = [ - mock.call("{relation_name}.available"), - mock.call("{relation_name}.available.proxy"), - mock.call("{relation_name}.available.ssl")] + mock.call("mysql-router.available"), + mock.call("mysql-router.available.proxy"), + mock.call("mysql-router.available.ssl")] self.mysql_router.set_or_clear_available() - self.set_state.assert_has_calls(_calls) + self.set_flag.assert_has_calls(_calls) def test_changed_not_available(self): self.patch_mysql_router('db_router_data_complete', False) self.patch_mysql_router('joined') self._local_data = {"prefixes": ["myprefix"]} self.mysql_router.set_or_clear_available() - self.set_state.assert_not_called() + self.set_flag.assert_not_called() def test_joined(self): self.patch_mysql_router('set_or_clear_available') self.mysql_router.joined() - self.set_state.assert_called_once_with('{relation_name}.connected') + self.set_flag.assert_called_once_with('mysql-router.connected') self.set_or_clear_available.assert_called_once() def test_departed(self): self.mysql_router.departed() _calls = [ - mock.call("{relation_name}.available")] - self.remove_state.assert_has_calls(_calls) + mock.call('mysql-router.connected'), + mock.call("mysql-router.available"), + mock.call('mysql-router.proxy.available'), + mock.call('mysql-router.available.ssl')] + self.clear_flag.assert_has_calls(_calls) def test_db_router_data_complete_missing_prefix(self): - self._remote_data = {"password": "1234", - "allowed_units": "unit/1"} + self._remote_data.update({"password": "1234", + "allowed_units": "unit/1"}) assert self.mysql_router.db_router_data_complete() is False def test_db_router_data_complete(self): self._local_data = {"prefixes": ["myprefix"]} - self._remote_data = {"myprefix_password": "1234", - "myprefix_allowed_units": "unit/1"} + self._remote_data.update({"myprefix_password": "1234", + "myprefix_allowed_units": "unit/1"}) assert self.mysql_router.db_router_data_complete() is True self.db_host.return_value = None assert self.mysql_router.db_router_data_complete() is False def test_db_router_data_complete_wait_timeout(self): self._local_data = {"prefixes": ["myprefix"]} - self._remote_data = {"myprefix_password": "1234", - "myprefix_allowed_units": "unit/1"} + self._remote_data.update({"myprefix_password": "1234", + "myprefix_allowed_units": "unit/1"}) # Wait timeout is an optional value and should not affect data complete self.wait_timeout.return_value = None assert self.mysql_router.db_router_data_complete() is True @@ -172,16 +174,16 @@ class TestMySQLRouterRequires(unittest.TestCase): def test_proxy_db_data_incomplete(self): self._local_data = {"prefixes": ["myprefix"]} - self._remote_data = {"myprefix_password": "1234", - "myprefix_allowed_units": "unit/1"} + self._remote_data.update({"myprefix_password": "1234", + "myprefix_allowed_units": "unit/1"}) assert self.mysql_router.proxy_db_data_complete() is False def test_proxy_db_data_complete(self): self._local_data = {"prefixes": ["myprefix", "db"]} - self._remote_data = {"myprefix_password": "1234", - "myprefix_allowed_units": "unit/1", - "db_password": "1234", - "db_allowed_units": "unit/1"} + self._remote_data.update({"myprefix_password": "1234", + "myprefix_allowed_units": "unit/1", + "db_password": "1234", + "db_allowed_units": "unit/1"}) assert self.mysql_router.proxy_db_data_complete() is True self.db_host.return_value = None assert self.mysql_router.proxy_db_data_complete() is False @@ -220,9 +222,8 @@ class TestMySQLRouterRequires(unittest.TestCase): for key, test in _tests.items(): self.assertEqual(test(_prefix), None) # Set - self._local_data = {"prefixes": [_prefix]} for key, test in _tests.items(): - self._remote_data = {"{}_{}".format(_prefix, key): _value} + self._remote_data.update({"{}_{}".format(_prefix, key): _value}) self.assertEqual(test(_prefix), _value) def test_configure_db_router(self): @@ -234,9 +235,15 @@ class TestMySQLRouterRequires(unittest.TestCase): "{}_username".format(_prefix): _user, "{}_hostname".format(_prefix): _host, "private-address": _host} + calls = [ + mock.call('prefix_username', _user), + mock.call('prefix_hostname', _host), + mock.call('private-address', _host), + ] self.mysql_router.configure_db_router(_user, _host, prefix=_prefix) - self.set_remote.assert_called_once_with(**_expected) - self.set_local.assert_called_once_with(**_expected) + self._set_local.has_calls(calls) + self.assertTrue(all(self._published_data[k] == _expected[k] + for k in _expected.keys())) self.set_prefix.assert_called_once() def test_configure_proxy_db(self): @@ -250,8 +257,15 @@ class TestMySQLRouterRequires(unittest.TestCase): "{}_username".format(_prefix): _user, "{}_hostname".format(_prefix): _host} self.mysql_router.configure_proxy_db(_db, _user, _host, prefix=_prefix) - self.set_remote.assert_called_once_with(**_expected) - self.set_local.assert_called_once_with(**_expected) + calls = [ + mock.call('prefix_database', _db), + mock.call('prefix_username', _user), + mock.call('prefix_hostname', _host) + ] + self._set_local.has_calls(calls) + + self.assertTrue(all(self._published_data[k] == _expected[k] + for k in _expected.keys())) self.set_prefix.assert_called_once() def test_get_prefix(self): @@ -264,22 +278,21 @@ class TestMySQLRouterRequires(unittest.TestCase): # First _prefix = "prefix" self.mysql_router.set_prefix(_prefix) - self.set_local.assert_called_once_with("prefixes", [_prefix]) + self._set_local.assert_called_once_with("prefixes", [_prefix]) # More than one - self.set_local.reset_mock() + self._set_local.reset_mock() self._local_data = {"prefixes": [_prefix]} _second = "secondprefix" self.mysql_router.set_prefix(_second) - self.set_local.assert_called_once_with("prefixes", [_prefix, _second]) + self._set_local.assert_called_once_with("prefixes", [_prefix, _second]) - @mock.patch.object(requires.hookenv, 'related_units') - def test_ly_departed(self, related_units): + def test_ly_departed(self): self._local_data = {"prefixes": ["myprefix"]} + self.patch_mysql_router('ssl_ca', "fake_ca") - related_units.return_value = ['unit/1'] self.mysql_router.departed() - self.assertFalse(self.set_local.called) + self.assertFalse(self._set_local.called) - related_units.return_value = [] + self.mysql_router.relations[0].units = [] self.mysql_router.departed() - self.set_local.assert_called_once_with("prefixes", []) + self._set_local.assert_called_once_with("prefixes", [])