diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index de1d6d035056..0c57150d95c5 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -358,7 +358,7 @@ def _sync_fixed_ips(context, project_id, user_id): def _sync_security_groups(context, project_id, user_id): return dict(security_groups=_security_group_count_by_project_and_user( - context, project_id, user_id, context.session)) + context, project_id, user_id)) def _sync_server_groups(context, project_id, user_id): @@ -1617,11 +1617,7 @@ def instance_create(context, values): values - dict containing column values. """ - # NOTE(rpodolyaka): create the default security group, if it doesn't exist. - # This must be done in a separate transaction, so that this one is not - # aborted in case a concurrent one succeeds first and the unique constraint - # for security group names is violated by a concurrent INSERT - security_group_ensure_default(context, context.session) + security_group_ensure_default(context) values = values.copy() values['metadata'] = _metadata_refs( @@ -1650,15 +1646,14 @@ def instance_create(context, values): def _get_sec_group_models(security_groups): models = [] - default_group = _security_group_ensure_default(context, - context.session) + default_group = _security_group_ensure_default(context) if 'default' in security_groups: models.append(default_group) # Generate a new list, so we don't modify the original security_groups = [x for x in security_groups if x != 'default'] if security_groups: models.extend(_security_group_get_by_names(context, - context.session, context.project_id, security_groups)) + context.project_id, security_groups)) return models if 'hostname' in values: @@ -4036,14 +4031,18 @@ def block_device_mapping_destroy_by_instance_and_device(context, instance_uuid, ################### -def _security_group_create(context, values, session=None): + +@require_context +@main_context_manager.writer +def security_group_create(context, values): security_group_ref = models.SecurityGroup() # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception # once save() is called. This will get cleaned up in next orm pass. security_group_ref.rules security_group_ref.update(values) try: - security_group_ref.save(session=session) + with main_context_manager.writer.savepoint.using(context): + security_group_ref.save(context.session) except db_exc.DBDuplicateEntry: raise exception.SecurityGroupExists( project_id=values['project_id'], @@ -4051,21 +4050,21 @@ def _security_group_create(context, values, session=None): return security_group_ref -def _security_group_get_query(context, session=None, read_deleted=None, +def _security_group_get_query(context, read_deleted=None, project_only=False, join_rules=True): - query = model_query(context, models.SecurityGroup, session=session, + query = model_query(context, models.SecurityGroup, read_deleted=read_deleted, project_only=project_only) if join_rules: query = query.options(joinedload_all('rules.grantee_group')) return query -def _security_group_get_by_names(context, session, project_id, group_names): +def _security_group_get_by_names(context, project_id, group_names): """Get security group models for a project by a list of names. Raise SecurityGroupNotFoundForProject for a name not found. """ - query = _security_group_get_query(context, session=session, - read_deleted="no", join_rules=False).\ + query = _security_group_get_query(context, read_deleted="no", + join_rules=False).\ filter_by(project_id=project_id).\ filter(models.SecurityGroup.name.in_(group_names)) sg_models = query.all() @@ -4081,11 +4080,13 @@ def _security_group_get_by_names(context, session, project_id, group_names): @require_context +@main_context_manager.reader def security_group_get_all(context): return _security_group_get_query(context).all() @require_context +@main_context_manager.reader def security_group_get(context, security_group_id, columns_to_join=None): query = _security_group_get_query(context, project_only=True).\ filter_by(id=security_group_id) @@ -4105,6 +4106,7 @@ def security_group_get(context, security_group_id, columns_to_join=None): @require_context +@main_context_manager.reader def security_group_get_by_name(context, project_id, group_name, columns_to_join=None): query = _security_group_get_query(context, @@ -4127,6 +4129,7 @@ def security_group_get_by_name(context, project_id, group_name, @require_context +@main_context_manager.reader def security_group_get_by_project(context, project_id): return _security_group_get_query(context, read_deleted="no").\ filter_by(project_id=project_id).\ @@ -4134,6 +4137,7 @@ def security_group_get_by_project(context, project_id): @require_context +@main_context_manager.reader def security_group_get_by_instance(context, instance_uuid): return _security_group_get_query(context, read_deleted="no").\ join(models.SecurityGroup.instances).\ @@ -4142,64 +4146,63 @@ def security_group_get_by_instance(context, instance_uuid): @require_context +@main_context_manager.reader def security_group_in_use(context, group_id): - session = get_session() - with session.begin(): - # Are there any instances that haven't been deleted - # that include this group? - inst_assoc = model_query(context, - models.SecurityGroupInstanceAssociation, - read_deleted="no", session=session).\ - filter_by(security_group_id=group_id).\ - all() - for ia in inst_assoc: - num_instances = model_query(context, models.Instance, - session=session, read_deleted="no").\ - filter_by(uuid=ia.instance_uuid).\ - count() - if num_instances: - return True + # Are there any instances that haven't been deleted + # that include this group? + inst_assoc = model_query(context, + models.SecurityGroupInstanceAssociation, + read_deleted="no").\ + filter_by(security_group_id=group_id).\ + all() + for ia in inst_assoc: + num_instances = model_query(context, models.Instance, + read_deleted="no").\ + filter_by(uuid=ia.instance_uuid).\ + count() + if num_instances: + return True return False @require_context -def security_group_create(context, values): - return _security_group_create(context, values) - - -@require_context +@main_context_manager.writer def security_group_update(context, security_group_id, values, columns_to_join=None): - session = get_session() - with session.begin(): - query = model_query(context, models.SecurityGroup, - session=session).filter_by(id=security_group_id) - if columns_to_join: - for column in columns_to_join: - query = query.options(joinedload_all(column)) - security_group_ref = query.first() + query = model_query(context, models.SecurityGroup).filter_by( + id=security_group_id) + if columns_to_join: + for column in columns_to_join: + query = query.options(joinedload_all(column)) + security_group_ref = query.first() - if not security_group_ref: - raise exception.SecurityGroupNotFound( - security_group_id=security_group_id) - security_group_ref.update(values) - name = security_group_ref['name'] - project_id = security_group_ref['project_id'] - try: - security_group_ref.save(session=session) - except db_exc.DBDuplicateEntry: - raise exception.SecurityGroupExists( - project_id=project_id, - security_group_name=name) + if not security_group_ref: + raise exception.SecurityGroupNotFound( + security_group_id=security_group_id) + security_group_ref.update(values) + name = security_group_ref['name'] + project_id = security_group_ref['project_id'] + try: + security_group_ref.save(context.session) + except db_exc.DBDuplicateEntry: + raise exception.SecurityGroupExists( + project_id=project_id, + security_group_name=name) return security_group_ref -def security_group_ensure_default(context, session=None): +def security_group_ensure_default(context): """Ensure default security group exists for a project_id.""" try: - return _security_group_ensure_default(context, session=session) + # NOTE(rpodolyaka): create the default security group, if it doesn't + # exist. This must be done in a separate transaction, so that + # this one is not aborted in case a concurrent one succeeds first + # and the unique constraint for security group names is violated + # by a concurrent INSERT + with main_context_manager.writer.independent.using(context): + return _security_group_ensure_default(context) except exception.SecurityGroupExists: # NOTE(rpodolyaka): a concurrent transaction has succeeded first, # suppress the error and proceed @@ -4207,83 +4210,67 @@ def security_group_ensure_default(context, session=None): 'default') -def _security_group_ensure_default(context, session=None): - if session is None: - session = get_session() +@main_context_manager.writer +def _security_group_ensure_default(context): + try: + default_group = _security_group_get_by_names(context, + context.project_id, + ['default'])[0] + except exception.NotFound: + values = {'name': 'default', + 'description': 'default', + 'user_id': context.user_id, + 'project_id': context.project_id} + default_group = security_group_create(context, values) + usage = model_query(context, models.QuotaUsage, read_deleted="no").\ + filter_by(project_id=context.project_id).\ + filter_by(user_id=context.user_id).\ + filter_by(resource='security_groups') + # Create quota usage for auto created default security group + if not usage.first(): + _quota_usage_create(context.project_id, + context.user_id, + 'security_groups', + 1, 0, + CONF.until_refresh, + context.session) + else: + usage.update({'in_use': int(usage.first().in_use) + 1}) - with session.begin(subtransactions=True): - try: - default_group = _security_group_get_by_names(context, - session, - context.project_id, - ['default'])[0] - except exception.NotFound: - values = {'name': 'default', - 'description': 'default', - 'user_id': context.user_id, - 'project_id': context.project_id} - default_group = _security_group_create(context, values, - session=session) - usage = model_query(context, models.QuotaUsage, - read_deleted="no", session=session).\ - filter_by(project_id=context.project_id).\ - filter_by(user_id=context.user_id).\ - filter_by(resource='security_groups') - # Create quota usage for auto created default security group - if not usage.first(): - _quota_usage_create(context.project_id, - context.user_id, - 'security_groups', - 1, 0, - CONF.until_refresh, - session) - else: - usage.update({'in_use': int(usage.first().in_use) + 1}) - - default_rules = _security_group_rule_get_default_query(context, - session=session).all() - for default_rule in default_rules: - # This is suboptimal, it should be programmatic to know - # the values of the default_rule - rule_values = {'protocol': default_rule.protocol, - 'from_port': default_rule.from_port, - 'to_port': default_rule.to_port, - 'cidr': default_rule.cidr, - 'parent_group_id': default_group.id, - } - _security_group_rule_create(context, - rule_values, - session=session) - return default_group + default_rules = _security_group_rule_get_default_query(context).all() + for default_rule in default_rules: + # This is suboptimal, it should be programmatic to know + # the values of the default_rule + rule_values = {'protocol': default_rule.protocol, + 'from_port': default_rule.from_port, + 'to_port': default_rule.to_port, + 'cidr': default_rule.cidr, + 'parent_group_id': default_group.id, + } + _security_group_rule_create(context, rule_values) + return default_group @require_context +@main_context_manager.writer def security_group_destroy(context, security_group_id): - session = get_session() - with session.begin(): - model_query(context, models.SecurityGroup, - session=session).\ - filter_by(id=security_group_id).\ - soft_delete() - model_query(context, models.SecurityGroupInstanceAssociation, - session=session).\ - filter_by(security_group_id=security_group_id).\ - soft_delete() - model_query(context, models.SecurityGroupIngressRule, - session=session).\ - filter_by(group_id=security_group_id).\ - soft_delete() - model_query(context, models.SecurityGroupIngressRule, - session=session).\ - filter_by(parent_group_id=security_group_id).\ - soft_delete() + model_query(context, models.SecurityGroup).\ + filter_by(id=security_group_id).\ + soft_delete() + model_query(context, models.SecurityGroupInstanceAssociation).\ + filter_by(security_group_id=security_group_id).\ + soft_delete() + model_query(context, models.SecurityGroupIngressRule).\ + filter_by(group_id=security_group_id).\ + soft_delete() + model_query(context, models.SecurityGroupIngressRule).\ + filter_by(parent_group_id=security_group_id).\ + soft_delete() -def _security_group_count_by_project_and_user(context, project_id, user_id, - session=None): +def _security_group_count_by_project_and_user(context, project_id, user_id): nova.context.authorize_project_context(context, project_id) - return model_query(context, models.SecurityGroup, read_deleted="no", - session=session).\ + return model_query(context, models.SecurityGroup, read_deleted="no").\ filter_by(project_id=project_id).\ filter_by(user_id=user_id).\ count() @@ -4292,19 +4279,19 @@ def _security_group_count_by_project_and_user(context, project_id, user_id, ################### -def _security_group_rule_create(context, values, session=None): +def _security_group_rule_create(context, values): security_group_rule_ref = models.SecurityGroupIngressRule() security_group_rule_ref.update(values) - security_group_rule_ref.save(session=session) + security_group_rule_ref.save(context.session) return security_group_rule_ref -def _security_group_rule_get_query(context, session=None): - return model_query(context, models.SecurityGroupIngressRule, - session=session) +def _security_group_rule_get_query(context): + return model_query(context, models.SecurityGroupIngressRule) @require_context +@main_context_manager.reader def security_group_rule_get(context, security_group_rule_id): result = (_security_group_rule_get_query(context). filter_by(id=security_group_rule_id). @@ -4318,6 +4305,7 @@ def security_group_rule_get(context, security_group_rule_id): @require_context +@main_context_manager.reader def security_group_rule_get_by_security_group(context, security_group_id, columns_to_join=None): if columns_to_join is None: @@ -4331,6 +4319,7 @@ def security_group_rule_get_by_security_group(context, security_group_id, @require_context +@main_context_manager.reader def security_group_rule_get_by_instance(context, instance_uuid): return (_security_group_rule_get_query(context). join('parent_group', 'instances'). @@ -4340,11 +4329,13 @@ def security_group_rule_get_by_instance(context, instance_uuid): @require_context +@main_context_manager.writer def security_group_rule_create(context, values): return _security_group_rule_create(context, values) @require_context +@main_context_manager.writer def security_group_rule_destroy(context, security_group_rule_id): count = (_security_group_rule_get_query(context). filter_by(id=security_group_rule_id). @@ -4355,22 +4346,23 @@ def security_group_rule_destroy(context, security_group_rule_id): @require_context +@main_context_manager.reader def security_group_rule_count_by_group(context, security_group_id): return (model_query(context, models.SecurityGroupIngressRule, read_deleted="no"). filter_by(parent_group_id=security_group_id). count()) -# + ################### -def _security_group_rule_get_default_query(context, session=None): - return model_query(context, models.SecurityGroupIngressDefaultRule, - session=session) +def _security_group_rule_get_default_query(context): + return model_query(context, models.SecurityGroupIngressDefaultRule) @require_context +@main_context_manager.reader def security_group_default_rule_get(context, security_group_rule_default_id): result = _security_group_rule_get_default_query(context).\ filter_by(id=security_group_rule_default_id).\ @@ -4383,30 +4375,29 @@ def security_group_default_rule_get(context, security_group_rule_default_id): return result +@main_context_manager.writer def security_group_default_rule_destroy(context, security_group_rule_default_id): - session = get_session() - with session.begin(): - count = _security_group_rule_get_default_query(context, - session=session).\ - filter_by(id=security_group_rule_default_id).\ - soft_delete() - if count == 0: - raise exception.SecurityGroupDefaultRuleNotFound( - rule_id=security_group_rule_default_id) + count = _security_group_rule_get_default_query(context).\ + filter_by(id=security_group_rule_default_id).\ + soft_delete() + if count == 0: + raise exception.SecurityGroupDefaultRuleNotFound( + rule_id=security_group_rule_default_id) +@main_context_manager.writer def security_group_default_rule_create(context, values): security_group_default_rule_ref = models.SecurityGroupIngressDefaultRule() security_group_default_rule_ref.update(values) - security_group_default_rule_ref.save() + security_group_default_rule_ref.save(context.session) return security_group_default_rule_ref @require_context +@main_context_manager.reader def security_group_default_rule_list(context): - return _security_group_rule_get_default_query(context).\ - all() + return _security_group_rule_get_default_query(context).all() ################### diff --git a/nova/tests/unit/compute/test_compute.py b/nova/tests/unit/compute/test_compute.py index 0cbb724253af..53965c0a591d 100644 --- a/nova/tests/unit/compute/test_compute.py +++ b/nova/tests/unit/compute/test_compute.py @@ -7844,7 +7844,8 @@ class ComputeAPITestCase(BaseTestCase): security_group=['testgroup']) db.instance_destroy(self.context, ref[0]['uuid']) - group = db.security_group_get(self.context, group['id']) + group = db.security_group_get(self.context, group['id'], + columns_to_join=['instances']) self.assertEqual(0, len(group['instances'])) def test_destroy_security_group_disassociates_instances(self): @@ -7860,7 +7861,8 @@ class ComputeAPITestCase(BaseTestCase): db.security_group_destroy(self.context, group['id']) admin_deleted_context = context.get_admin_context( read_deleted="only") - group = db.security_group_get(admin_deleted_context, group['id']) + group = db.security_group_get(admin_deleted_context, group['id'], + columns_to_join=['instances']) self.assertEqual(0, len(group['instances'])) def _test_rebuild(self, vm_state):