Merge "Remove dbapi calls from agent driver"

This commit is contained in:
Jenkins
2014-08-01 16:56:47 +00:00
committed by Gerrit Code Review
3 changed files with 53 additions and 43 deletions

View File

@@ -29,11 +29,10 @@ from ironic.common import states
from ironic.common import utils from ironic.common import utils
from ironic.conductor import task_manager from ironic.conductor import task_manager
from ironic.conductor import utils as manager_utils from ironic.conductor import utils as manager_utils
from ironic.db import api as dbapi
from ironic.drivers import base from ironic.drivers import base
from ironic.drivers.modules import agent_client from ironic.drivers.modules import agent_client
from ironic.drivers.modules import image_cache from ironic.drivers.modules import image_cache
from ironic.objects import node as node_module from ironic import objects
from ironic.openstack.common import excutils from ironic.openstack.common import excutils
from ironic.openstack.common import fileutils from ironic.openstack.common import fileutils
from ironic.openstack.common import log from ironic.openstack.common import log
@@ -336,7 +335,6 @@ class AgentVendorInterface(base.VendorInterface):
'lookup': self._lookup, 'lookup': self._lookup,
} }
self.supported_payload_versions = ['2'] self.supported_payload_versions = ['2']
self.dbapi = dbapi.get_instance()
self._client = _get_client() self._client = _get_client()
def get_properties(self): def get_properties(self):
@@ -584,7 +582,7 @@ class AgentVendorInterface(base.VendorInterface):
'database.') % mac_addresses) 'database.') % mac_addresses)
node_id = self._get_node_id(ports) node_id = self._get_node_id(ports)
try: try:
node = node_module.Node.get_by_id(context, node_id) node = objects.Node.get_by_id(context, node_id)
except exception.NodeNotFound: except exception.NodeNotFound:
with excutils.save_and_reraise_exception(): with excutils.save_and_reraise_exception():
LOG.exception(_('Could not find matching node for the ' LOG.exception(_('Could not find matching node for the '
@@ -601,9 +599,7 @@ class AgentVendorInterface(base.VendorInterface):
for mac in mac_addresses: for mac in mac_addresses:
# Will do a search by mac if the mac isn't malformed # Will do a search by mac if the mac isn't malformed
try: try:
# TODO(JoshNang) add port.get_by_mac() to Ironic port_ob = objects.Port.get_by_address(context, mac)
# port.get_by_uuid() would technically work but shouldn't.
port_ob = self.dbapi.get_port(port_id=mac)
ports.append(port_ob) ports.append(port_ob)
except exception.PortNotFound: except exception.PortNotFound:

View File

@@ -20,8 +20,8 @@ from ironic.common import neutron
from ironic.common import pxe_utils from ironic.common import pxe_utils
from ironic.common import states from ironic.common import states
from ironic.conductor import task_manager from ironic.conductor import task_manager
from ironic.db import api as dbapi
from ironic.drivers.modules import agent from ironic.drivers.modules import agent
from ironic import objects
from ironic.openstack.common import context from ironic.openstack.common import context
from ironic.tests.conductor import utils as mgr_utils from ironic.tests.conductor import utils as mgr_utils
from ironic.tests.db import base as db_base from ironic.tests.db import base as db_base
@@ -39,7 +39,6 @@ class TestAgentDeploy(db_base.DbTestCase):
def setUp(self): def setUp(self):
super(TestAgentDeploy, self).setUp() super(TestAgentDeploy, self).setUp()
mgr_utils.mock_the_extension_manager(driver='fake_agent') mgr_utils.mock_the_extension_manager(driver='fake_agent')
self.dbapi = dbapi.get_instance()
self.driver = agent.AgentDeploy() self.driver = agent.AgentDeploy()
self.context = context.get_admin_context() self.context = context.get_admin_context()
n = { n = {
@@ -49,10 +48,6 @@ class TestAgentDeploy(db_base.DbTestCase):
} }
self.node = object_utils.create_test_node(self.context, **n) self.node = object_utils.create_test_node(self.context, **n)
def _create_test_port(self, **kwargs):
p = db_utils.get_test_port(**kwargs)
return self.dbapi.create_port(p)
def test_validate(self): def test_validate(self):
with task_manager.acquire( with task_manager.acquire(
self.context, self.node['uuid'], shared=False) as task: self.context, self.node['uuid'], shared=False) as task:
@@ -99,9 +94,7 @@ class TestAgentVendor(db_base.DbTestCase):
def setUp(self): def setUp(self):
super(TestAgentVendor, self).setUp() super(TestAgentVendor, self).setUp()
mgr_utils.mock_the_extension_manager(driver="fake_pxe") mgr_utils.mock_the_extension_manager(driver="fake_pxe")
self.dbapi = dbapi.get_instance()
self.passthru = agent.AgentVendorInterface() self.passthru = agent.AgentVendorInterface()
self.passthru.db_connection = mock.Mock(autospec=True)
self.context = context.get_admin_context() self.context = context.get_admin_context()
n = { n = {
'driver': 'fake_pxe', 'driver': 'fake_pxe',
@@ -110,10 +103,6 @@ class TestAgentVendor(db_base.DbTestCase):
} }
self.node = object_utils.create_test_node(self.context, **n) self.node = object_utils.create_test_node(self.context, **n)
def _create_test_port(self, **kwargs):
p = db_utils.get_test_port(**kwargs)
return self.dbapi.create_port(p)
def test_validate(self): def test_validate(self):
with task_manager.acquire(self.context, self.node.uuid) as task: with task_manager.acquire(self.context, self.node.uuid) as task:
self.passthru.validate(task) self.passthru.validate(task)
@@ -197,13 +186,13 @@ class TestAgentVendor(db_base.DbTestCase):
version='2', version='2',
inventory={'interfaces': []}) inventory={'interfaces': []})
def test_find_ports_by_macs(self): @mock.patch.object(objects.Port, 'get_by_address')
fake_port = self._create_test_port() def test_find_ports_by_macs(self, mock_get_port):
fake_port = object_utils.get_test_port(self.context)
mock_get_port.return_value = fake_port
macs = ['aa:bb:cc:dd:ee:ff'] macs = ['aa:bb:cc:dd:ee:ff']
self.passthru.dbapi = mock.Mock()
self.passthru.dbapi.get_port.return_value = fake_port
with task_manager.acquire( with task_manager.acquire(
self.context, self.node['uuid'], shared=True) as task: self.context, self.node['uuid'], shared=True) as task:
ports = self.passthru._find_ports_by_macs(task, macs) ports = self.passthru._find_ports_by_macs(task, macs)
@@ -211,10 +200,9 @@ class TestAgentVendor(db_base.DbTestCase):
self.assertEqual(fake_port.uuid, ports[0].uuid) self.assertEqual(fake_port.uuid, ports[0].uuid)
self.assertEqual(fake_port.node_id, ports[0].node_id) self.assertEqual(fake_port.node_id, ports[0].node_id)
def test_find_ports_by_macs_bad_params(self): @mock.patch.object(objects.Port, 'get_by_address')
self.passthru.dbapi = mock.Mock() def test_find_ports_by_macs_bad_params(self, mock_get_port):
self.passthru.dbapi.get_port.side_effect = exception.PortNotFound( mock_get_port.side_effect = exception.PortNotFound(port="123")
port="123")
macs = ['aa:bb:cc:dd:ee:ff'] macs = ['aa:bb:cc:dd:ee:ff']
with task_manager.acquire( with task_manager.acquire(
@@ -228,7 +216,7 @@ class TestAgentVendor(db_base.DbTestCase):
@mock.patch('ironic.drivers.modules.agent.AgentVendorInterface' @mock.patch('ironic.drivers.modules.agent.AgentVendorInterface'
'._find_ports_by_macs') '._find_ports_by_macs')
def test_find_node_by_macs(self, ports_mock, node_id_mock, node_mock): def test_find_node_by_macs(self, ports_mock, node_id_mock, node_mock):
ports_mock.return_value = [self._create_test_port()] ports_mock.return_value = object_utils.get_test_port(self.context)
node_id_mock.return_value = '1' node_id_mock.return_value = '1'
node_mock.return_value = self.node node_mock.return_value = self.node
@@ -258,7 +246,7 @@ class TestAgentVendor(db_base.DbTestCase):
'._find_ports_by_macs') '._find_ports_by_macs')
def test_find_node_by_macs_nodenotfound(self, ports_mock, node_id_mock, def test_find_node_by_macs_nodenotfound(self, ports_mock, node_id_mock,
node_mock): node_mock):
port = self._create_test_port() port = object_utils.get_test_port(self.context)
ports_mock.return_value = [port] ports_mock.return_value = [port]
node_id_mock.return_value = self.node['uuid'] node_id_mock.return_value = self.node['uuid']
node_mock.side_effect = [self.node, node_mock.side_effect = [self.node,
@@ -273,25 +261,29 @@ class TestAgentVendor(db_base.DbTestCase):
macs) macs)
def test_get_node_id(self): def test_get_node_id(self):
fake_port1 = self._create_test_port(node_id=123, fake_port1 = object_utils.get_test_port(self.context,
address="aa:bb:cc:dd:ee:fe") node_id=123,
fake_port2 = self._create_test_port(node_id=123, address="aa:bb:cc:dd:ee:fe")
id=42, fake_port2 = object_utils.get_test_port(self.context,
address="aa:bb:cc:dd:ee:fb", node_id=123,
uuid='1be26c0b-03f2-4d2e-ae87-c02' id=42,
'd7f33c782') address="aa:bb:cc:dd:ee:fb",
uuid='1be26c0b-03f2-4d2e-ae87-'
'c02d7f33c782')
node_id = self.passthru._get_node_id([fake_port1, fake_port2]) node_id = self.passthru._get_node_id([fake_port1, fake_port2])
self.assertEqual(fake_port2.node_id, node_id) self.assertEqual(fake_port2.node_id, node_id)
def test_get_node_id_exception(self): def test_get_node_id_exception(self):
fake_port1 = self._create_test_port(node_id=123, fake_port1 = object_utils.get_test_port(self.context,
address="aa:bb:cc:dd:ee:fc") node_id=123,
fake_port2 = self._create_test_port(node_id=321, address="aa:bb:cc:dd:ee:fc")
id=42, fake_port2 = object_utils.get_test_port(self.context,
address="aa:bb:cc:dd:ee:fd", node_id=321,
uuid='1be26c0b-03f2-4d2e-ae87-c02' id=42,
'd7f33c782') address="aa:bb:cc:dd:ee:fd",
uuid='1be26c0b-03f2-4d2e-ae87-'
'c02d7f33c782')
self.assertRaises(exception.NodeNotFound, self.assertRaises(exception.NodeNotFound,
self.passthru._get_node_id, self.passthru._get_node_id,

View File

@@ -38,3 +38,25 @@ def create_test_node(ctxt, **kw):
node = get_test_node(ctxt, **kw) node = get_test_node(ctxt, **kw)
node.create() node.create()
return node return node
def get_test_port(ctxt, **kw):
"""Return a Port object with appropriate attributes.
NOTE: The object leaves the attributes marked as changed, such
that a create() could be used to commit it to the DB.
"""
db_port = db_utils.get_test_port(**kw)
port = objects.Port(context=ctxt)
for key in db_port:
setattr(port, key, db_port[key])
return port
def create_test_port(ctxt, **kw):
"""Create a port in the DB and return a Port object with appropriate
attributes.
"""
port = get_test_port(ctxt, **kw)
port.create()
return port