diff --git a/ironic/drivers/modules/agent.py b/ironic/drivers/modules/agent.py index 52b131e671..0ed95edba2 100644 --- a/ironic/drivers/modules/agent.py +++ b/ironic/drivers/modules/agent.py @@ -29,11 +29,10 @@ from ironic.common import states from ironic.common import utils from ironic.conductor import task_manager from ironic.conductor import utils as manager_utils -from ironic.db import api as dbapi from ironic.drivers import base from ironic.drivers.modules import agent_client from ironic.drivers.modules import image_cache -from ironic.objects import node as node_module +from ironic import objects from ironic.openstack.common import excutils from ironic.openstack.common import fileutils from ironic.openstack.common import log @@ -336,7 +335,6 @@ class AgentVendorInterface(base.VendorInterface): 'lookup': self._lookup, } self.supported_payload_versions = ['2'] - self.dbapi = dbapi.get_instance() self._client = _get_client() def get_properties(self): @@ -584,7 +582,7 @@ class AgentVendorInterface(base.VendorInterface): 'database.') % mac_addresses) node_id = self._get_node_id(ports) try: - node = node_module.Node.get_by_id(context, node_id) + node = objects.Node.get_by_id(context, node_id) except exception.NodeNotFound: with excutils.save_and_reraise_exception(): LOG.exception(_('Could not find matching node for the ' @@ -601,9 +599,7 @@ class AgentVendorInterface(base.VendorInterface): for mac in mac_addresses: # Will do a search by mac if the mac isn't malformed try: - # TODO(JoshNang) add port.get_by_mac() to Ironic - # port.get_by_uuid() would technically work but shouldn't. - port_ob = self.dbapi.get_port(port_id=mac) + port_ob = objects.Port.get_by_address(context, mac) ports.append(port_ob) except exception.PortNotFound: diff --git a/ironic/tests/drivers/test_agent.py b/ironic/tests/drivers/test_agent.py index a6cf9d5e06..bd3a188cee 100644 --- a/ironic/tests/drivers/test_agent.py +++ b/ironic/tests/drivers/test_agent.py @@ -20,8 +20,8 @@ from ironic.common import neutron from ironic.common import pxe_utils from ironic.common import states from ironic.conductor import task_manager -from ironic.db import api as dbapi from ironic.drivers.modules import agent +from ironic import objects from ironic.openstack.common import context from ironic.tests.conductor import utils as mgr_utils from ironic.tests.db import base as db_base @@ -39,7 +39,6 @@ class TestAgentDeploy(db_base.DbTestCase): def setUp(self): super(TestAgentDeploy, self).setUp() mgr_utils.mock_the_extension_manager(driver='fake_agent') - self.dbapi = dbapi.get_instance() self.driver = agent.AgentDeploy() self.context = context.get_admin_context() n = { @@ -49,10 +48,6 @@ class TestAgentDeploy(db_base.DbTestCase): } self.node = object_utils.create_test_node(self.context, **n) - def _create_test_port(self, **kwargs): - p = db_utils.get_test_port(**kwargs) - return self.dbapi.create_port(p) - def test_validate(self): with task_manager.acquire( self.context, self.node['uuid'], shared=False) as task: @@ -99,9 +94,7 @@ class TestAgentVendor(db_base.DbTestCase): def setUp(self): super(TestAgentVendor, self).setUp() mgr_utils.mock_the_extension_manager(driver="fake_pxe") - self.dbapi = dbapi.get_instance() self.passthru = agent.AgentVendorInterface() - self.passthru.db_connection = mock.Mock(autospec=True) self.context = context.get_admin_context() n = { 'driver': 'fake_pxe', @@ -110,10 +103,6 @@ class TestAgentVendor(db_base.DbTestCase): } self.node = object_utils.create_test_node(self.context, **n) - def _create_test_port(self, **kwargs): - p = db_utils.get_test_port(**kwargs) - return self.dbapi.create_port(p) - def test_validate(self): with task_manager.acquire(self.context, self.node.uuid) as task: self.passthru.validate(task) @@ -197,13 +186,13 @@ class TestAgentVendor(db_base.DbTestCase): version='2', inventory={'interfaces': []}) - def test_find_ports_by_macs(self): - fake_port = self._create_test_port() + @mock.patch.object(objects.Port, 'get_by_address') + def test_find_ports_by_macs(self, mock_get_port): + fake_port = object_utils.get_test_port(self.context) + mock_get_port.return_value = fake_port macs = ['aa:bb:cc:dd:ee:ff'] - self.passthru.dbapi = mock.Mock() - self.passthru.dbapi.get_port.return_value = fake_port with task_manager.acquire( self.context, self.node['uuid'], shared=True) as task: ports = self.passthru._find_ports_by_macs(task, macs) @@ -211,10 +200,9 @@ class TestAgentVendor(db_base.DbTestCase): self.assertEqual(fake_port.uuid, ports[0].uuid) self.assertEqual(fake_port.node_id, ports[0].node_id) - def test_find_ports_by_macs_bad_params(self): - self.passthru.dbapi = mock.Mock() - self.passthru.dbapi.get_port.side_effect = exception.PortNotFound( - port="123") + @mock.patch.object(objects.Port, 'get_by_address') + def test_find_ports_by_macs_bad_params(self, mock_get_port): + mock_get_port.side_effect = exception.PortNotFound(port="123") macs = ['aa:bb:cc:dd:ee:ff'] with task_manager.acquire( @@ -228,7 +216,7 @@ class TestAgentVendor(db_base.DbTestCase): @mock.patch('ironic.drivers.modules.agent.AgentVendorInterface' '._find_ports_by_macs') def test_find_node_by_macs(self, ports_mock, node_id_mock, node_mock): - ports_mock.return_value = [self._create_test_port()] + ports_mock.return_value = object_utils.get_test_port(self.context) node_id_mock.return_value = '1' node_mock.return_value = self.node @@ -258,7 +246,7 @@ class TestAgentVendor(db_base.DbTestCase): '._find_ports_by_macs') def test_find_node_by_macs_nodenotfound(self, ports_mock, node_id_mock, node_mock): - port = self._create_test_port() + port = object_utils.get_test_port(self.context) ports_mock.return_value = [port] node_id_mock.return_value = self.node['uuid'] node_mock.side_effect = [self.node, @@ -273,25 +261,29 @@ class TestAgentVendor(db_base.DbTestCase): macs) def test_get_node_id(self): - fake_port1 = self._create_test_port(node_id=123, - address="aa:bb:cc:dd:ee:fe") - fake_port2 = self._create_test_port(node_id=123, - id=42, - address="aa:bb:cc:dd:ee:fb", - uuid='1be26c0b-03f2-4d2e-ae87-c02' - 'd7f33c782') + fake_port1 = object_utils.get_test_port(self.context, + node_id=123, + address="aa:bb:cc:dd:ee:fe") + fake_port2 = object_utils.get_test_port(self.context, + node_id=123, + id=42, + address="aa:bb:cc:dd:ee:fb", + uuid='1be26c0b-03f2-4d2e-ae87-' + 'c02d7f33c782') node_id = self.passthru._get_node_id([fake_port1, fake_port2]) self.assertEqual(fake_port2.node_id, node_id) def test_get_node_id_exception(self): - fake_port1 = self._create_test_port(node_id=123, - address="aa:bb:cc:dd:ee:fc") - fake_port2 = self._create_test_port(node_id=321, - id=42, - address="aa:bb:cc:dd:ee:fd", - uuid='1be26c0b-03f2-4d2e-ae87-c02' - 'd7f33c782') + fake_port1 = object_utils.get_test_port(self.context, + node_id=123, + address="aa:bb:cc:dd:ee:fc") + fake_port2 = object_utils.get_test_port(self.context, + node_id=321, + id=42, + address="aa:bb:cc:dd:ee:fd", + uuid='1be26c0b-03f2-4d2e-ae87-' + 'c02d7f33c782') self.assertRaises(exception.NodeNotFound, self.passthru._get_node_id, diff --git a/ironic/tests/objects/utils.py b/ironic/tests/objects/utils.py index 65b1c70ea1..2e6d566dda 100644 --- a/ironic/tests/objects/utils.py +++ b/ironic/tests/objects/utils.py @@ -38,3 +38,25 @@ def create_test_node(ctxt, **kw): node = get_test_node(ctxt, **kw) node.create() return node + + +def get_test_port(ctxt, **kw): + """Return a Port object with appropriate attributes. + + NOTE: The object leaves the attributes marked as changed, such + that a create() could be used to commit it to the DB. + """ + db_port = db_utils.get_test_port(**kw) + port = objects.Port(context=ctxt) + for key in db_port: + setattr(port, key, db_port[key]) + return port + + +def create_test_port(ctxt, **kw): + """Create a port in the DB and return a Port object with appropriate + attributes. + """ + port = get_test_port(ctxt, **kw) + port.create() + return port