From 02e50aa24e2fcef2ddc7d06a8f8522cc433be641 Mon Sep 17 00:00:00 2001 From: Hongbin Lu Date: Wed, 15 Feb 2017 20:55:57 -0600 Subject: [PATCH] Eagerly load resource_provider from foreignkey This allows inventory/allocatoin to reference 'resource_provider' field that were eagerly loaded from DB. Change-Id: I2eb1c9e27a7328b96036dd03476650abe0c5393e --- zun/db/sqlalchemy/api.py | 41 +++++++++++-------------- zun/tests/unit/db/test_allocation.py | 46 +++++++++++++++++++--------- zun/tests/unit/db/test_inventory.py | 42 +++++++++++++++++-------- 3 files changed, 80 insertions(+), 49 deletions(-) diff --git a/zun/db/sqlalchemy/api.py b/zun/db/sqlalchemy/api.py index 5f5444701..a96efd5a7 100644 --- a/zun/db/sqlalchemy/api.py +++ b/zun/db/sqlalchemy/api.py @@ -20,6 +20,7 @@ from oslo_db.sqlalchemy import utils as db_utils from oslo_utils import strutils from oslo_utils import timeutils from oslo_utils import uuidutils +from sqlalchemy.orm import contains_eager from sqlalchemy.orm.exc import MultipleResultsFound from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql import func @@ -535,8 +536,11 @@ class Connection(object): def list_inventories(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Inventory) + session = get_session() + query = model_query(models.Inventory, session=session) query = self._add_inventories_filters(query, filters) + query = query.join(models.Inventory.resource_provider) + query = query.options(contains_eager('resource_provider')) return _paginate_query(models.Inventory, limit, marker, sort_key, sort_dir, query) @@ -551,32 +555,17 @@ class Connection(object): raise exception.UniqueConstraintViolated(fields=fields) return inventory - def get_inventory(self, context, inventory_ident): - if strutils.is_int_like(inventory_ident): - return self._get_inventory_by_id(context, inventory_ident) - else: - return self._get_inventory_by_name(context, inventory_ident) - - def _get_inventory_by_id(self, context, inventory_id): - query = model_query(models.Inventory) + def get_inventory(self, context, inventory_id): + session = get_session() + query = model_query(models.Inventory, session=session) + query = query.join(models.Inventory.resource_provider) + query = query.options(contains_eager('resource_provider')) query = query.filter_by(id=inventory_id) try: return query.one() except NoResultFound: raise exception.InventoryNotFound(inventory=inventory_id) - def _get_inventory_by_name(self, context, inventory_name): - query = model_query(models.Inventory) - query = query.filter_by(name=inventory_name) - try: - return query.one() - except NoResultFound: - raise exception.InventoryNotFound(inventory=inventory_name) - except MultipleResultsFound: - raise exception.Conflict('Multiple inventories exist with same ' - 'name. Please use the inventory id ' - 'instead.') - def destroy_inventory(self, context, inventory_id): session = get_session() with session.begin(): @@ -613,8 +602,11 @@ class Connection(object): def list_allocations(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Allocation) + session = get_session() + query = model_query(models.Allocation, session=session) query = self._add_allocations_filters(query, filters) + query = query.join(models.Allocation.resource_provider) + query = query.options(contains_eager('resource_provider')) return _paginate_query(models.Allocation, limit, marker, sort_key, sort_dir, query) @@ -629,7 +621,10 @@ class Connection(object): return allocation def get_allocation(self, context, allocation_id): - query = model_query(models.Allocation) + session = get_session() + query = model_query(models.Allocation, session=session) + query = query.join(models.Allocation.resource_provider) + query = query.options(contains_eager('resource_provider')) query = query.filter_by(id=allocation_id) try: return query.one() diff --git a/zun/tests/unit/db/test_allocation.py b/zun/tests/unit/db/test_allocation.py index 9095c99d6..31792fa26 100644 --- a/zun/tests/unit/db/test_allocation.py +++ b/zun/tests/unit/db/test_allocation.py @@ -11,6 +11,7 @@ # under the License. from oslo_config import cfg +from oslo_utils import uuidutils from zun.common import exception import zun.conf @@ -31,7 +32,10 @@ class DbAllocationTestCase(base.DbTestCase): utils.create_test_allocation(context=self.context) def test_get_allocation_by_id(self): - allocation = utils.create_test_allocation(context=self.context) + provider = utils.create_test_resource_provider( + context=self.context) + allocation = utils.create_test_allocation( + resource_provider_id=provider.id, context=self.context) res = dbapi.get_allocation(self.context, allocation.id) self.assertEqual(allocation.id, res.id) @@ -43,29 +47,39 @@ class DbAllocationTestCase(base.DbTestCase): allocation_id) def test_list_allocations(self): - rcs = [] + cids = [] for i in range(1, 6): + provider = utils.create_test_resource_provider( + id=i, + uuid=uuidutils.generate_uuid(), + context=self.context) allocation = utils.create_test_allocation( id=i, - resource_class_id=i, + resource_provider_id=provider.id, + consumer_id=uuidutils.generate_uuid(), context=self.context) - rcs.append(allocation['resource_class_id']) + cids.append(allocation['consumer_id']) res = dbapi.list_allocations(self.context) - res_rcs = [r.resource_class_id for r in res] - self.assertEqual(sorted(rcs), sorted(res_rcs)) + res_cids = [r.consumer_id for r in res] + self.assertEqual(sorted(cids), sorted(res_cids)) def test_list_allocations_sorted(self): - rcs = [] + cids = [] for i in range(5): + provider = utils.create_test_resource_provider( + id=i, + uuid=uuidutils.generate_uuid(), + context=self.context) allocation = utils.create_test_allocation( id=i, - resource_class_id=i, + resource_provider_id=provider.id, + consumer_id=uuidutils.generate_uuid(), context=self.context) - rcs.append(allocation.resource_class_id) + cids.append(allocation['consumer_id']) res = dbapi.list_allocations(self.context, - sort_key='resource_class_id') - res_rcs = [r.resource_class_id for r in res] - self.assertEqual(sorted(rcs), res_rcs) + sort_key='consumer_id') + res_cids = [r.consumer_id for r in res] + self.assertEqual(sorted(cids), res_cids) self.assertRaises(exception.InvalidParameterValue, dbapi.list_allocations, @@ -73,13 +87,17 @@ class DbAllocationTestCase(base.DbTestCase): sort_key='foo') def test_list_allocations_with_filters(self): + provider = utils.create_test_resource_provider( + id=1, + uuid=uuidutils.generate_uuid(), + context=self.context) allocation1 = utils.create_test_allocation( used=0, - resource_class_id=1, + resource_provider_id=provider.id, context=self.context) allocation2 = utils.create_test_allocation( used=1, - resource_class_id=2, + resource_provider_id=provider.id, context=self.context) res = dbapi.list_allocations( diff --git a/zun/tests/unit/db/test_inventory.py b/zun/tests/unit/db/test_inventory.py index 054124170..9917057cb 100644 --- a/zun/tests/unit/db/test_inventory.py +++ b/zun/tests/unit/db/test_inventory.py @@ -11,6 +11,7 @@ # under the License. from oslo_config import cfg +from oslo_utils import uuidutils from zun.common import exception import zun.conf @@ -41,7 +42,10 @@ class DbInventoryTestCase(base.DbTestCase): resource_class_id=1) def test_get_inventory_by_id(self): - inventory = utils.create_test_inventory(context=self.context) + provider = utils.create_test_resource_provider( + context=self.context) + inventory = utils.create_test_inventory( + resource_provider_id=provider.id, context=self.context) res = dbapi.get_inventory(self.context, inventory.id) self.assertEqual(inventory.id, res.id) @@ -53,29 +57,39 @@ class DbInventoryTestCase(base.DbTestCase): inventory_id) def test_list_inventories(self): - rcs = [] + totals = [] for i in range(1, 6): + provider = utils.create_test_resource_provider( + id=i, + uuid=uuidutils.generate_uuid(), + context=self.context) inventory = utils.create_test_inventory( id=i, - resource_class_id=i, + resource_provider_id=provider.id, + total=i, context=self.context) - rcs.append(inventory['resource_class_id']) + totals.append(inventory['total']) res = dbapi.list_inventories(self.context) - res_rcs = [r.resource_class_id for r in res] - self.assertEqual(sorted(rcs), sorted(res_rcs)) + res_totals = [r.total for r in res] + self.assertEqual(sorted(totals), sorted(res_totals)) def test_list_inventories_sorted(self): - rcs = [] + totals = [] for i in range(5): + provider = utils.create_test_resource_provider( + id=i, + uuid=uuidutils.generate_uuid(), + context=self.context) inventory = utils.create_test_inventory( id=i, - resource_class_id=i, + resource_provider_id=provider.id, + total=10 - i, context=self.context) - rcs.append(inventory.resource_class_id) + totals.append(inventory['total']) res = dbapi.list_inventories(self.context, - sort_key='resource_class_id') - res_rcs = [r.resource_class_id for r in res] - self.assertEqual(sorted(rcs), res_rcs) + sort_key='total') + res_totals = [r.total for r in res] + self.assertEqual(sorted(totals), res_totals) self.assertRaises(exception.InvalidParameterValue, dbapi.list_inventories, @@ -83,12 +97,16 @@ class DbInventoryTestCase(base.DbTestCase): sort_key='foo') def test_list_inventories_with_filters(self): + provider = utils.create_test_resource_provider( + context=self.context) inventory1 = utils.create_test_inventory( total=10, + resource_provider_id=provider.id, resource_class_id=1, context=self.context) inventory2 = utils.create_test_inventory( total=20, + resource_provider_id=provider.id, resource_class_id=2, context=self.context)