Ensure unit acl before raising 'available' state
Check the allowed_units list past back by the db before raising the 'available' state to ensure charms do not try to connect to the db before they are allowed. Change-Id: I17228ba5d82249f0a5baffab1c2d200963b55b10 Closes-Bug: #1861665
This commit is contained in:
parent
6c1b8d4192
commit
1099a6b3ec
40
requires.py
40
requires.py
|
@ -24,7 +24,7 @@ class MySQLSharedRequires(RelationBase):
|
|||
self.remove_state('{relation_name}.available.access_network')
|
||||
self.remove_state('{relation_name}.available.ssl')
|
||||
else:
|
||||
if self.base_data_complete():
|
||||
if self.base_data_complete() and self.unit_allowed_all_dbs():
|
||||
self.set_state('{relation_name}.available')
|
||||
if self.access_network_data_complete():
|
||||
self.set_state('{relation_name}.available.access_network')
|
||||
|
@ -152,6 +152,44 @@ class MySQLSharedRequires(RelationBase):
|
|||
return True
|
||||
return False
|
||||
|
||||
def unit_allowed_db(self, prefix=None):
|
||||
""""
|
||||
Check unit can access requested database.
|
||||
|
||||
:param prefix: Prefix used to distinguish multiple db requests.
|
||||
:type prefix: str
|
||||
:returns: Whether db acl has been setup.
|
||||
:rtype: bool
|
||||
"""
|
||||
allowed = False
|
||||
allowed_units = self.allowed_units(prefix=prefix) or ''
|
||||
hookenv.log("Checking {} is in {}".format(
|
||||
hookenv.local_unit(),
|
||||
allowed_units.split()))
|
||||
if allowed_units and hookenv.local_unit() in allowed_units.split():
|
||||
allowed = True
|
||||
hookenv.log("Unit allowed: {}".format(allowed))
|
||||
return allowed
|
||||
|
||||
def unit_allowed_all_dbs(self):
|
||||
""""
|
||||
Check unit can access all requested databases.
|
||||
|
||||
:returns: Whether db acl has been setup for all dbs.
|
||||
:rtype: bool
|
||||
"""
|
||||
if self.get_prefixes():
|
||||
_allowed = [self.unit_allowed_db(prefix=p)
|
||||
for p in self.get_prefixes()]
|
||||
else:
|
||||
_allowed = [self.unit_allowed_db()]
|
||||
hookenv.log("Allowed: {}".format(_allowed))
|
||||
if all(_allowed):
|
||||
hookenv.log("Returning unit_allowed_all_dbs True")
|
||||
return True
|
||||
hookenv.log("Returning unit_allowed_all_dbs False")
|
||||
return False
|
||||
|
||||
def access_network_data_complete(self):
|
||||
"""
|
||||
Check if optional access network data provided by mysql is complete.
|
||||
|
|
|
@ -116,6 +116,7 @@ class TestMySQLSharedRequires(unittest.TestCase):
|
|||
self.assertEqual(hook_patterns[k], v['args'])
|
||||
|
||||
def test_changed_available(self):
|
||||
self.patch_mysql_shared('unit_allowed_all_dbs', True)
|
||||
self.patch_mysql_shared('base_data_complete', True)
|
||||
self.patch_mysql_shared('access_network_data_complete', True)
|
||||
self.patch_mysql_shared('ssl_data_complete', True)
|
||||
|
@ -258,3 +259,33 @@ class TestMySQLSharedRequires(unittest.TestCase):
|
|||
_second = "secondprefix"
|
||||
self.mysql_shared.set_prefix(_second)
|
||||
self.set_local.assert_called_once_with("prefixes", [_prefix, _second])
|
||||
|
||||
@mock.patch.object(requires.hookenv, 'log')
|
||||
@mock.patch.object(requires.hookenv, 'local_unit')
|
||||
def test_unit_allowed_db(self, local_unit, log):
|
||||
self._remote_data = {'allowed_units': 'unit/1 unit/3'}
|
||||
local_unit.return_value = 'unit/1'
|
||||
self.assertTrue(self.mysql_shared.unit_allowed_db())
|
||||
local_unit.return_value = 'unit/2'
|
||||
self.assertFalse(self.mysql_shared.unit_allowed_db())
|
||||
|
||||
@mock.patch.object(requires.hookenv, 'log')
|
||||
@mock.patch.object(requires.hookenv, 'local_unit')
|
||||
def test_unit_allowed_db_prefix(self, local_unit, log):
|
||||
self._remote_data = {'bob_allowed_units': 'unit/1 unit/3'}
|
||||
local_unit.return_value = 'unit/1'
|
||||
self.assertTrue(self.mysql_shared.unit_allowed_db(prefix='bob'))
|
||||
self.assertFalse(self.mysql_shared.unit_allowed_db())
|
||||
self.assertFalse(self.mysql_shared.unit_allowed_db(prefix='flump'))
|
||||
|
||||
@mock.patch.object(requires.hookenv, 'log')
|
||||
@mock.patch.object(requires.hookenv, 'local_unit')
|
||||
def test_unit_allowed_all_dbs(self, local_unit, log):
|
||||
local_unit.return_value = 'unit/1'
|
||||
self._local_data = {"prefixes": ['prefix1', 'prefix2']}
|
||||
self._remote_data = {'prefix1_allowed_units': 'unit/1 unit/3'}
|
||||
self.assertFalse(self.mysql_shared.unit_allowed_all_dbs())
|
||||
self._remote_data = {
|
||||
'prefix1_allowed_units': 'unit/1 unit/3',
|
||||
'prefix2_allowed_units': 'unit/1 unit/3'}
|
||||
self.assertTrue(self.mysql_shared.unit_allowed_all_dbs())
|
||||
|
|
Loading…
Reference in New Issue