diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 2cce69c5a69..17a30bb17ba 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -44,6 +44,7 @@ from neutron.extensions import security_groups_default_rules as \ from neutron.extensions import securitygroup as ext_sg from neutron.objects import base as base_obj from neutron.objects import ports as port_obj +from neutron.objects import rbac_db as rbac_db_obj from neutron.objects import securitygroup as sg_obj from neutron.objects import securitygroup_default_rules as sg_default_rules_obj from neutron import quota @@ -131,8 +132,8 @@ class SecurityGroupDbMixin( # be used here otherwise, SG will not be found and error 500 will # be returned through the API get_context = context.elevated() if default_sg else context - sg = sg_obj.SecurityGroup.get_object(get_context, id=sg.id) - secgroup_dict = self._make_security_group_dict(sg) + sg = self._get_security_group(get_context, sg.id) + secgroup_dict = self._make_security_group_dict(context, sg) self._registry_publish(resources.SECURITY_GROUP, events.PRECOMMIT_CREATE, exc_cls=ext_sg.SecurityGroupConflict, @@ -174,9 +175,10 @@ class SecurityGroupDbMixin( sg_objs = sg_obj.SecurityGroup.get_objects( context, _pager=pager, validate_filters=False, - fields=fields, **filters) + fields=fields, return_db_obj=True, **filters) - return [self._make_security_group_dict(obj, fields) for obj in sg_objs] + return [self._make_security_group_dict(context, obj, fields) + for obj in sg_objs] @db_api.retry_if_session_inactive() def get_security_groups_count(self, context, filters=None): @@ -195,8 +197,8 @@ class SecurityGroupDbMixin( try: with db_api.CONTEXT_READER.using(context): - ret = self._make_security_group_dict(self._get_security_group( - context, id, fields=fields), fields) + sg = self._get_security_group(context, id, fields=fields) + ret = self._make_security_group_dict(context, sg, fields) if (fields is None or len(fields) == 0 or 'security_group_rules' in fields): rules = self.get_security_group_rules( @@ -209,12 +211,21 @@ class SecurityGroupDbMixin( context.tenant_id = tmp_context_tenant_id return ret - def _get_security_group(self, context, id, fields=None): - sg = sg_obj.SecurityGroup.get_object(context, fields=fields, id=id) + @staticmethod + def _get_security_group(context, _id, fields=None): + sg = sg_obj.SecurityGroup.get_object(context, fields=fields, id=_id) if sg is None: - raise ext_sg.SecurityGroupNotFound(id=id) + raise ext_sg.SecurityGroupNotFound(id=_id) return sg + @staticmethod + def _get_security_group_db(context, _id, fields=None): + sg_db = sg_obj.SecurityGroup.get_object( + context, fields=fields, id=_id, return_db_obj=True) + if sg_db is None: + raise ext_sg.SecurityGroupNotFound(id=_id) + return sg_db + def _check_security_group(self, context, id, tenant_id=None): if tenant_id: tmp_context_tenant_id = context.tenant_id @@ -258,7 +269,7 @@ class SecurityGroupDbMixin( # consistency with deleted rules sg = self._get_security_group(context, id) sgr_ids = [r['id'] for r in sg.rules] - sec_group = self._make_security_group_dict(sg) + sec_group = self._make_security_group_dict(context, sg) self._registry_publish(resources.SECURITY_GROUP, events.PRECOMMIT_DELETE, exc_cls=ext_sg.SecurityGroupInUse, @@ -282,8 +293,8 @@ class SecurityGroupDbMixin( if 'stateful' in s: with db_api.CONTEXT_READER.using(context): - sg = self._get_security_group(context, id) - if s['stateful'] != sg['stateful']: + sg_db = self._get_security_group_db(context, id) + if s['stateful'] != sg_db['stateful']: filters = {'security_group_id': [id]} ports = self._get_port_security_group_bindings(context, filters) @@ -299,11 +310,11 @@ class SecurityGroupDbMixin( sg = self._get_security_group(context, id) if sg.name == 'default' and 'name' in s: raise ext_sg.SecurityGroupCannotUpdateDefault() - sg_dict = self._make_security_group_dict(sg) + sg_dict = self._make_security_group_dict(context, sg) original_security_group = sg_dict sg.update_fields(s) sg.update() - sg_dict = self._make_security_group_dict(sg) + sg_dict = self._make_security_group_dict(context, sg) self._registry_publish( resources.SECURITY_GROUP, events.PRECOMMIT_UPDATE, @@ -320,24 +331,33 @@ class SecurityGroupDbMixin( return sg_dict - def _make_security_group_dict(self, security_group, fields=None): + def _make_security_group_dict(self, context, security_group, fields=None): + """Return the security group in a dictionary + + :param context: Neutron API request context. + :param security_group: DB object or OVO of the security group. + :param fields: list of fields to filter the returned dictionary. + :return: a dictionary with the security group definition. + """ + rules = security_group.rules or [] + if isinstance(security_group, sg_obj.SecurityGroup): + shared = security_group.shared + security_group = security_group.db_obj + else: + rbac_entries = security_group['rbac_entries'] + shared = rbac_db_obj.RbacNeutronDbObjectMixin.is_network_shared( + context, rbac_entries) res = {'id': security_group['id'], 'name': security_group['name'], 'stateful': security_group['stateful'], 'tenant_id': security_group['tenant_id'], 'description': security_group['description'], - 'standard_attr_id': security_group.db_obj.standard_attr_id, - 'shared': security_group['shared'], + 'standard_attr_id': security_group.standard_attr_id, + 'shared': shared, + 'security_group_rules': [self._make_security_group_rule_dict(r) + for r in rules], } - if security_group.rules: - res['security_group_rules'] = [ - self._make_security_group_rule_dict(r) - for r in security_group.rules - ] - else: - res['security_group_rules'] = [] - resource_extend.apply_funcs(ext_sg.SECURITYGROUPS, res, - security_group.db_obj) + resource_extend.apply_funcs(ext_sg.SECURITYGROUPS, res, security_group) return db_utils.resource_fields(res, fields) @staticmethod diff --git a/neutron/objects/base.py b/neutron/objects/base.py index c2c9601ef6e..acde5c1b04b 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -608,7 +608,7 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject): return db_api.CONTEXT_READER.using(context) @classmethod - def get_object(cls, context, fields=None, **kwargs): + def get_object(cls, context, fields=None, return_db_obj=False, **kwargs): """Fetch a single object Return the first result of given context or None if the result doesn't @@ -620,6 +620,8 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject): avoid loading synthetic fields when possible, and does not affect db queries. Default is None, which is the same as []. Example: ['id', 'name'] + :param return_db_obj: return the DB model object instead of loading + the OVO; that could save some time. :param kwargs: multiple keys defined by key=value pairs :return: single object of NeutronDbObject class or None """ @@ -633,6 +635,8 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject): with cls.db_context_reader(context): db_obj = obj_db_api.get_object( cls, context, **cls.modify_fields_to_db(kwargs)) + if return_db_obj: + return db_obj if db_obj: return cls._load_object(context, db_obj, fields=fields)