# Copyright 2012, Nachi Ueno, NTT MCL, Inc. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import netaddr from neutron_lib import constants as const from neutron_lib.utils import helpers from oslo_log import log as logging from neutron._i18n import _ from neutron.db import api as db_api from neutron.db.models import allowed_address_pair as aap_models from neutron.db.models import securitygroup as sg_models from neutron.db import models_v2 from neutron.db import securitygroups_db as sg_db from neutron.extensions import securitygroup as ext_sg LOG = logging.getLogger(__name__) DIRECTION_IP_PREFIX = {'ingress': 'source_ip_prefix', 'egress': 'dest_ip_prefix'} DHCP_RULE_PORT = {4: (67, 68, const.IPv4), 6: (547, 546, const.IPv6)} class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin): """Mixin class to add agent-based security group implementation.""" def get_port_from_device(self, context, device): """Get port dict from device name on an agent. Subclass must provide this method or get_ports_from_devices. :param device: device name which identifies a port on the agent side. What is specified in "device" depends on a plugin agent implementation. For example, it is a port ID in OVS agent and netdev name in Linux Bridge agent. :return: port dict returned by DB plugin get_port(). In addition, it must contain the following fields in the port dict returned. - device - security_groups - security_group_rules, - security_group_source_groups - fixed_ips """ raise NotImplementedError(_("%s must implement get_port_from_device " "or get_ports_from_devices.") % self.__class__.__name__) def get_ports_from_devices(self, context, devices): """Bulk method of get_port_from_device. Subclasses may override this to provide better performance for DB queries, backend calls, etc. """ return [self.get_port_from_device(context, device) for device in devices] def create_security_group_rule(self, context, security_group_rule): rule = super(SecurityGroupServerRpcMixin, self).create_security_group_rule(context, security_group_rule) sgids = [rule['security_group_id']] self.notifier.security_groups_rule_updated(context, sgids) return rule def create_security_group_rule_bulk(self, context, security_group_rules): rules = super(SecurityGroupServerRpcMixin, self).create_security_group_rule_bulk_native( context, security_group_rules) sgids = set([r['security_group_id'] for r in rules]) self.notifier.security_groups_rule_updated(context, list(sgids)) return rules def delete_security_group_rule(self, context, sgrid): rule = self.get_security_group_rule(context, sgrid) super(SecurityGroupServerRpcMixin, self).delete_security_group_rule(context, sgrid) self.notifier.security_groups_rule_updated(context, [rule['security_group_id']]) def check_and_notify_security_group_member_changed( self, context, original_port, updated_port): sg_change = not helpers.compare_elements( original_port.get(ext_sg.SECURITYGROUPS), updated_port.get(ext_sg.SECURITYGROUPS)) if sg_change: self.notify_security_groups_member_updated_bulk( context, [original_port, updated_port]) elif original_port['fixed_ips'] != updated_port['fixed_ips']: self.notify_security_groups_member_updated(context, updated_port) def is_security_group_member_updated(self, context, original_port, updated_port): """Check security group member updated or not. This method returns a flag which indicates request notification is required and does not perform notification itself. It is because another changes for the port may require notification. """ need_notify = False if (original_port['fixed_ips'] != updated_port['fixed_ips'] or original_port['mac_address'] != updated_port['mac_address'] or not helpers.compare_elements( original_port.get(ext_sg.SECURITYGROUPS), updated_port.get(ext_sg.SECURITYGROUPS))): need_notify = True return need_notify def notify_security_groups_member_updated_bulk(self, context, ports): """Notify update event of security group members for ports. The agent setups the iptables rule to allow ingress packet from the dhcp server (as a part of provider rules), so we need to notify an update of dhcp server ip address to the plugin agent. security_groups_provider_updated() just notifies that an event occurs and the plugin agent fetches the update provider rule in the other RPC call (security_group_rules_for_devices). """ sg_provider_updated_networks = set() sec_groups = set() for port in ports: if port['device_owner'] == const.DEVICE_OWNER_DHCP: sg_provider_updated_networks.add( port['network_id']) # For IPv6, provider rule need to be updated in case router # interface is created or updated after VM port is created. # NOTE (Swami): ROUTER_INTERFACE_OWNERS check is required # since it includes the legacy router interface device owners # and DVR router interface device owners. elif port['device_owner'] in const.ROUTER_INTERFACE_OWNERS: if any(netaddr.IPAddress(fixed_ip['ip_address']).version == 6 for fixed_ip in port['fixed_ips']): sg_provider_updated_networks.add( port['network_id']) else: sec_groups |= set(port.get(ext_sg.SECURITYGROUPS)) if sg_provider_updated_networks: ports_query = context.session.query(models_v2.Port.id).filter( models_v2.Port.network_id.in_( sg_provider_updated_networks)).all() ports_to_update = [p.id for p in ports_query] self.notifier.security_groups_provider_updated( context, ports_to_update) if sec_groups: self.notifier.security_groups_member_updated( context, list(sec_groups)) def notify_security_groups_member_updated(self, context, port): self.notify_security_groups_member_updated_bulk(context, [port]) @db_api.retry_if_session_inactive() def security_group_info_for_ports(self, context, ports): sg_info = {'devices': ports, 'security_groups': {}, 'sg_member_ips': {}} rules_in_db = self._select_rules_for_ports(context, ports) remote_security_group_info = {} for (port_id, rule_in_db) in rules_in_db: remote_gid = rule_in_db.get('remote_group_id') security_group_id = rule_in_db.get('security_group_id') ethertype = rule_in_db['ethertype'] if ('security_group_source_groups' not in sg_info['devices'][port_id]): sg_info['devices'][port_id][ 'security_group_source_groups'] = [] if remote_gid: if (remote_gid not in sg_info['devices'][port_id][ 'security_group_source_groups']): sg_info['devices'][port_id][ 'security_group_source_groups'].append(remote_gid) if remote_gid not in remote_security_group_info: remote_security_group_info[remote_gid] = {} if ethertype not in remote_security_group_info[remote_gid]: # this set will be serialized into a list by rpc code remote_security_group_info[remote_gid][ethertype] = set() direction = rule_in_db['direction'] rule_dict = { 'direction': direction, 'ethertype': ethertype} for key in ('protocol', 'port_range_min', 'port_range_max', 'remote_ip_prefix', 'remote_group_id'): if rule_in_db.get(key) is not None: if key == 'remote_ip_prefix': direction_ip_prefix = DIRECTION_IP_PREFIX[direction] rule_dict[direction_ip_prefix] = rule_in_db[key] continue rule_dict[key] = rule_in_db[key] if security_group_id not in sg_info['security_groups']: sg_info['security_groups'][security_group_id] = [] if rule_dict not in sg_info['security_groups'][security_group_id]: sg_info['security_groups'][security_group_id].append( rule_dict) # Update the security groups info if they don't have any rules sg_ids = self._select_sg_ids_for_ports(context, ports) for (sg_id, ) in sg_ids: if sg_id not in sg_info['security_groups']: sg_info['security_groups'][sg_id] = [] sg_info['sg_member_ips'] = remote_security_group_info # the provider rules do not belong to any security group, so these # rules still reside in sg_info['devices'] [port_id] self._apply_provider_rule(context, sg_info['devices']) return self._get_security_group_member_ips(context, sg_info) def _get_security_group_member_ips(self, context, sg_info): ips = self._select_ips_for_remote_group( context, sg_info['sg_member_ips'].keys()) for sg_id, member_ips in ips.items(): for ip in member_ips: ethertype = 'IPv%d' % netaddr.IPNetwork(ip).version if ethertype in sg_info['sg_member_ips'][sg_id]: sg_info['sg_member_ips'][sg_id][ethertype].add(ip) return sg_info def _select_sg_ids_for_ports(self, context, ports): if not ports: return [] sg_binding_port = sg_models.SecurityGroupPortBinding.port_id sg_binding_sgid = sg_models.SecurityGroupPortBinding.security_group_id query = context.session.query(sg_binding_sgid) query = query.filter(sg_binding_port.in_(ports.keys())) return query.all() def _select_rules_for_ports(self, context, ports): if not ports: return [] sg_binding_port = sg_models.SecurityGroupPortBinding.port_id sg_binding_sgid = sg_models.SecurityGroupPortBinding.security_group_id sgr_sgid = sg_models.SecurityGroupRule.security_group_id query = context.session.query(sg_binding_port, sg_models.SecurityGroupRule) query = query.join(sg_models.SecurityGroupRule, sgr_sgid == sg_binding_sgid) query = query.filter(sg_binding_port.in_(ports.keys())) return query.all() def _select_ips_for_remote_group(self, context, remote_group_ids): ips_by_group = {} if not remote_group_ids: return ips_by_group for remote_group_id in remote_group_ids: ips_by_group[remote_group_id] = set() ip_port = models_v2.IPAllocation.port_id sg_binding_port = sg_models.SecurityGroupPortBinding.port_id sg_binding_sgid = sg_models.SecurityGroupPortBinding.security_group_id # Join the security group binding table directly to the IP allocation # table instead of via the Port table skip an unnecessary intermediary query = context.session.query(sg_binding_sgid, models_v2.IPAllocation.ip_address, aap_models.AllowedAddressPair.ip_address) query = query.join(models_v2.IPAllocation, ip_port == sg_binding_port) # Outerjoin because address pairs may be null and we still want the # IP for the port. query = query.outerjoin( aap_models.AllowedAddressPair, sg_binding_port == aap_models.AllowedAddressPair.port_id) query = query.filter(sg_binding_sgid.in_(remote_group_ids)) # Each allowed address pair IP record for a port beyond the 1st # will have a duplicate regular IP in the query response since # the relationship is 1-to-many. Dedup with a set for security_group_id, ip_address, allowed_addr_ip in query: ips_by_group[security_group_id].add(ip_address) if allowed_addr_ip: ips_by_group[security_group_id].add(allowed_addr_ip) return ips_by_group def _select_remote_group_ids(self, ports): remote_group_ids = [] for port in ports.values(): for rule in port.get('security_group_rules'): remote_group_id = rule.get('remote_group_id') if remote_group_id: remote_group_ids.append(remote_group_id) return remote_group_ids def _convert_remote_group_id_to_ip_prefix(self, context, ports): remote_group_ids = self._select_remote_group_ids(ports) ips = self._select_ips_for_remote_group(context, remote_group_ids) for port in ports.values(): updated_rule = [] for rule in port.get('security_group_rules'): remote_group_id = rule.get('remote_group_id') direction = rule.get('direction') direction_ip_prefix = DIRECTION_IP_PREFIX[direction] if not remote_group_id: updated_rule.append(rule) continue port['security_group_source_groups'].append(remote_group_id) base_rule = rule for ip in ips[remote_group_id]: if ip in port.get('fixed_ips', []): continue ip_rule = base_rule.copy() version = netaddr.IPNetwork(ip).version ethertype = 'IPv%s' % version if base_rule['ethertype'] != ethertype: continue ip_rule[direction_ip_prefix] = str( netaddr.IPNetwork(ip).cidr) updated_rule.append(ip_rule) port['security_group_rules'] = updated_rule return ports def _add_ingress_dhcp_rule(self, port): for ip_version in (4, 6): # only allow DHCP servers to talk to the appropriate IP address # to avoid getting leases that don't match the Neutron IPs prefix = '32' if ip_version == 4 else '128' dests = ['%s/%s' % (ip, prefix) for ip in port['fixed_ips'] if netaddr.IPNetwork(ip).version == ip_version] if ip_version == 4: # v4 dhcp servers can also talk to broadcast dests.append('255.255.255.255/32') elif ip_version == 6: # v6 dhcp responses can target link-local addresses dests.append('fe80::/64') source_port, dest_port, ethertype = DHCP_RULE_PORT[ip_version] for dest in dests: dhcp_rule = {'direction': 'ingress', 'ethertype': ethertype, 'protocol': 'udp', 'port_range_min': dest_port, 'port_range_max': dest_port, 'source_port_range_min': source_port, 'source_port_range_max': source_port, 'dest_ip_prefix': dest} port['security_group_rules'].append(dhcp_rule) def _add_ingress_ra_rule(self, port): has_v6 = [ip for ip in port['fixed_ips'] if netaddr.IPNetwork(ip).version == 6] if not has_v6: return ra_rule = {'direction': 'ingress', 'ethertype': const.IPv6, 'protocol': const.PROTO_NAME_IPV6_ICMP, 'source_port_range_min': const.ICMPV6_TYPE_RA} port['security_group_rules'].append(ra_rule) def _apply_provider_rule(self, context, ports): for port in ports.values(): self._add_ingress_ra_rule(port) self._add_ingress_dhcp_rule(port) @db_api.retry_if_session_inactive() def security_group_rules_for_ports(self, context, ports): rules_in_db = self._select_rules_for_ports(context, ports) for (port_id, rule_in_db) in rules_in_db: port = ports[port_id] direction = rule_in_db['direction'] rule_dict = { 'security_group_id': rule_in_db['security_group_id'], 'direction': direction, 'ethertype': rule_in_db['ethertype'], } for key in ('protocol', 'port_range_min', 'port_range_max', 'remote_ip_prefix', 'remote_group_id'): if rule_in_db.get(key) is not None: if key == 'remote_ip_prefix': direction_ip_prefix = DIRECTION_IP_PREFIX[direction] rule_dict[direction_ip_prefix] = rule_in_db[key] continue rule_dict[key] = rule_in_db[key] port['security_group_rules'].append(rule_dict) self._apply_provider_rule(context, ports) return self._convert_remote_group_id_to_ip_prefix(context, ports)