diff --git a/quark/drivers/nvp_driver.py b/quark/drivers/nvp_driver.py index 86be1c6..2af82d1 100644 --- a/quark/drivers/nvp_driver.py +++ b/quark/drivers/nvp_driver.py @@ -17,6 +17,8 @@ NVP client driver for Quark """ +import contextlib + import aiclib from neutron.extensions import securitygroup as sg_ext from neutron.openstack.common import log as logging @@ -89,6 +91,7 @@ class NVPDriver(base.BaseDriver): # NOTE(mdietz): What does default_tz actually mean? # We don't have one default. # NOTE(jkoelker): Transport Zone + # NOTE(mdietz): :-/ tz isn't the issue. default is default_tz = CONF.NVP.default_tz LOG.info("Loading NVP settings " + str(default_tz)) connections = CONF.NVP.controller_connection @@ -99,6 +102,7 @@ class NVPDriver(base.BaseDriver): 'max_rules_per_group': CONF.NVP.max_rules_per_group, 'max_rules_per_port': CONF.NVP.max_rules_per_port}) LOG.info("Loading NVP settings " + str(connections)) + for conn in connections: (ip, port, user, pw, req_timeout, http_timeout, retries, redirects) = conn.split(":") @@ -114,7 +118,11 @@ class NVPDriver(base.BaseDriver): default_tz=default_tz, backoff=backoff)) - def get_connection(self): + def _connection(self): + if len(self.nvp_connections) == 0: + raise exceptions.NoBackendConnectionsDefined( + msg="No NVP connections defined cannot continue") + conn = self.nvp_connections[self.conn_index] if "connection" not in conn: scheme = conn["port"] == "443" and "https" or "http" @@ -132,6 +140,22 @@ class NVPDriver(base.BaseDriver): backoff=backoff) return conn["connection"] + def _next_connection(self): + # TODO(anyone): Do we want to drop and create new connections at some + # point? What about recycling them after a certain + # number of usages or time, proactively? + conn_len = len(self.nvp_connections) + if conn_len: + self.conn_index = (self.conn_index + 1) % conn_len + + @contextlib.contextmanager + def get_connection(self): + try: + yield self._connection() + except Exception: + self._next_connection() + raise + def create_network(self, context, network_name, tags=None, network_id=None, **kwargs): return self._lswitch_create(context, network_name, tags, @@ -187,67 +211,67 @@ class NVPDriver(base.BaseDriver): security_groups = security_groups or [] tenant_id = context.tenant_id lswitch = self._create_or_choose_lswitch(context, network_id) - connection = self.get_connection() - port = connection.lswitch_port(lswitch) - port.admin_status_enabled(status) - nvp_group_ids = self._get_security_groups_for_port(context, - security_groups) - port.security_profiles(nvp_group_ids) - tags = [dict(tag=network_id, scope="neutron_net_id"), - dict(tag=port_id, scope="neutron_port_id"), - dict(tag=tenant_id, scope="os_tid"), - dict(tag=device_id, scope="vm_id")] - LOG.debug("Creating port on switch %s" % lswitch) - port.tags(tags) - res = port.create() - try: - """Catching odd NVP returns here will make it safe to assume that - NVP returned something correct.""" - res["lswitch"] = lswitch - except TypeError: - LOG.exception("Unexpected return from NVP: %s" % res) - raise - port = connection.lswitch_port(lswitch) - port.uuid = res["uuid"] - port.attachment_vif(port_id) - return res + with self.get_connection() as connection: + port = connection.lswitch_port(lswitch) + port.admin_status_enabled(status) + nvp_group_ids = self._get_security_groups_for_port(context, + security_groups) + port.security_profiles(nvp_group_ids) + tags = [dict(tag=network_id, scope="neutron_net_id"), + dict(tag=port_id, scope="neutron_port_id"), + dict(tag=tenant_id, scope="os_tid"), + dict(tag=device_id, scope="vm_id")] + LOG.debug("Creating port on switch %s" % lswitch) + port.tags(tags) + res = port.create() + try: + """Catching odd NVP returns here will make it safe to assume that + NVP returned something correct.""" + res["lswitch"] = lswitch + except TypeError: + LOG.exception("Unexpected return from NVP: %s" % res) + raise + port = connection.lswitch_port(lswitch) + port.uuid = res["uuid"] + port.attachment_vif(port_id) + return res def update_port(self, context, port_id, status=True, security_groups=None, **kwargs): security_groups = security_groups or [] - connection = self.get_connection() - lswitch_id = self._lswitch_from_port(context, port_id) - port = connection.lswitch_port(lswitch_id, port_id) - nvp_group_ids = self._get_security_groups_for_port(context, - security_groups) - if nvp_group_ids: - port.security_profiles(nvp_group_ids) - port.admin_status_enabled(status) - return port.update() + with self.get_connection() as connection: + lswitch_id = self._lswitch_from_port(context, port_id) + port = connection.lswitch_port(lswitch_id, port_id) + nvp_group_ids = self._get_security_groups_for_port(context, + security_groups) + if nvp_group_ids: + port.security_profiles(nvp_group_ids) + port.admin_status_enabled(status) + return port.update() def delete_port(self, context, port_id, **kwargs): - connection = self.get_connection() - lswitch_uuid = kwargs.get('lswitch_uuid', None) - try: - if not lswitch_uuid: - lswitch_uuid = self._lswitch_from_port(context, port_id) - LOG.debug("Deleting port %s from lswitch %s" - % (port_id, lswitch_uuid)) - connection.lswitch_port(lswitch_uuid, port_id).delete() - except aiclib.core.AICException as ae: - if ae.code == 404: - LOG.info("LSwitchPort/Port %s not found in NVP." - " Ignoring explicitly. Code: %s, Message: %s" - % (port_id, ae.code, ae.message)) - else: - LOG.info("AICException deleting LSwitchPort/Port %s in NVP." - " Ignoring explicitly. Code: %s, Message: %s" - % (port_id, ae.code, ae.message)) + with self.get_connection() as connection: + lswitch_uuid = kwargs.get('lswitch_uuid', None) + try: + if not lswitch_uuid: + lswitch_uuid = self._lswitch_from_port(context, port_id) + LOG.debug("Deleting port %s from lswitch %s" + % (port_id, lswitch_uuid)) + connection.lswitch_port(lswitch_uuid, port_id).delete() + except aiclib.core.AICException as ae: + if ae.code == 404: + LOG.info("LSwitchPort/Port %s not found in NVP." + " Ignoring explicitly. Code: %s, Message: %s" + % (port_id, ae.code, ae.message)) + else: + LOG.info("AICException deleting LSwitchPort/Port %s in " + "NVP. Ignoring explicitly. Code: %s, Message: %s" + % (port_id, ae.code, ae.message)) - except Exception as e: - LOG.info("Failed to delete LSwitchPort/Port %s in NVP." - " Ignoring explicitly. Message: %s" - % (port_id, e.args[0])) + except Exception as e: + LOG.info("Failed to delete LSwitchPort/Port %s in NVP." + " Ignoring explicitly. Message: %s" + % (port_id, e.args[0])) def _collect_lport_info(self, lport, get_status): info = { @@ -291,23 +315,23 @@ class NVPDriver(base.BaseDriver): return info def diag_port(self, context, port_id, get_status=False, **kwargs): - connection = self.get_connection() - lswitch_uuid = self._lswitch_from_port(context, port_id) - lswitch_port = connection.lswitch_port(lswitch_uuid, port_id) + with self.get_connection() as connection: + lswitch_uuid = self._lswitch_from_port(context, port_id) + lswitch_port = connection.lswitch_port(lswitch_uuid, port_id) - query = lswitch_port.query() - query.relations("LogicalPortAttachment") - results = query.results() - if results['result_count'] == 0: - return {'lport': "Logical port not found."} + query = lswitch_port.query() + query.relations("LogicalPortAttachment") + results = query.results() + if results['result_count'] == 0: + return {'lport': "Logical port not found."} - config = results['results'][0] - relations = config.pop('_relations') - config['attachment'] = relations['LogicalPortAttachment']['type'] - if get_status: - config['status'] = lswitch_port.status() - config['statistics'] = lswitch_port.statistics() - return {'lport': self._collect_lport_info(config, get_status)} + config = results['results'][0] + relations = config.pop('_relations') + config['attachment'] = relations['LogicalPortAttachment']['type'] + if get_status: + config['status'] = lswitch_port.status() + config['statistics'] = lswitch_port.statistics() + return {'lport': self._collect_lport_info(config, get_status)} def _get_network_details(self, context, network_id, switches): name, phys_net, phys_type, segment_id = None, None, None, None @@ -326,55 +350,55 @@ class NVPDriver(base.BaseDriver): def create_security_group(self, context, group_name, **group): tenant_id = context.tenant_id - connection = self.get_connection() - group_id = group.get('group_id') - profile = connection.securityprofile() - if group_name: - profile.display_name(group_name) - ingress_rules = group.get('port_ingress_rules', []) - egress_rules = group.get('port_egress_rules', []) + with self.get_connection() as connection: + group_id = group.get('group_id') + profile = connection.securityprofile() + if group_name: + profile.display_name(group_name) + ingress_rules = group.get('port_ingress_rules', []) + egress_rules = group.get('port_egress_rules', []) - if (len(ingress_rules) + len(egress_rules) > - self.limits['max_rules_per_group']): - raise exceptions.DriverLimitReached(limit="rules per group") + if (len(ingress_rules) + len(egress_rules) > + self.limits['max_rules_per_group']): + raise exceptions.DriverLimitReached(limit="rules per group") - if egress_rules: - profile.port_egress_rules(egress_rules) - if ingress_rules: - profile.port_ingress_rules(ingress_rules) - tags = [dict(tag=group_id, scope="neutron_group_id"), - dict(tag=tenant_id, scope="os_tid")] - LOG.debug("Creating security profile %s" % group_name) - profile.tags(tags) - return profile.create() + if egress_rules: + profile.port_egress_rules(egress_rules) + if ingress_rules: + profile.port_ingress_rules(ingress_rules) + tags = [dict(tag=group_id, scope="neutron_group_id"), + dict(tag=tenant_id, scope="os_tid")] + LOG.debug("Creating security profile %s" % group_name) + profile.tags(tags) + return profile.create() def delete_security_group(self, context, group_id, **kwargs): guuid = self._get_security_group_id(context, group_id) - connection = self.get_connection() - LOG.debug("Deleting security profile %s" % group_id) - connection.securityprofile(guuid).delete() + with self.get_connection() as connection: + LOG.debug("Deleting security profile %s" % group_id) + connection.securityprofile(guuid).delete() def update_security_group(self, context, group_id, **group): query = self._get_security_group(context, group_id) - connection = self.get_connection() - profile = connection.securityprofile(query.get('uuid')) + with self.get_connection() as connection: + profile = connection.securityprofile(query.get('uuid')) - ingress_rules = group.get('port_ingress_rules', - query.get('logical_port_ingress_rules')) - egress_rules = group.get('port_egress_rules', - query.get('logical_port_egress_rules')) + ingress_rules = group.get('port_ingress_rules', + query.get('logical_port_ingress_rules')) + egress_rules = group.get('port_egress_rules', + query.get('logical_port_egress_rules')) - if (len(ingress_rules) + len(egress_rules) > - self.limits['max_rules_per_group']): - raise exceptions.DriverLimitReached(limit="rules per group") + if (len(ingress_rules) + len(egress_rules) > + self.limits['max_rules_per_group']): + raise exceptions.DriverLimitReached(limit="rules per group") - if group.get('name', None): - profile.display_name(group['name']) - if group.get('port_ingress_rules', None) is not None: - profile.port_ingress_rules(ingress_rules) - if group.get('port_egress_rules', None) is not None: - profile.port_egress_rules(egress_rules) - return profile.update() + if group.get('name', None): + profile.display_name(group['name']) + if group.get('port_ingress_rules', None) is not None: + profile.port_ingress_rules(ingress_rules) + if group.get('port_egress_rules', None) is not None: + profile.port_egress_rules(egress_rules) + return profile.update() def _update_security_group_rules(self, context, group_id, rule, operation, checks): @@ -447,9 +471,9 @@ class NVPDriver(base.BaseDriver): return None def _lswitch_delete(self, context, lswitch_uuid): - connection = self.get_connection() - LOG.debug("Deleting lswitch %s" % lswitch_uuid) - connection.lswitch(lswitch_uuid).delete() + with self.get_connection() as connection: + LOG.debug("Deleting lswitch %s" % lswitch_uuid) + connection.lswitch(lswitch_uuid).delete() def _config_provider_attrs(self, connection, switch, phys_net, net_type, segment_id): @@ -491,64 +515,65 @@ class NVPDriver(base.BaseDriver): (context.tenant_id, network_name)) tenant_id = context.tenant_id - connection = self.get_connection() + with self.get_connection() as connection: + switch = connection.lswitch() + if network_name is None: + network_name = network_id + switch.display_name(network_name[:40]) + tags = tags or [] + tags.append({"tag": tenant_id, "scope": "os_tid"}) + if network_id: + tags.append({"tag": network_id, "scope": "neutron_net_id"}) + switch.tags(tags) + pnet = phys_net or CONF.NVP.default_tz + ptype = phys_type or CONF.NVP.default_tz_type + switch.transport_zone(pnet, ptype) + LOG.debug("Creating lswitch for network %s" % network_id) - switch = connection.lswitch() - if network_name is None: - network_name = network_id - switch.display_name(network_name[:40]) - tags = tags or [] - tags.append({"tag": tenant_id, "scope": "os_tid"}) - if network_id: - tags.append({"tag": network_id, "scope": "neutron_net_id"}) - switch.tags(tags) - pnet = phys_net or CONF.NVP.default_tz - ptype = phys_type or CONF.NVP.default_tz_type - switch.transport_zone(pnet, ptype) - LOG.debug("Creating lswitch for network %s" % network_id) - - # When connecting to public or snet, we need switches that are - # connected to their respective public/private transport zones - # using a "bridge" connector. Public uses no VLAN, whereas private - # uses VLAN 122 in netdev. Probably need this to be configurable - self._config_provider_attrs(connection, switch, phys_net, phys_type, - segment_id) - res = switch.create() - try: - uuid = res["uuid"] - return uuid - except TypeError: - LOG.exception("Unexpected return from NVP: %s" % res) - raise + # When connecting to public or snet, we need switches that are + # connected to their respective public/private transport zones + # using a "bridge" connector. Public uses no VLAN, whereas private + # uses VLAN 122 in netdev. Probably need this to be configurable + self._config_provider_attrs(connection, switch, phys_net, + phys_type, segment_id) + res = switch.create() + try: + uuid = res["uuid"] + return uuid + except TypeError: + LOG.exception("Unexpected return from NVP: %s" % res) + raise def _lswitches_for_network(self, context, network_id): - connection = self.get_connection() - query = connection.lswitch().query() - query.tagscopes(['os_tid', 'neutron_net_id']) - query.tags([context.tenant_id, network_id]) - return query + with self.get_connection() as connection: + query = connection.lswitch().query() + query.tagscopes(['os_tid', 'neutron_net_id']) + query.tags([context.tenant_id, network_id]) + return query def _lswitch_from_port(self, context, port_id): - connection = self.get_connection() - query = connection.lswitch_port("*").query() - query.relations("LogicalSwitchConfig") - query.uuid(port_id) - port = query.results() - if port['result_count'] > 1: - raise Exception("Could not identify lswitch for port %s" % port_id) - if port['result_count'] < 1: - raise Exception("No lswitch found for port %s" % port_id) - return port['results'][0]["_relations"]["LogicalSwitchConfig"]["uuid"] + with self.get_connection() as connection: + query = connection.lswitch_port("*").query() + query.relations("LogicalSwitchConfig") + query.uuid(port_id) + port = query.results() + if port['result_count'] > 1: + raise Exception("Could not identify lswitch for port %s" % + port_id) + if port['result_count'] < 1: + raise Exception("No lswitch found for port %s" % port_id) + cfg = port['results'][0]["_relations"]["LogicalSwitchConfig"] + return cfg["uuid"] def _get_security_group(self, context, group_id): - connection = self.get_connection() - query = connection.securityprofile().query() - query.tagscopes(['os_tid', 'neutron_group_id']) - query.tags([context.tenant_id, group_id]) - query = query.results() - if query['result_count'] != 1: - raise sg_ext.SecurityGroupNotFound(id=group_id) - return query['results'][0] + with self.get_connection() as connection: + query = connection.securityprofile().query() + query.tagscopes(['os_tid', 'neutron_group_id']) + query.tags([context.tenant_id, group_id]) + query = query.results() + if query['result_count'] != 1: + raise sg_ext.SecurityGroupNotFound(id=group_id) + return query['results'][0] def _get_security_group_id(self, context, group_id): return self._get_security_group(context, group_id)['uuid'] @@ -567,24 +592,25 @@ class NVPDriver(base.BaseDriver): if rule.get(key): rule_clone[key] = rule[key] - connection = self.get_connection() - secrule = connection.securityrule(ethertype, **rule_clone) + with self.get_connection() as connection: + secrule = connection.securityrule(ethertype, **rule_clone) - direction = rule.get('direction', '') - if direction not in ['ingress', 'egress']: - raise AttributeError( - "Direction not specified as 'ingress' or 'egress'.") - return (direction, secrule) + direction = rule.get('direction', '') + if direction not in ['ingress', 'egress']: + raise AttributeError( + "Direction not specified as 'ingress' or 'egress'.") + return (direction, secrule) def _check_rule_count_per_port(self, context, group_id): - connection = self.get_connection() - ports = connection.lswitch_port("*").query().security_profile_uuid( - '=', self._get_security_group_id( - context, group_id)).results().get('results', []) - groups = (port.get('security_profiles', []) for port in ports) - return max([self._check_rule_count_for_groups( - context, (connection.securityprofile(gp).read() for gp in group)) - for group in groups] or [0]) + with self.get_connection() as connection: + ports = connection.lswitch_port("*").query().security_profile_uuid( + '=', self._get_security_group_id( + context, group_id)).results().get('results', []) + groups = (port.get('security_profiles', []) for port in ports) + return max([self._check_rule_count_for_groups( + context, (connection.securityprofile(gp).read() + for gp in group)) + for group in groups] or [0]) def _check_rule_count_for_groups(self, context, groups): return sum(len(group['logical_port_ingress_rules']) + diff --git a/quark/exceptions.py b/quark/exceptions.py index d10c622..fc74863 100644 --- a/quark/exceptions.py +++ b/quark/exceptions.py @@ -123,3 +123,8 @@ class RedisConnectionFailure(exceptions.NeutronException): class RedisSlaveWritesForbidden(exceptions.NeutronException): message = _("No write actions can be applied to Slave redis nodes.") + + +class NoBackendConnectionsDefined(exceptions.NeutronException): + message = _("This driver cannot be used without a backend connection " + "definition. %(msg)") diff --git a/quark/tests/test_nvp_driver.py b/quark/tests/test_nvp_driver.py index 430af56..7b297b2 100644 --- a/quark/tests/test_nvp_driver.py +++ b/quark/tests/test_nvp_driver.py @@ -124,10 +124,10 @@ class TestNVPDriverCreateNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_create_network(self): @@ -151,8 +151,8 @@ class TestNVPDriverProviderNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self, tz): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection() switch = self._create_lswitch(1, False) switch.transport_zone = mock.Mock() @@ -161,7 +161,7 @@ class TestNVPDriverProviderNetwork(TestNVPDriver): tz_query = mock.Mock() tz_query.query = mock.Mock(return_value=tz_results) connection.transportzone = mock.Mock(return_value=tz_query) - get_connection.return_value = connection + conn.return_value = connection yield connection, switch def test_config_provider_attrs_flat_net(self): @@ -286,11 +286,11 @@ class TestNVPDriverDeleteNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self, network_exists=True): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), - ) as (get_connection, switch_list): + ) as (conn, switch_list): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if network_exists: ret = {"results": [{"uuid": self.lswitch_uuid}]} else: @@ -318,12 +318,12 @@ class TestNVPDriverDeleteNetworkWithExceptions(TestNVPDriver): @contextlib.contextmanager def _stubs(self, network_exists=True, exception=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._lswitch_delete" % self.d_pkg), - ) as (get_connection, switch_list, switch_delete): + ) as (conn, switch_list, switch_delete): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if network_exists: ret = {"results": [{"uuid": self.lswitch_uuid}]} else: @@ -372,13 +372,14 @@ class TestNVPDriverCreatePort(TestNVPDriver): @contextlib.contextmanager def _stubs(self, has_lswitch=True, maxed_ports=False, net_details=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._get_network_details" % self.d_pkg), - ) as (get_connection, get_switches, get_net_dets): + ) as (conn, next_conn, get_switches, get_net_dets): connection = self._create_connection(has_switches=has_lswitch, maxed_ports=maxed_ports) - get_connection.return_value = connection + conn.return_value = connection get_switches.return_value = connection.lswitch().query() get_net_dets.return_value = net_details yield connection @@ -517,11 +518,12 @@ class TestNVPDriverUpdatePort(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_update_port(self): @@ -550,10 +552,10 @@ class TestNVPDriverLswitchesForNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self, single_switch=True): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection(switch_count=1) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_get_lswitches(self): @@ -606,10 +608,11 @@ class TestNVPDriverDeletePort(TestNVPDriver): @contextlib.contextmanager def _stubs(self, switch_count=1): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection(switch_count=switch_count) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_delete_port(self): @@ -645,11 +648,11 @@ class TestNVPDriverDeletePortWithExceptions(TestNVPDriver): @contextlib.contextmanager def _stubs(self, switch_exception=None, delete_exception=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_from_port" % self.d_pkg), - ) as (get_connection, switch): + ) as (conn, switch): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if switch_exception: switch.side_effect = switch_exception else: @@ -729,11 +732,12 @@ class TestNVPDriverCreateSecurityGroup(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_group_create(self): @@ -783,11 +787,12 @@ class TestNVPDriverDeleteSecurityGroup(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_group_delete(self): @@ -812,11 +817,12 @@ class TestNVPDriverUpdateSecurityGroup(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_group_update(self): @@ -872,14 +878,15 @@ class TestNVPDriverCreateSecurityGroupRule(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() connection.securityrule = self._create_security_rule() connection.lswitch_port().query.return_value = ( self._create_lport_query(1, [self.profile_id])) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_rule_create(self): @@ -955,13 +962,13 @@ class TestNVPDriverDeleteSecurityGroupRule(TestNVPDriver): rulelist['logical_port_%s_rules' % rule.pop('direction')].append( rule) with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection() connection.securityprofile = self._create_security_profile() connection.securityrule = self._create_security_rule() connection.securityprofile().read().update(rulelist) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_delete_security_group(self): @@ -1023,16 +1030,40 @@ class TestNVPGetConnection(TestNVPDriver): http_timeout=10, retries=1, backoff=0)) - with mock.patch("aiclib.nvp.Connection") as (aiclib_conn): - yield aiclib_conn + with contextlib.nested( + mock.patch("aiclib.nvp.Connection"), + mock.patch("%s._next_connection" % self.d_pkg) + ) as (aiclib_conn, next_conn): + yield aiclib_conn, next_conn cfg.CONF.clear_override("controller_connection", "NVP") def test_get_connection(self): - with self._stubs(has_conn=False) as aiclib_conn: - self.driver.get_connection() + with self._stubs(has_conn=False) as (aiclib_conn, next_conn): + with self.driver.get_connection(): + pass self.assertTrue(aiclib_conn.called) + self.assertFalse(next_conn.called) def test_get_connection_connection_defined(self): - with self._stubs(has_conn=True) as aiclib_conn: - self.driver.get_connection() + with self._stubs(has_conn=True) as (aiclib_conn, next_conn): + with self.driver.get_connection(): + pass self.assertFalse(aiclib_conn.called) + self.assertFalse(next_conn.called) + + def test_get_connection_iterates(self): + with self._stubs(has_conn=True) as (aiclib_conn, next_conn): + try: + with self.driver.get_connection(): + raise Exception("Failure") + except Exception: + pass + self.assertFalse(aiclib_conn.called) + self.assertTrue(next_conn.called) + + +class TestNVPGetConnectionNoneDefined(TestNVPDriver): + def test_get_connection(self): + with self.assertRaises(q_exc.NoBackendConnectionsDefined): + with self.driver.get_connection(): + pass diff --git a/quark/tests/test_optimized_nvp_driver.py b/quark/tests/test_optimized_nvp_driver.py index 0a14d4c..f6ba105 100644 --- a/quark/tests/test_optimized_nvp_driver.py +++ b/quark/tests/test_optimized_nvp_driver.py @@ -52,13 +52,13 @@ class TestOptimizedNVPDriverDeleteNetwork(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, switch_count=1): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), - ) as (get_connection, select_switch, get_switches): + ) as (conn, select_switch, get_switches): connection = self._create_connection() switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_switch.return_value = switch get_switches.return_value = [switch] * switch_count self.context.session.delete = mock.Mock(return_value=None) @@ -105,14 +105,14 @@ class TestOptimizedNVPDriverDeleteNetworkWithExceptions( @contextlib.contextmanager def _stubs(self, switch_count=1, error_code=500): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._lswitch_delete" % self.d_pkg) - ) as (get_connection, select_switch, get_switches, delete_switch): + ) as (conn, select_switch, get_switches, delete_switch): connection = self._create_connection() switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_switch.return_value = switch get_switches.return_value = [switch] * switch_count delete_switch.side_effect = aiclib.core.AICException( @@ -151,17 +151,17 @@ class TestOptimizedNVPDriverDeletePortMultiSwitch(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, port_count=2, exception=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lport_select_by_id" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._lport_delete" % self.d_pkg), - ) as (get_connection, select_port, select_switch, + ) as (conn, select_port, select_switch, two_switch, port_delete): connection = self._create_connection() port = self._create_lport_mock(port_count) switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_port.return_value = port select_switch.return_value = switch two_switch.return_value = [switch, switch] @@ -244,15 +244,15 @@ class TestOptimizedNVPDriverDeletePortSingleSwitch(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, port_count=2): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lport_select_by_id" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), - ) as (get_connection, select_port, select_switch, one_switch): + ) as (conn, select_port, select_switch, one_switch): connection = self._create_connection() port = self._create_lport_mock(port_count) switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_port.return_value = port select_switch.return_value = switch one_switch.return_value = [switch] @@ -274,16 +274,16 @@ class TestOptimizedNVPDriverCreatePort(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, has_lswitch=True, maxed_ports=False): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_select_free" % self.d_pkg), mock.patch("%s._lswitch_select_first" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitch_create_optimized" % self.d_pkg), mock.patch("%s._get_network_details" % self.d_pkg) - ) as (get_connection, select_free, select_first, + ) as (conn, select_free, select_first, select_by_id, create_opt, get_net_dets): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if has_lswitch: select_first.return_value = mock.Mock(nvp_id=self.lswitch_uuid) if not has_lswitch: @@ -427,7 +427,7 @@ class TestOptimizedNVPDriverUpdatePort(TestOptimizedNVPDriver): class TestCreateSecurityGroups(TestOptimizedNVPDriver): def test_create_security_group(self): - with mock.patch("%s.get_connection" % self.d_pkg): + with mock.patch("%s._connection" % self.d_pkg): self.driver.create_security_group(self.context, "newgroup") self.assertTrue(self.context.session.add.called) @@ -436,7 +436,7 @@ class TestDeleteSecurityGroups(TestOptimizedNVPDriver): def test_delete_security_group(self): mod_path = "quark.drivers.nvp_driver.NVPDriver" with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._query_security_group" % self.d_pkg), mock.patch("%s.delete_security_group" % mod_path)): @@ -452,10 +452,10 @@ class TestSecurityGroupRules(TestOptimizedNVPDriver): def _stubs(self, rules=None): rules = rules or [] with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._query_security_group" % self.d_pkg), mock.patch("%s._check_rule_count_per_port" % self.d_pkg), - ) as (get_connection, query_sec_group, rule_count): + ) as (conn, query_sec_group, rule_count): query_sec_group.return_value = (quark.drivers.optimized_nvp_driver. SecurityProfile()) connection = self._create_connection() @@ -464,7 +464,7 @@ class TestSecurityGroupRules(TestOptimizedNVPDriver): connection.securityrule = self._create_security_rule() connection.lswitch_port().query.return_value = ( self._create_lport_query(1, [self.profile_id])) - get_connection.return_value = connection + conn.return_value = connection old_query = self.context.session.query sec_group = quark.db.models.SecurityGroup()