Ensure unit acl before raising 'available' state

Check the allowed_units list past back by the db before raising the
'available' state to ensure charms do not try to connect to the db
before they are allowed.

Change-Id: I17228ba5d82249f0a5baffab1c2d200963b55b10
Closes-Bug: #1861665
This commit is contained in:
Liam Young 2020-02-03 09:41:09 +00:00
parent 6c1b8d4192
commit 1099a6b3ec
2 changed files with 70 additions and 1 deletions

View File

@ -24,7 +24,7 @@ class MySQLSharedRequires(RelationBase):
self.remove_state('{relation_name}.available.access_network')
self.remove_state('{relation_name}.available.ssl')
else:
if self.base_data_complete():
if self.base_data_complete() and self.unit_allowed_all_dbs():
self.set_state('{relation_name}.available')
if self.access_network_data_complete():
self.set_state('{relation_name}.available.access_network')
@ -152,6 +152,44 @@ class MySQLSharedRequires(RelationBase):
return True
return False
def unit_allowed_db(self, prefix=None):
""""
Check unit can access requested database.
:param prefix: Prefix used to distinguish multiple db requests.
:type prefix: str
:returns: Whether db acl has been setup.
:rtype: bool
"""
allowed = False
allowed_units = self.allowed_units(prefix=prefix) or ''
hookenv.log("Checking {} is in {}".format(
hookenv.local_unit(),
allowed_units.split()))
if allowed_units and hookenv.local_unit() in allowed_units.split():
allowed = True
hookenv.log("Unit allowed: {}".format(allowed))
return allowed
def unit_allowed_all_dbs(self):
""""
Check unit can access all requested databases.
:returns: Whether db acl has been setup for all dbs.
:rtype: bool
"""
if self.get_prefixes():
_allowed = [self.unit_allowed_db(prefix=p)
for p in self.get_prefixes()]
else:
_allowed = [self.unit_allowed_db()]
hookenv.log("Allowed: {}".format(_allowed))
if all(_allowed):
hookenv.log("Returning unit_allowed_all_dbs True")
return True
hookenv.log("Returning unit_allowed_all_dbs False")
return False
def access_network_data_complete(self):
"""
Check if optional access network data provided by mysql is complete.

View File

@ -116,6 +116,7 @@ class TestMySQLSharedRequires(unittest.TestCase):
self.assertEqual(hook_patterns[k], v['args'])
def test_changed_available(self):
self.patch_mysql_shared('unit_allowed_all_dbs', True)
self.patch_mysql_shared('base_data_complete', True)
self.patch_mysql_shared('access_network_data_complete', True)
self.patch_mysql_shared('ssl_data_complete', True)
@ -258,3 +259,33 @@ class TestMySQLSharedRequires(unittest.TestCase):
_second = "secondprefix"
self.mysql_shared.set_prefix(_second)
self.set_local.assert_called_once_with("prefixes", [_prefix, _second])
@mock.patch.object(requires.hookenv, 'log')
@mock.patch.object(requires.hookenv, 'local_unit')
def test_unit_allowed_db(self, local_unit, log):
self._remote_data = {'allowed_units': 'unit/1 unit/3'}
local_unit.return_value = 'unit/1'
self.assertTrue(self.mysql_shared.unit_allowed_db())
local_unit.return_value = 'unit/2'
self.assertFalse(self.mysql_shared.unit_allowed_db())
@mock.patch.object(requires.hookenv, 'log')
@mock.patch.object(requires.hookenv, 'local_unit')
def test_unit_allowed_db_prefix(self, local_unit, log):
self._remote_data = {'bob_allowed_units': 'unit/1 unit/3'}
local_unit.return_value = 'unit/1'
self.assertTrue(self.mysql_shared.unit_allowed_db(prefix='bob'))
self.assertFalse(self.mysql_shared.unit_allowed_db())
self.assertFalse(self.mysql_shared.unit_allowed_db(prefix='flump'))
@mock.patch.object(requires.hookenv, 'log')
@mock.patch.object(requires.hookenv, 'local_unit')
def test_unit_allowed_all_dbs(self, local_unit, log):
local_unit.return_value = 'unit/1'
self._local_data = {"prefixes": ['prefix1', 'prefix2']}
self._remote_data = {'prefix1_allowed_units': 'unit/1 unit/3'}
self.assertFalse(self.mysql_shared.unit_allowed_all_dbs())
self._remote_data = {
'prefix1_allowed_units': 'unit/1 unit/3',
'prefix2_allowed_units': 'unit/1 unit/3'}
self.assertTrue(self.mysql_shared.unit_allowed_all_dbs())