From 3d88a48fbc77b0ff3d3d01877c5be0b8e01132b1 Mon Sep 17 00:00:00 2001 From: Kevin George Date: Wed, 12 Jun 2013 23:00:04 -0500 Subject: [PATCH] Security groups/rules implemented. Some serious refactoring. And some stylistic fixes. Implemented PUT on security groups. Proper quota management. Now with plugin_views. --- quark/db/api.py | 6 + quark/db/models.py | 45 ++++--- quark/drivers/base.py | 24 +++- quark/drivers/nvp_driver.py | 169 +++++++++++++++++++++-- quark/drivers/optimized_nvp_driver.py | 69 +++++++++- quark/plugin.py | 184 ++++++++++++++++++++++++-- quark/plugin_views.py | 9 +- quark/tests/test_nvp_driver.py | 2 +- quark/tests/test_quark_plugin.py | 5 +- 9 files changed, 460 insertions(+), 53 deletions(-) diff --git a/quark/db/api.py b/quark/db/api.py index fb5fdf2..1c42fa9 100644 --- a/quark/db/api.py +++ b/quark/db/api.py @@ -439,6 +439,12 @@ def security_group_create(context, **sec_group_dict): return new_group +def security_group_update(context, group, **kwargs): + group.update(kwargs) + context.session.add(group) + return group + + def security_group_delete(context, group): context.session.delete(group) diff --git a/quark/db/models.py b/quark/db/models.py index b742845..c5c6b9e 100644 --- a/quark/db/models.py +++ b/quark/db/models.py @@ -236,13 +236,13 @@ port_ip_association_table = sa.Table( sa.ForeignKey("quark_ip_addresses.id"))) -port_rule_association_table = sa.Table( - "quark_port_security_rule_associations", +port_group_association_table = sa.Table( + "quark_port_security_group_associations", BASEV2.metadata, sa.Column("port_id", sa.String(36), sa.ForeignKey("quark_ports.id")), - sa.Column("rule_id", sa.String(36), - sa.ForeignKey("quark_security_group_rule.id"))) + sa.Column("group_id", sa.String(36), + sa.ForeignKey("quark_security_groups.id"))) class SecurityGroupRule(BASEV2, models.HasId, models.HasTenant): @@ -255,7 +255,20 @@ class SecurityGroupRule(BASEV2, models.HasId, models.HasTenant): ethertype = sa.Column(sa.String(4), nullable=False) port_range_max = sa.Column(sa.Integer(), nullable=True) port_range_min = sa.Column(sa.Integer(), nullable=True) - protocol = sa.Column(sa.String(32), nullable=True) + protocol = sa.Column(sa.Integer(), nullable=True) + remote_ip_prefix = sa.Column(sa.String(22), nullable=True) + remote_group_id = sa.Column(sa.String(36), nullable=True) + + +class SecurityGroup(BASEV2, models.HasId, models.HasTenant): + __tablename__ = "quark_security_groups" + id = sa.Column(sa.String(36), primary_key=True) + name = sa.Column(sa.String(255), nullable=False) + description = sa.Column(sa.String(255), nullable=False) + join = "SecurityGroupRule.group_id==SecurityGroup.id" + rules = orm.relationship(SecurityGroupRule, backref='group', + cascade='delete', + primaryjoin=join) class Port(BASEV2, models.HasTenant, models.HasId): @@ -282,26 +295,16 @@ class Port(BASEV2, models.HasTenant, models.HasId): backref="ports") @declarative.declared_attr - def security_rules(cls): - primaryjoin = cls.id == port_rule_association_table.c.port_id - secondaryjoin = (port_rule_association_table.c.rule_id == - SecurityGroupRule.id) - return orm.relationship(SecurityGroupRule, primaryjoin=primaryjoin, + def security_groups(cls): + primaryjoin = cls.id == port_group_association_table.c.port_id + secondaryjoin = (port_group_association_table.c.group_id == + SecurityGroup.id) + return orm.relationship(SecurityGroup, primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, - secondary=port_rule_association_table, + secondary=port_group_association_table, backref="ports") -class SecurityGroup(BASEV2, models.HasId, models.HasTenant): - __tablename__ = "quark_security_groups" - id = sa.Column(sa.String(36), primary_key=True) - name = sa.Column(sa.String(255), nullable=False) - description = sa.Column(sa.String(255), nullable=False) - join = "SecurityGroupRule.group_id==SecurityGroup.id" - rules = orm.relationship(SecurityGroupRule, backref='group', - cascade='delete', primaryjoin=join) - - class MacAddress(BASEV2, models.HasTenant): __tablename__ = "quark_mac_addresses" address = sa.Column(sa.BigInteger(), primary_key=True) diff --git a/quark/drivers/base.py b/quark/drivers/base.py index 4f761c5..d5048ed 100644 --- a/quark/drivers/base.py +++ b/quark/drivers/base.py @@ -37,10 +37,30 @@ class BaseDriver(object): def delete_network(self, context, network_id): LOG.info("delete_network %s" % network_id) - def create_port(self, context, network_id, port_id, status=True): + def create_port(self, context, network_id, port_id, **kwargs): LOG.info("create_port %s %s %s" % (context.tenant_id, network_id, port_id)) return {"uuid": port_id} - def delete_port(self, context, port_id, lswitch_uuid=None): + def update_port(self, context, port_id, **kwargs): + LOG.info("update_port %s %s" % (context.tenant_id, port_id)) + return {"uuid": port_id} + + def delete_port(self, context, port_id, **kwargs): LOG.info("delete_port %s %s" % (context.tenant_id, port_id)) + + def create_security_group(self, context, group_name, **group): + LOG.info("Creating security profile %s for tenant %s" % + (group_name, context.tenant_id)) + + def delete_security_group(self, context, group_id, **kwargs): + LOG.info("Deleting security profile %s for tenant %s" % + (group_id, context.tenant_id)) + + def create_security_group_rule(self, context, group_id, rule): + LOG.info("Creating security rule on group %s for tenant %s" % + (group_id, context.tenant_id)) + + def delete_security_group_rule(self, context, group_id, rule): + LOG.info("Deleting security rule on group %s for tenant %s" % + (group_id, context.tenant_id)) diff --git a/quark/drivers/nvp_driver.py b/quark/drivers/nvp_driver.py index efe55ed..0c59ada 100644 --- a/quark/drivers/nvp_driver.py +++ b/quark/drivers/nvp_driver.py @@ -20,6 +20,7 @@ NVP client driver for Quark from oslo.config import cfg import aiclib +from quantum.extensions import securitygroup as sg_ext from quantum.openstack.common import log as logging from quark.drivers import base @@ -39,6 +40,12 @@ nvp_opts = [ cfg.MultiStrOpt('controller_connection', default=[], help=_('NVP Controller connection string')), + cfg.IntOpt('max_rules_per_group', + default=30, + help=_('Maxiumum size of NVP SecurityRule list per group')), + cfg.IntOpt('max_rules_per_port', + default=30, + help=_('Maximum rules per NVP lport across all groups')), ] physical_net_type_map = { @@ -105,12 +112,17 @@ class NVPDriver(base.BaseDriver): LOG.debug("Deleting lswitch %s" % switch["uuid"]) connection.lswitch(switch["uuid"]).delete() - def create_port(self, context, network_id, port_id, status=True): + def create_port(self, context, network_id, port_id, + status=True, security_groups=[], allowed_pairs=[]): tenant_id = context.tenant_id lswitch = self._create_or_choose_lswitch(context, network_id) connection = self.get_connection() port = connection.lswitch_port(lswitch) port.admin_status_enabled(status) + port.allowed_address_pairs(allowed_pairs) + nvp_group_ids = self._get_security_groups_for_port(context, + security_groups) + port.security_profiles(nvp_group_ids) tags = [dict(tag=network_id, scope="quantum_net_id"), dict(tag=port_id, scope="quantum_port_id"), dict(tag=tenant_id, scope="os_tid")] @@ -120,17 +132,26 @@ class NVPDriver(base.BaseDriver): res["lswitch"] = lswitch return res - def delete_port(self, context, port_id, lswitch_uuid=None): + def update_port(self, context, port_id, status=True, + security_groups=[], allowed_pairs=[]): connection = self.get_connection() + lswitch_id = self._lswitch_from_port(context, port_id) + port = connection.lswitch_port(lswitch_id, port_id) + nvp_group_ids = self._get_security_groups_for_port(context, + security_groups) + self._count_security_rules_on_port(context, nvp_group_ids) + if nvp_group_ids: + port.security_profiles(nvp_group_ids) + if allowed_pairs: + port.allowed_address_pairs(allowed_pairs) + port.admin_status_enabled(status) + return port.update() + + def delete_port(self, context, port_id, **kwargs): + connection = self.get_connection() + lswitch_uuid = kwargs.get('lswitch_uuid', None) if not lswitch_uuid: - query = connection.lswitch_port("*").query() - query.relations("LogicalSwitchConfig") - query.uuid(port_id) - port = query.results() - if port["result_count"] > 1: - raise Exception("More than one lswitch for port %s" % port_id) - for r in port["results"]: - lswitch_uuid = r["_relations"]["LogicalSwitchConfig"]["uuid"] + lswitch_uuid = self._lswitch_from_port(context, port_id) LOG.debug("Deleting port %s from lswitch %s" % (port_id, lswitch_uuid)) connection.lswitch_port(lswitch_uuid, port_id).delete() @@ -149,6 +170,76 @@ class NVPDriver(base.BaseDriver): phys_type=phys_type, segment_id=segment_id) return {} + def create_security_group(self, context, group_name, **group): + tenant_id = context.tenant_id + connection = self.get_connection() + group_id = group.get('group_id') + profile = connection.securityprofile() + if group_name: + profile.display_name(group_name) + profile.port_egress_rules(group.get('port_egress_rules', [])) + profile.port_ingress_rules(group.get('port_ingress_rules', [])) + tags = [dict(tag=group_id, scope="quantum_group_id"), + dict(tag=tenant_id, scope="os_tid")] + LOG.debug("Creating security profile %s" % group_name) + profile.tags(tags) + return profile.create() + + def delete_security_group(self, context, group_id): + guuid = self._get_security_group_id(context, group_id) + connection = self.get_connection() + LOG.debug("Deleting security profile %s" % group_id) + connection.securityprofile(guuid).delete() + + def update_security_group(self, context, group_id, **group): + prof_id = self._get_security_group_id(context, group_id) + connection = self.get_connection() + profile = connection.securityprofile(prof_id) + + ingress_rules = group.get('port_ingress_rules', []) + egress_rules = group.get('port_egress_rules', []) + if (len(egress_rules) + len(ingress_rules) > + CONF.NVP.max_rules_per_group): + raise sg_ext.qexception.InvalidInput( + error_message="Max rules per group for %s" % group_id) + + if group.get('name', None): + profile.display_name(group['name']) + if ingress_rules: + profile.port_ingress_rules(ingress_rules) + if egress_rules: + profile.port_egress_rules(egress_rules) + return profile.update() + + def _update_security_group_rules(self, context, group_id, rule, operation, + check, raises): + groupd = self._get_security_group(context, group_id) + direction, secrule = self._get_security_group_rule_object(context, + rule) + rulelist = groupd['logical_port_%s_rules' % direction] + if not check(secrule, rulelist): + if raises: + raise raises + else: + getattr(rulelist, operation)(secrule) + + LOG.debug("Adding rule on security group %s" % groupd['uuid']) + group = {'port_%s_rules' % direction: rulelist} + return self.update_security_group(context, group_id, **group) + + def create_security_group_rule(self, context, group_id, rule): + return self._update_security_group_rules( + context, group_id, rule, 'append', + lambda x, y: x not in y, + sg_ext.SecurityGroupRuleExists(id=group_id)) + + def delete_security_group_rule(self, context, group_id, rule): + return self._update_security_group_rules( + context, group_id, rule, 'remove', + lambda x, y: x in y, + sg_ext.SecurityGroupRuleNotFound(id="with group_id %s" % + group_id)) + def _create_or_choose_lswitch(self, context, network_id): switches = self._lswitch_status_query(context, network_id) switch = self._lswitch_select_open(context, network_id=network_id, @@ -257,3 +348,61 @@ class NVPDriver(base.BaseDriver): query.tagscopes(['os_tid', 'quantum_net_id']) query.tags([context.tenant_id, network_id]) return query + + def _lswitch_from_port(self, context, port_id): + connection = self.get_connection() + query = connection.lswitch_port("*").query() + query.relations("LogicalSwitchConfig") + query.uuid(port_id) + port = query.results() + if port['result_count'] > 1: + raise Exception("Could not identify lswitch for port %s" % port_id) + if port['result_count'] < 1: + raise Exception("No lswitch found for port %s" % port_id) + return port['results'][0]["_relations"]["LogicalSwitchConfig"]["uuid"] + + def _get_security_group(self, context, group_id): + connection = self.get_connection() + query = connection.securityprofile().query() + query.tagscopes(['os_tid', 'quantum_group_id']) + query.tags([context.tenant_id, group_id]) + query = query.results() + return query['results'][0] + + def _get_security_group_id(self, context, group_id): + return self._get_security_group(context, group_id)['uuid'] + + def _get_security_group_rule_object(self, context, rule): + ethertype = rule.get('ethertype', None) + rule_clone = {} + + ip_prefix = rule.get('remote_ip_prefix', None) + if ip_prefix: + rule_clone['ip_prefix'] = ip_prefix + profile_uuid = rule.get('remote_group_id', None) + if profile_uuid: + rule_clone['profile_uuid'] = profile_uuid + for key in ['protocol', 'port_range_min', 'port_range_max']: + if rule.get(key): + rule_clone[key] = rule[key] + + connection = self.get_connection() + secrule = connection.securityrule(ethertype, **rule_clone) + + direction = rule.get('direction', '') + if direction not in ['ingress', 'egress']: + raise AttributeError( + "Direction not specified as 'ingress' or 'egress'.") + return (direction, secrule) + + def _get_security_groups_for_port(self, context, groups): + rulecount = 0 + nvp_group_ids = [] + for group in groups: + nvp_group = self._get_security_group(context, group) + rulecount += (len(nvp_group['logical_port_ingress_rules']) + + len(nvp_group['logical_port_egress_rules'])) + nvp_group_ids.append(nvp_group['uuid']) + if rulecount > CONF.NVP.max_rules_per_port: + raise sg_ext.exceptions.OverQuota(over='security rules per port') + return nvp_group_ids diff --git a/quark/drivers/optimized_nvp_driver.py b/quark/drivers/optimized_nvp_driver.py index 10e028c..7f97b08 100644 --- a/quark/drivers/optimized_nvp_driver.py +++ b/quark/drivers/optimized_nvp_driver.py @@ -32,10 +32,13 @@ class OptimizedNVPDriver(NVPDriver): for switch in lswitches: self._lswitch_delete(context, switch.nvp_id) - def create_port(self, context, network_id, port_id, status=True): + def create_port(self, context, network_id, port_id, + status=True, security_groups=[], allowed_pairs=[]): nvp_port = super(OptimizedNVPDriver, self).\ create_port(context, network_id, - port_id, status) + port_id, status=status, + security_groups=security_groups, + allowed_pairs=allowed_pairs) switch_nvp_id = nvp_port["lswitch"] # slightly inefficient for the sake of brevity. Lets the @@ -51,6 +54,17 @@ class OptimizedNVPDriver(NVPDriver): switch.port_count = switch.port_count + 1 return nvp_port + def update_port(self, context, port_id, + status=True, security_groups=[], allowed_pairs=[]): + nvp_port = super(OptimizedNVPDriver, self).\ + update_port(context, port_id, status=status, + security_groups=security_groups, + allowed_pairs=allowed_pairs) + port = context.session.query(LSwitchPort).\ + filter(LSwitchPort.port_id == port_id).\ + first() + port.update(nvp_port) + def delete_port(self, context, port_id, lswitch_uuid=None): port = self._lport_select_by_id(context, port_id) switch = port.switch @@ -61,6 +75,21 @@ class OptimizedNVPDriver(NVPDriver): if switch.port_count == 0: self._lswitch_delete(context, switch.nvp_id) + def create_security_group(self, context, group_name, **group): + nvp_group = super(OptimizedNVPDriver, self).create_security_group( + context, group_name, **group) + group_id = group.get('group_id') + profile = SecurityProfile(id=group_id, nvp_id=nvp_group['uuid']) + context.session.add(profile) + + def delete_security_group(self, context, group_id): + super(OptimizedNVPDriver, self).\ + delete_security_group(context, group_id) + group = context.session.query(SecurityProfile).\ + filter(SecurityProfile.id == group_id).\ + first() + context.session.delete(group) + def _lport_select_by_id(self, context, port_id): port = context.session.query(LSwitchPort).\ filter(LSwitchPort.port_id == port_id).\ @@ -144,6 +173,37 @@ class OptimizedNVPDriver(NVPDriver): all() return switches + def _lswitch_from_port(self, context, port_id): + port = self._lport_select_by_id(context, port_id) + return port.switch.nvp_id + + def _get_security_group_id(self, context, group_id): + return context.session.query(SecurityProfile).\ + filter(SecurityProfile.id == group_id).first().nvp_id + + def _make_security_rule_dict(self, rule): + res = {"port_range_min": rule.get("port_range_min"), + "port_range_max": rule.get("port_range_max"), + "protocol": rule.get("protocol"), + "ip_prefix": rule.get("remote_ip_prefix"), + "group_id": rule.get("remote_group_id"), + "ethertype": rule.get("ethertype")} + for key, value in res.items(): + if value is None: + res.pop(key) + return res + + def _get_security_group(self, context, group_id): + group = context.session.query(models.SecurityGroup).\ + filter(models.SecurityGroup.id == group_id).first() + rulelist = {'ingress': [], 'egress': []} + for rule in group.rules: + rulelist[rule.direction].append( + self._make_security_rule_dict(rule)) + return {'uuid': self._get_security_group_id(context, group_id), + 'logical_port_ingress_rules': rulelist['ingress'], + 'logical_port_egress_rules': rulelist['egress']} + class LSwitchPort(models.BASEV2, models.HasId): __tablename__ = "quark_nvp_driver_lswitchport" @@ -170,3 +230,8 @@ class QOS(models.BASEV2, models.HasId): display_name = sa.Column(sa.String(255), nullable=False) max_bandwidth_rate = sa.Column(sa.Integer(), nullable=False) min_bandwidth_rate = sa.Column(sa.Integer(), nullable=False) + + +class SecurityProfile(models.BASEV2, models.HasId): + __tablename__ = "quark_nvp_driver_security_profile" + nvp_id = sa.Column(sa.String(36), nullable=False) diff --git a/quark/plugin.py b/quark/plugin.py index 8af900b..abbe483 100644 --- a/quark/plugin.py +++ b/quark/plugin.py @@ -64,9 +64,10 @@ quark_opts = [ help=_("Path to the config for the net driver")) ] - STRATEGY = network_strategy.STRATEGY CONF.register_opts(quark_opts, "QUARK") +CONF.set_default('quota_security_group', 5, "QUOTAS") +CONF.set_default('quota_security_group_rule', 20, "QUOTAS") def _pop_param(attrs, param, default=None): @@ -124,6 +125,50 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, self.ipam_reuse_after = CONF.QUARK.ipam_reuse_after models.BASEV2.metadata.create_all(quantum_db_api._ENGINE) + def _make_security_group_list(self, context, group_ids): + if not group_ids or group_ids is attributes.ATTR_NOT_SPECIFIED: + return ([], []) + group_ids = list(set(group_ids)) + security_groups = [] + for id in group_ids: + group = db_api.security_group_find(context, id=id, + scope=db_api.ONE) + if not group: + raise sg_ext.SecurityGroupNotFound(id=id) + security_groups.append(group) + return (group_ids, security_groups) + + def _validate_security_group_rule(self, context, rule): + + if (rule.get('remote_ip_prefix', None) and + rule.get('remote_group_id', None)): + raise sg_ext.SecurityGroupRemoteGroupAndRemoteIpPrefix() + + protocol = rule.get('protocol', None) + if protocol is not None: + if (protocol in [6, 17] and + (type(rule.get('port_range_min', None)) != + type(rule.get('port_range_max', None)))): + raise exceptions.InvalidInput( + error_message="For TCP/UDP rules, cannot wildcard only " + "one end of port range.") + try: + protonumber = int(rule['protocol']) + if protonumber < 0 or protonumber > 255: + raise sg_ext.SecurityGroupRuleInvalidProtocol( + protocol=protocol, + values=['udp', 'tcp', 'icmp']) + except (ValueError, TypeError): + raise sg_ext.SecurityGroupRuleInvalidProtocol( + protocol=protocol, values=['udp', 'tcp', 'icmp']) + else: + rule.pop('protocol', None) + if (rule.get('port_range_min', None) is not None or + rule.get('port_range_max', None)) is not None: + raise sg_ext.SecurityGroupProtocolRequiredWithPorts() + + return rule + def _validate_subnet_cidr(self, context, network_id, new_subnet_cidr): """Validate the CIDR for a subnet. @@ -407,6 +452,11 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, s = db_api.subnet_create(context, **sub["subnet"]) new_subnets.append(s) new_net["subnets"] = new_subnets + + if not self.get_security_groups( + context, + filters={"id": '00000000-0000-0000-0000-000000000000'}): + self._create_default_security_group(context) return v._make_network_dict(new_net) def update_network(self, context, id, network): @@ -553,16 +603,26 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, addresses.append(self.ipam_driver.allocate_ip_address( context, net["id"], port_id, self.ipam_reuse_after)) + (group_ids, security_groups) = self._make_security_group_list( + context, port["port"].pop("security_groups", None)) mac = self.ipam_driver.allocate_mac_address(context, net["id"], port_id, self.ipam_reuse_after, mac_address=mac_address) + mac_address_string = str(netaddr.EUI(mac['address'], + dialect=netaddr.mac_unix)) + address_pairs = [{'mac_address': mac_address_string, + 'ip_address': address.get('address_readable') or ''} + for address in addresses] backend_port = self.net_driver.create_port(context, net["id"], - port_id=port_id) + port_id=port_id, + security_groups=group_ids, + allowed_pairs=address_pairs) port_attrs["network_id"] = net["id"] port_attrs["id"] = port_id + port_attrs["security_groups"] = security_groups new_port = db_api.port_create( context, addresses=addresses, mac_address=mac["address"], backend_key=backend_port["uuid"], **port_attrs) @@ -583,6 +643,7 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, if not port_db: raise exceptions.PortNotFound(port_id=id) + address_pairs = [] fixed_ips = port["port"].pop("fixed_ips", None) if fixed_ips: self.ipam_driver.deallocate_ip_address( @@ -601,7 +662,21 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, context, port_db["network_id"], id, self.ipam_reuse_after, ip_address=ip_address)) port["port"]["addresses"] = addresses + mac_address_string = str(netaddr.EUI(port_db.mac_address, + dialect=netaddr.mac_unix)) + address_pairs = [{'mac_address': mac_address_string, + 'ip_address': + address.get('address_readable') or ''} + for address in addresses] + (group_ids, security_groups) = self._make_security_group_list( + context, port["port"].pop("security_groups", None)) + self.net_driver.update_port(context, + port_id=port_db.backend_key, + security_groups=group_ids, + allowed_pairs=address_pairs) + + port["port"]["security_groups"] = security_groups port = db_api.port_update(context, port_db, **port["port"]) @@ -996,20 +1071,79 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, def create_security_group(self, context, security_group): LOG.info("create_security_group for tenant %s" % (context.tenant_id)) - g = security_group["security_group"] - group = db_api.security_group_create(context, **g) - return v._make_security_group_dict(group) + if (db_api.security_group_find(context).count() >= + CONF.QUOTAS.quota_security_group): + raise exceptions.OverQuota(overs="security groups") + group = security_group["security_group"] + group_name = group.get('name', '') + if group_name == "default": + raise sg_ext.SecurityGroupDefaultAlreadyExists() + group_id = uuidutils.generate_uuid() + + self.net_driver.create_security_group( + context, + group_name, + group_id=group_id, + **group) + + group["id"] = group_id + group["name"] = group_name + group["tenant_id"] = context.tenant_id + dbgroup = db_api.security_group_create(context, **group) + return v._make_security_group_dict(dbgroup) + + def _create_default_security_group(self, context): + default_group = { + 'name': 'default', 'description': '', + 'group_id': '00000000-0000-0000-0000-000000000000', + 'port_egress_rules': [], + 'port_ingress_rules': [ + {'ethertype': 'IPv4', 'protocol': 6, + 'port_range_min': 0, 'port_range_max': 65535}, + {'ethertype': 'IPv4', 'protocol': 17, + 'port_range_min': 0, 'port_range_max': 65535}, + {'ethertype': 'IPv6', 'protocol': 6, + 'port_range_min': 0, 'port_range_max': 65535}, + {'ethertype': 'IPv6', 'protocol': 17, + 'port_range_min': 0, 'port_range_max': 65535}]} + + self.net_driver.create_security_group( + context, + "default", + **default_group) + + default_group["id"] = '00000000-0000-0000-0000-000000000000' + default_group["tenant_id"] = context.tenant_id + for rule in default_group.pop('port_ingress_rules'): + db_api.security_group_rule_create( + context, security_group_id= + "00000000-0000-0000-0000-000000000000", + tenant_id=context.tenant_id, direction='ingress', + **rule) + db_api.security_group_create(context, **default_group) def create_security_group_rule(self, context, security_group_rule): LOG.info("create_security_group for tenant %s" % (context.tenant_id)) - r = security_group_rule["security_group_rule"] - group_id = r["security_group_id"] + if (db_api.security_group_rule_find(context).count() >= + CONF.QUOTAS.quota_security_group_rule): + raise exceptions.OverQuota(overs="security group rules") + rule = self._validate_security_group_rule( + context, security_group_rule["security_group_rule"]) + rule['id'] = uuidutils.generate_uuid() + + group_id = rule["security_group_id"] group = db_api.security_group_find(context, id=group_id) if not group: raise sg_ext.SecurityGroupNotFound(group_id=group_id) - rule = db_api.security_group_rule_create(context, **r) - return v._make_security_group_rule_dict(rule) + + self.net_driver.create_security_group_rule( + context, + group_id, + rule) + + return v._make_security_group_rule_dict( + db_api.security_group_rule_create(context, **rule)) def delete_security_group(self, context, id): LOG.info("delete_security_group %s for tenant %s" % @@ -1017,6 +1151,11 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, group = db_api.security_group_find(context, id=id, scope=db_api.ONE) if not group: raise sg_ext.SecurityGroupNotFound(group_id=id) + if group.name == 'default': + raise sg_ext.SecurityGroupCannotRemoveDefault() + if group.ports: + raise sg_ext.SecurityGroupInUse(id=id) + self.net_driver.delete_security_group(context, group['id']) db_api.security_group_delete(context, group) def delete_security_group_rule(self, context, id): @@ -1026,6 +1165,18 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, scope=db_api.ONE) if not rule: raise sg_ext.SecurityGroupRuleNotFound(group_id=id) + + group = db_api.security_group_find(context, id=rule['group_id'], + scope=db_api.ONE) + if not group: + raise sg_ext.SecurityGroupNotFound(id=id) + + self.net_driver.delete_security_group_rule( + context, + group.id, + v._make_security_group_rule_dict(rule)) + + rule['id'] = id db_api.security_group_rule_delete(context, rule) def get_security_group(self, context, id, fields=None): @@ -1050,7 +1201,7 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, page_reverse=False): LOG.info("get_security_groups for tenant %s" % (context.tenant_id)) - groups = db_api.security_group_find(context, filters=filters) + groups = db_api.security_group_find(context, **filters) return [v._make_security_group_dict(group) for group in groups] def get_security_group_rules(self, context, filters=None, fields=None, @@ -1062,4 +1213,15 @@ class Plugin(quantum_plugin_base_v2.QuantumPluginBaseV2, return [v._make_security_group_rule_dict(rule) for rule in rules] def update_security_group(self, context, id, security_group): - raise NotImplementedError() + newgroup = security_group['security_group'] + group = db_api.security_group_find(context, id=id, scope=db_api.ONE) + self.net_driver.update_security_group( + context, + id, + **newgroup) + + dbgroup = db_api.security_group_update( + context, + group, + **newgroup) + return v._make_security_group_dict(dbgroup) diff --git a/quark/plugin_views.py b/quark/plugin_views.py index 7c979a9..a3ff8e5 100644 --- a/quark/plugin_views.py +++ b/quark/plugin_views.py @@ -72,20 +72,21 @@ def _make_security_group_dict(security_group, fields=None): "description": security_group.get("description"), "name": security_group.get("name"), "tenant_id": security_group.get("tenant_id")} - res["security_group_rules"] =\ - [_make_security_group_rule_dict(r) for r in security_group['rules']] + res["security_group_rules"] = [ + r.id for r in security_group['rules']] return res def _make_security_group_rule_dict(security_rule, fields=None): res = {"id": security_rule.get("id"), + "ethertype": security_rule.get("ethertype"), "direction": security_rule.get("direction"), "tenant_id": security_rule.get("tenant_id"), "port_range_max": security_rule.get("port_range_max"), - "port_range_mid": security_rule.get("port_range_mid"), + "port_range_min": security_rule.get("port_range_min"), "protocol": security_rule.get("protocol"), "remote_ip_prefix": security_rule.get("remote_ip_prefix"), - "security_group_id": security_rule.get("security_group_id"), + "security_group_id": security_rule.get("group_id"), "remote_group_id": security_rule.get("remote_group_id")} return res diff --git a/quark/tests/test_nvp_driver.py b/quark/tests/test_nvp_driver.py index 9cad2a0..f37e3b8 100644 --- a/quark/tests/test_nvp_driver.py +++ b/quark/tests/test_nvp_driver.py @@ -459,7 +459,7 @@ class TestNVPDriverDeletePort(TestNVPDriver): def test_delete_port_switch_given(self): with self._stubs() as (connection): self.driver.delete_port(self.context, self.port_id, - self.lswitch_uuid) + lswitch_uuid=self.lswitch_uuid) self.assertFalse(connection.lswitch_port().query.called) self.assertTrue(connection.lswitch_port().delete.called) diff --git a/quark/tests/test_quark_plugin.py b/quark/tests/test_quark_plugin.py index 1644738..99037b1 100644 --- a/quark/tests/test_quark_plugin.py +++ b/quark/tests/test_quark_plugin.py @@ -1250,7 +1250,8 @@ class TestQuarkUpdatePort(TestQuarkPlugin): port_update.assert_called_once_with( self.context, port_find(), - name="ourport") + name="ourport", + security_groups=[]) def test_update_port_fixed_ip_bad_request(self): with self._stubs( @@ -1264,7 +1265,7 @@ class TestQuarkUpdatePort(TestQuarkPlugin): def test_update_port_fixed_ip(self): with self._stubs( - port=dict(id=1, name="myport") + port=dict(id=1, name="myport", mac_address="0:0:0:0:0:1") ) as (port_find, port_update, alloc_ip, dealloc_ip): new_port = dict(port=dict( fixed_ips=[dict(subnet_id=1,