diff --git a/ironic/api/controllers/v1/port.py b/ironic/api/controllers/v1/port.py index 99db96cb4f..ef04da0736 100644 --- a/ironic/api/controllers/v1/port.py +++ b/ironic/api/controllers/v1/port.py @@ -209,7 +209,7 @@ class PortsController(rest.RestController): """ try: - port = pecan.request.dbapi.get_port(address) + port = objects.Port.get_by_address(pecan.request.context, address) return [port] except exception.PortNotFound: return [] diff --git a/ironic/db/api.py b/ironic/db/api.py index b5c27d332e..060f49781d 100644 --- a/ironic/db/api.py +++ b/ironic/db/api.py @@ -190,10 +190,26 @@ class Connection(object): """ @abc.abstractmethod - def get_port(self, port_id): + def get_port_by_id(self, port_id): """Return a network port representation. - :param port_id: The id or MAC of a port. + :param port_id: The id of a port. + :returns: A port. + """ + + @abc.abstractmethod + def get_port_by_uuid(self, port_uuid): + """Return a network port representation. + + :param port_uuid: The uuid of a port. + :returns: A port. + """ + + @abc.abstractmethod + def get_port_by_address(self, address): + """Return a network port representation. + + :param address: The MAC address of a port. :returns: A port. """ diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index 2901b1aaf0..1d29353ae7 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -364,17 +364,26 @@ class Connection(api.Connection): ref.update(values) return ref - @objects.objectify(objects.Port) - def get_port(self, port_id): - query = model_query(models.Port) - query = add_port_filter(query, port_id) - + def get_port_by_id(self, port_id): + query = model_query(models.Port).filter_by(id=port_id) try: - result = query.one() + return query.one() except NoResultFound: raise exception.PortNotFound(port=port_id) - return result + def get_port_by_uuid(self, port_uuid): + query = model_query(models.Port).filter_by(uuid=port_uuid) + try: + return query.one() + except NoResultFound: + raise exception.PortNotFound(port=port_uuid) + + def get_port_by_address(self, address): + query = model_query(models.Port).filter_by(address=address) + try: + return query.one() + except NoResultFound: + raise exception.PortNotFound(port=address) @objects.objectify(objects.Port) def get_port_by_vif(self, vif): diff --git a/ironic/objects/port.py b/ironic/objects/port.py index 07c400e2a9..0028836194 100644 --- a/ironic/objects/port.py +++ b/ironic/objects/port.py @@ -13,20 +13,27 @@ # License for the specific language governing permissions and limitations # under the License. +from ironic.common import exception +from ironic.common import utils from ironic.db import api as dbapi from ironic.objects import base -from ironic.objects import utils +from ironic.objects import utils as obj_utils class Port(base.IronicObject): + # Version 1.0: Initial version + # Version 1.1: Add get() and get_by_id() and get_by_address() and + # make get_by_uuid() only work with a uuid + VERSION = '1.1' + dbapi = dbapi.get_instance() fields = { 'id': int, - 'uuid': utils.str_or_none, - 'node_id': utils.int_or_none, - 'address': utils.str_or_none, - 'extra': utils.dict_or_none, + 'uuid': obj_utils.str_or_none, + 'node_id': obj_utils.int_or_none, + 'address': obj_utils.str_or_none, + 'extra': obj_utils.dict_or_none, } @staticmethod @@ -39,23 +46,74 @@ class Port(base.IronicObject): return port @base.remotable_classmethod - def get_by_uuid(cls, context, uuid=None): - """Find a port based on uuid and return a Port object. + def get(cls, context, port_id): + """Find a port based on its id or uuid and return a Port object. - :param uuid: the uuid of a port. + :param port_id: the id *or* uuid of a port. :returns: a :class:`Port` object. """ - db_port = cls.dbapi.get_port(uuid) - return Port._from_db_object(cls(), db_port) + if utils.is_int_like(port_id): + return cls.get_by_id(context, port_id) + elif utils.is_uuid_like(port_id): + return cls.get_by_uuid(context, port_id) + elif utils.is_valid_mac(port_id): + return cls.get_by_address(context, port_id) + else: + raise exception.InvalidIdentity(identity=port_id) + + @base.remotable_classmethod + def get_by_id(cls, context, port_id): + """Find a port based on its integer id and return a Port object. + + :param port_id: the id of a port. + :returns: a :class:`Port` object. + """ + db_port = cls.dbapi.get_port_by_id(port_id) + port = Port._from_db_object(cls(), db_port) + # FIXME(comstud): Setting of the context should be moved to + # _from_db_object(). + port._context = context + return port + + @base.remotable_classmethod + def get_by_uuid(cls, context, uuid): + """Find a port based on uuid and return a :class:`Port` object. + + :param uuid: the uuid of a port. + :param context: Security context + :returns: a :class:`Port` object. + """ + db_port = cls.dbapi.get_port_by_uuid(uuid) + port = Port._from_db_object(cls(), db_port) + # FIXME(comstud): Setting of the context should be moved to + # _from_db_object(). + port._context = context + return port + + @base.remotable_classmethod + def get_by_address(cls, context, address): + """Find a port based on address and return a :class:`Port` object. + + :param address: the address of a port. + :param context: Security context + :returns: a :class:`Port` object. + """ + db_port = cls.dbapi.get_port_by_address(address) + port = Port._from_db_object(cls(), db_port) + # FIXME(comstud): Setting of the context should be moved to + # _from_db_object(). + port._context = context + return port @base.remotable - def save(self, context): + def save(self, context=None): """Save updates to this Port. Updates will be made column by column based on the result of self.what_changed(). - :param context: Security context + :param context: Security context. NOTE: This is only used + internally by the indirection_api. """ updates = self.obj_get_changes() self.dbapi.update_port(self.uuid, updates) @@ -63,14 +121,15 @@ class Port(base.IronicObject): self.obj_reset_changes() @base.remotable - def refresh(self, context): + def refresh(self, context=None): """Loads updates for this Port. Loads a port with the same uuid from the database and checks for updated attributes. Updates are applied from the loaded port column by column, if there are any updates. - :param context: Security context + :param context: Security context. NOTE: This is only used + internally by the indirection_api. """ current = self.__class__.get_by_uuid(context, uuid=self.uuid) for field in self.fields: diff --git a/ironic/tests/api/v1/test_ports.py b/ironic/tests/api/v1/test_ports.py index 87aebefd7a..3b345f9442 100644 --- a/ironic/tests/api/v1/test_ports.py +++ b/ironic/tests/api/v1/test_ports.py @@ -580,7 +580,7 @@ class TestPost(base.FunctionalTest): pdict = post_get_test_port(node_uuid=self.node['uuid']) self.post_json('/ports', pdict) # GET doesn't return the node_id it's an internal value - port = self.dbapi.get_port(pdict['uuid']) + port = self.dbapi.get_port_by_uuid(pdict['uuid']) self.assertEqual(self.node['id'], port.node_id) def test_create_port_node_uuid_not_found(self): diff --git a/ironic/tests/db/test_nodes.py b/ironic/tests/db/test_nodes.py index 632cbff4cf..0f81e62a5a 100644 --- a/ironic/tests/db/test_nodes.py +++ b/ironic/tests/db/test_nodes.py @@ -279,7 +279,8 @@ class DbNodeTestCase(base.DbTestCase): self.dbapi.destroy_node(node_id) - self.assertRaises(exception.PortNotFound, self.dbapi.get_port, p.id) + self.assertRaises(exception.PortNotFound, + self.dbapi.get_port_by_id, p.id) def test_ports_get_destroyed_after_destroying_a_node_by_uuid(self): n = self._create_test_node() @@ -290,7 +291,8 @@ class DbNodeTestCase(base.DbTestCase): self.dbapi.destroy_node(n['uuid']) - self.assertRaises(exception.PortNotFound, self.dbapi.get_port, p.id) + self.assertRaises(exception.PortNotFound, + self.dbapi.get_port_by_id, p.id) def test_update_node(self): n = self._create_test_node() diff --git a/ironic/tests/db/test_ports.py b/ironic/tests/db/test_ports.py index 02a63dc8f2..e1c7a9ebf8 100644 --- a/ironic/tests/db/test_ports.py +++ b/ironic/tests/db/test_ports.py @@ -38,12 +38,17 @@ class DbPortTestCase(base.DbTestCase): def test_get_port_by_id(self): self.dbapi.create_port(self.p) - res = self.dbapi.get_port(self.p['id']) + res = self.dbapi.get_port_by_id(self.p['id']) self.assertEqual(self.p['address'], res.address) def test_get_port_by_uuid(self): self.dbapi.create_port(self.p) - res = self.dbapi.get_port(self.p['uuid']) + res = self.dbapi.get_port_by_uuid(self.p['uuid']) + self.assertEqual(self.p['id'], res.id) + + def test_get_port_by_address(self): + self.dbapi.create_port(self.p) + res = self.dbapi.get_port_by_address(self.p['address']) self.assertEqual(self.p['id'], res.id) def test_get_port_list(self): @@ -57,19 +62,6 @@ class DbPortTestCase(base.DbTestCase): res_uuids = [r.uuid for r in res] self.assertEqual(uuids.sort(), res_uuids.sort()) - def test_get_port_by_address(self): - self.dbapi.create_port(self.p) - - res = self.dbapi.get_port(self.p['address']) - self.assertEqual(self.p['id'], res.id) - - self.assertRaises(exception.PortNotFound, - self.dbapi.get_port, 99) - self.assertRaises(exception.PortNotFound, - self.dbapi.get_port, 'aa:bb:cc:dd:ee:ff') - self.assertRaises(exception.InvalidIdentity, - self.dbapi.get_port, 'not-a-mac') - def test_get_ports_by_node_id(self): p = db_utils.get_test_port(node_id=self.n.id) self.dbapi.create_port(p) diff --git a/ironic/tests/objects/test_port.py b/ironic/tests/objects/test_port.py index d599e5346e..f459f5f027 100644 --- a/ironic/tests/objects/test_port.py +++ b/ironic/tests/objects/test_port.py @@ -15,6 +15,7 @@ import mock +from ironic.common import exception from ironic.db import api as db_api from ironic.db.sqlalchemy import models from ironic import objects @@ -29,19 +30,43 @@ class TestPortObject(base.DbTestCase): self.fake_port = utils.get_test_port() self.dbapi = db_api.get_instance() - def test_load(self): - uuid = self.fake_port['uuid'] - with mock.patch.object(self.dbapi, 'get_port', + def test_get_by_id(self): + port_id = self.fake_port['id'] + with mock.patch.object(self.dbapi, 'get_port_by_id', autospec=True) as mock_get_port: mock_get_port.return_value = self.fake_port - objects.Port.get_by_uuid(self.context, uuid) + objects.Port.get(self.context, port_id) + + mock_get_port.assert_called_once_with(port_id) + + def test_get_by_uuid(self): + uuid = self.fake_port['uuid'] + with mock.patch.object(self.dbapi, 'get_port_by_uuid', + autospec=True) as mock_get_port: + mock_get_port.return_value = self.fake_port + + objects.Port.get(self.context, uuid) mock_get_port.assert_called_once_with(uuid) + def test_get_by_address(self): + address = self.fake_port['address'] + with mock.patch.object(self.dbapi, 'get_port_by_address', + autospec=True) as mock_get_port: + mock_get_port.return_value = self.fake_port + + objects.Port.get(self.context, address) + + mock_get_port.assert_called_once_with(address) + + def test_get_bad_id_and_uuid_and_address(self): + self.assertRaises(exception.InvalidIdentity, + objects.Port.get, self.context, 'not-a-uuid') + def test_save(self): uuid = self.fake_port['uuid'] - with mock.patch.object(self.dbapi, 'get_port', + with mock.patch.object(self.dbapi, 'get_port_by_uuid', autospec=True) as mock_get_port: mock_get_port.return_value = self.fake_port with mock.patch.object(self.dbapi, 'update_port', @@ -59,8 +84,9 @@ class TestPortObject(base.DbTestCase): returns = [self.fake_port, utils.get_test_port(address="c3:54:00:cf:2d:40")] expected = [mock.call(uuid), mock.call(uuid)] - with mock.patch.object(self.dbapi, 'get_port', side_effect=returns, - autospec=True) as mock_get_port: + with mock.patch.object(self.dbapi, 'get_port_by_uuid', + side_effect=returns, autospec=True) \ + as mock_get_port: p = objects.Port.get_by_uuid(self.context, uuid) self.assertEqual("52:54:00:cf:2d:31", p.address) p.refresh()