Hierarchical contracts implementation

At this time, we only support one level of hierarchy,
although one parent contract can have multiple children.

Change-Id: I09bcef301ccc95e256a0850320680f720b491520
This commit is contained in:
Ivar Lazzaro
2014-10-09 15:09:55 -07:00
parent 78e9f006b9
commit ec59849259
4 changed files with 91 additions and 5 deletions

View File

@@ -390,10 +390,22 @@ class GroupPolicyDbPlugin(gpolicy.GroupPolicyPluginBase,
if not child_id_list:
contract_db.child_contracts = []
return
if contract_db['parent_id']:
# Only one hierarchy level allowed for now
raise gpolicy.ThreeLevelContractHierarchyNotSupported(
contract_id=contract_db['id'])
with context.session.begin(subtransactions=True):
# We will first check if the new list of contracts is valid
contracts_in_db = self._validate_contract_list(
context, child_id_list)
for child in contracts_in_db:
if (child['child_contracts'] or
child['id'] == contract_db['id']):
# Only one level contract relationship supported for now
# No loops allowed
raise gpolicy.BadContractRelationship(
parent_id=contract_db['id'], child_id=child['id'])
# New list of child contracts is valid so we will first reset the
# existing list and then add each contract.
# Note that the list could be empty in which case we interpret
@@ -538,12 +550,18 @@ class GroupPolicyDbPlugin(gpolicy.GroupPolicyPluginBase,
else:
res['parent_id'] = None
ctx = context.get_admin_context()
with ctx.session.begin(subtransactions=True):
filters = {'parent_id': [ct['id']]}
child_contracts_in_db = self._get_collection_query(ctx, Contract,
filters=filters)
if 'child_contracts' in ct:
# They have been updated
res['child_contracts'] = [child_ct['id']
for child_ct in child_contracts_in_db]
for child_ct in ct['child_contracts']]
else:
with ctx.session.begin(subtransactions=True):
filters = {'parent_id': [ct['id']]}
child_contracts_in_db = self._get_collection_query(
ctx, Contract, filters=filters)
res['child_contracts'] = [child_ct['id']
for child_ct in
child_contracts_in_db]
res['policy_rules'] = [pr['policy_rule_id']
for pr in ct['policy_rules']]

View File

@@ -77,6 +77,18 @@ class ContractNotFound(nexc.NotFound):
message = _("Contract %(contract_id)s could not be found")
class BadContractRelationship(nexc.BadRequest):
message = _("Contract %(parent_id)s is an invalid parent for "
"%(child_id)s, make sure that child contract has no "
"children, or that you are not creating a relationship loop")
class ThreeLevelContractHierarchyNotSupported(nexc.BadRequest):
message = _("Can't add children to contract %(contract_id)s "
"which already has a parent. Only one level of contract "
"hierarchy supported.")
class GroupPolicyInvalidPortValue(nexc.InvalidInput):
message = _("Invalid value for port %(port)s")

View File

@@ -390,12 +390,21 @@ class ResourceMappingDriver(api.PolicyDriver):
@log.log
def update_contract_postcommit(self, context):
# Update contract rules
old_rules = set(context.original['policy_rules'])
new_rules = set(context.current['policy_rules'])
to_add = new_rules - old_rules
to_remove = old_rules - new_rules
self._remove_contract_rules(context, context.current, to_remove)
self._apply_contract_rules(context, context.current, to_add)
# Update children contraint
to_recompute = (set(context.original['child_contracts']) ^
set(context.current['child_contracts']))
self._recompute_contracts(context, to_recompute)
if to_add or to_remove:
to_recompute = (set(context.original['child_contracts']) &
set(context.current['child_contracts']))
self._recompute_contracts(context, to_recompute)
@log.log
def delete_contract_precommit(self, context):
@@ -957,12 +966,33 @@ class ResourceMappingDriver(api.PolicyDriver):
'0.0.0.0/0', unset=unset)
def _apply_contract_rules(self, context, contract, policy_rules):
if contract['parent_id']:
parent = context._plugin.get_contract(
context._plugin_context, contract['parent_id'])
policy_rules = policy_rules & set(parent['policy_rules'])
# Don't add rules unallowed by the parent
self._manage_contract_rules(context, contract, policy_rules)
def _remove_contract_rules(self, context, contract, policy_rules):
self._manage_contract_rules(context, contract, policy_rules,
unset=True)
def _recompute_contracts(self, context, children):
# Rules in child but not in parent shall be removed
# Child rules will be set after being filtered by the parent
for child in children:
child = context._plugin.get_contract(
context._plugin_context, child)
child_rules = set(child['policy_rules'])
if child['parent_id']:
parent = context._plugin.get_contract(
context._plugin_context, child['parent_id'])
parent_rules = set(parent['policy_rules'])
self._remove_contract_rules(context, child,
child_rules - parent_rules)
# Old parent may have filtered some rules, need to add them again
self._apply_contract_rules(context, child, child_rules)
def _ensure_default_security_group(self, plugin_context, tenant_id):
filters = {'name': ['gbp_default'], 'tenant_id': [tenant_id]}
default_group = self._core_plugin.get_security_groups(

View File

@@ -943,3 +943,29 @@ class TestGroupResources(GroupPolicyDbTestCase):
self.assertEqual(res.status_int, webob.exc.HTTPNoContent.code)
self.assertRaises(gpolicy.ContractNotFound,
self.plugin.get_contract, ctx, ct_id)
def test_contract_one_hierarchy_children(self):
child = self.create_contract()['contract']
parent = self.create_contract(
child_contracts = [child['id']])['contract']
self.create_contract(
child_contracts = [parent['id']],
expected_res_status=webob.exc.HTTPBadRequest.code)
def test_contract_one_hierarchy_parent(self):
child = self.create_contract()['contract']
# parent
self.create_contract(
child_contracts = [child['id']])['contract']
nephew = self.create_contract()['contract']
data = {'contract': {'child_contracts': [nephew['id']]}}
req = self.new_update_request('contracts', data, child['id'])
res = req.get_response(self.ext_api)
self.assertEqual(res.status_int, webob.exc.HTTPBadRequest.code)
def test_contract_parent_no_loop(self):
ct = self.create_contract()['contract']
data = {'contract': {'child_contracts': [ct['id']]}}
req = self.new_update_request('contracts', data, ct['id'])
res = req.get_response(self.ext_api)
self.assertEqual(res.status_int, webob.exc.HTTPBadRequest.code)