Eagerly load resource_provider from foreignkey
This allows inventory/allocatoin to reference 'resource_provider' field that were eagerly loaded from DB. Change-Id: I2eb1c9e27a7328b96036dd03476650abe0c5393e
This commit is contained in:
		@@ -20,6 +20,7 @@ from oslo_db.sqlalchemy import utils as db_utils
 | 
			
		||||
from oslo_utils import strutils
 | 
			
		||||
from oslo_utils import timeutils
 | 
			
		||||
from oslo_utils import uuidutils
 | 
			
		||||
from sqlalchemy.orm import contains_eager
 | 
			
		||||
from sqlalchemy.orm.exc import MultipleResultsFound
 | 
			
		||||
from sqlalchemy.orm.exc import NoResultFound
 | 
			
		||||
from sqlalchemy.sql import func
 | 
			
		||||
@@ -535,8 +536,11 @@ class Connection(object):
 | 
			
		||||
 | 
			
		||||
    def list_inventories(self, context, filters=None, limit=None,
 | 
			
		||||
                         marker=None, sort_key=None, sort_dir=None):
 | 
			
		||||
        query = model_query(models.Inventory)
 | 
			
		||||
        session = get_session()
 | 
			
		||||
        query = model_query(models.Inventory, session=session)
 | 
			
		||||
        query = self._add_inventories_filters(query, filters)
 | 
			
		||||
        query = query.join(models.Inventory.resource_provider)
 | 
			
		||||
        query = query.options(contains_eager('resource_provider'))
 | 
			
		||||
        return _paginate_query(models.Inventory, limit, marker,
 | 
			
		||||
                               sort_key, sort_dir, query)
 | 
			
		||||
 | 
			
		||||
@@ -551,32 +555,17 @@ class Connection(object):
 | 
			
		||||
            raise exception.UniqueConstraintViolated(fields=fields)
 | 
			
		||||
        return inventory
 | 
			
		||||
 | 
			
		||||
    def get_inventory(self, context, inventory_ident):
 | 
			
		||||
        if strutils.is_int_like(inventory_ident):
 | 
			
		||||
            return self._get_inventory_by_id(context, inventory_ident)
 | 
			
		||||
        else:
 | 
			
		||||
            return self._get_inventory_by_name(context, inventory_ident)
 | 
			
		||||
 | 
			
		||||
    def _get_inventory_by_id(self, context, inventory_id):
 | 
			
		||||
        query = model_query(models.Inventory)
 | 
			
		||||
    def get_inventory(self, context, inventory_id):
 | 
			
		||||
        session = get_session()
 | 
			
		||||
        query = model_query(models.Inventory, session=session)
 | 
			
		||||
        query = query.join(models.Inventory.resource_provider)
 | 
			
		||||
        query = query.options(contains_eager('resource_provider'))
 | 
			
		||||
        query = query.filter_by(id=inventory_id)
 | 
			
		||||
        try:
 | 
			
		||||
            return query.one()
 | 
			
		||||
        except NoResultFound:
 | 
			
		||||
            raise exception.InventoryNotFound(inventory=inventory_id)
 | 
			
		||||
 | 
			
		||||
    def _get_inventory_by_name(self, context, inventory_name):
 | 
			
		||||
        query = model_query(models.Inventory)
 | 
			
		||||
        query = query.filter_by(name=inventory_name)
 | 
			
		||||
        try:
 | 
			
		||||
            return query.one()
 | 
			
		||||
        except NoResultFound:
 | 
			
		||||
            raise exception.InventoryNotFound(inventory=inventory_name)
 | 
			
		||||
        except MultipleResultsFound:
 | 
			
		||||
            raise exception.Conflict('Multiple inventories exist with same '
 | 
			
		||||
                                     'name. Please use the inventory id '
 | 
			
		||||
                                     'instead.')
 | 
			
		||||
 | 
			
		||||
    def destroy_inventory(self, context, inventory_id):
 | 
			
		||||
        session = get_session()
 | 
			
		||||
        with session.begin():
 | 
			
		||||
@@ -613,8 +602,11 @@ class Connection(object):
 | 
			
		||||
 | 
			
		||||
    def list_allocations(self, context, filters=None, limit=None,
 | 
			
		||||
                         marker=None, sort_key=None, sort_dir=None):
 | 
			
		||||
        query = model_query(models.Allocation)
 | 
			
		||||
        session = get_session()
 | 
			
		||||
        query = model_query(models.Allocation, session=session)
 | 
			
		||||
        query = self._add_allocations_filters(query, filters)
 | 
			
		||||
        query = query.join(models.Allocation.resource_provider)
 | 
			
		||||
        query = query.options(contains_eager('resource_provider'))
 | 
			
		||||
        return _paginate_query(models.Allocation, limit, marker,
 | 
			
		||||
                               sort_key, sort_dir, query)
 | 
			
		||||
 | 
			
		||||
@@ -629,7 +621,10 @@ class Connection(object):
 | 
			
		||||
        return allocation
 | 
			
		||||
 | 
			
		||||
    def get_allocation(self, context, allocation_id):
 | 
			
		||||
        query = model_query(models.Allocation)
 | 
			
		||||
        session = get_session()
 | 
			
		||||
        query = model_query(models.Allocation, session=session)
 | 
			
		||||
        query = query.join(models.Allocation.resource_provider)
 | 
			
		||||
        query = query.options(contains_eager('resource_provider'))
 | 
			
		||||
        query = query.filter_by(id=allocation_id)
 | 
			
		||||
        try:
 | 
			
		||||
            return query.one()
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@
 | 
			
		||||
#    under the License.
 | 
			
		||||
 | 
			
		||||
from oslo_config import cfg
 | 
			
		||||
from oslo_utils import uuidutils
 | 
			
		||||
 | 
			
		||||
from zun.common import exception
 | 
			
		||||
import zun.conf
 | 
			
		||||
@@ -31,7 +32,10 @@ class DbAllocationTestCase(base.DbTestCase):
 | 
			
		||||
        utils.create_test_allocation(context=self.context)
 | 
			
		||||
 | 
			
		||||
    def test_get_allocation_by_id(self):
 | 
			
		||||
        allocation = utils.create_test_allocation(context=self.context)
 | 
			
		||||
        provider = utils.create_test_resource_provider(
 | 
			
		||||
            context=self.context)
 | 
			
		||||
        allocation = utils.create_test_allocation(
 | 
			
		||||
            resource_provider_id=provider.id, context=self.context)
 | 
			
		||||
        res = dbapi.get_allocation(self.context, allocation.id)
 | 
			
		||||
        self.assertEqual(allocation.id, res.id)
 | 
			
		||||
 | 
			
		||||
@@ -43,29 +47,39 @@ class DbAllocationTestCase(base.DbTestCase):
 | 
			
		||||
                          allocation_id)
 | 
			
		||||
 | 
			
		||||
    def test_list_allocations(self):
 | 
			
		||||
        rcs = []
 | 
			
		||||
        cids = []
 | 
			
		||||
        for i in range(1, 6):
 | 
			
		||||
            provider = utils.create_test_resource_provider(
 | 
			
		||||
                id=i,
 | 
			
		||||
                uuid=uuidutils.generate_uuid(),
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            allocation = utils.create_test_allocation(
 | 
			
		||||
                id=i,
 | 
			
		||||
                resource_class_id=i,
 | 
			
		||||
                resource_provider_id=provider.id,
 | 
			
		||||
                consumer_id=uuidutils.generate_uuid(),
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            rcs.append(allocation['resource_class_id'])
 | 
			
		||||
            cids.append(allocation['consumer_id'])
 | 
			
		||||
        res = dbapi.list_allocations(self.context)
 | 
			
		||||
        res_rcs = [r.resource_class_id for r in res]
 | 
			
		||||
        self.assertEqual(sorted(rcs), sorted(res_rcs))
 | 
			
		||||
        res_cids = [r.consumer_id for r in res]
 | 
			
		||||
        self.assertEqual(sorted(cids), sorted(res_cids))
 | 
			
		||||
 | 
			
		||||
    def test_list_allocations_sorted(self):
 | 
			
		||||
        rcs = []
 | 
			
		||||
        cids = []
 | 
			
		||||
        for i in range(5):
 | 
			
		||||
            provider = utils.create_test_resource_provider(
 | 
			
		||||
                id=i,
 | 
			
		||||
                uuid=uuidutils.generate_uuid(),
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            allocation = utils.create_test_allocation(
 | 
			
		||||
                id=i,
 | 
			
		||||
                resource_class_id=i,
 | 
			
		||||
                resource_provider_id=provider.id,
 | 
			
		||||
                consumer_id=uuidutils.generate_uuid(),
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            rcs.append(allocation.resource_class_id)
 | 
			
		||||
            cids.append(allocation['consumer_id'])
 | 
			
		||||
        res = dbapi.list_allocations(self.context,
 | 
			
		||||
                                     sort_key='resource_class_id')
 | 
			
		||||
        res_rcs = [r.resource_class_id for r in res]
 | 
			
		||||
        self.assertEqual(sorted(rcs), res_rcs)
 | 
			
		||||
                                     sort_key='consumer_id')
 | 
			
		||||
        res_cids = [r.consumer_id for r in res]
 | 
			
		||||
        self.assertEqual(sorted(cids), res_cids)
 | 
			
		||||
 | 
			
		||||
        self.assertRaises(exception.InvalidParameterValue,
 | 
			
		||||
                          dbapi.list_allocations,
 | 
			
		||||
@@ -73,13 +87,17 @@ class DbAllocationTestCase(base.DbTestCase):
 | 
			
		||||
                          sort_key='foo')
 | 
			
		||||
 | 
			
		||||
    def test_list_allocations_with_filters(self):
 | 
			
		||||
        provider = utils.create_test_resource_provider(
 | 
			
		||||
            id=1,
 | 
			
		||||
            uuid=uuidutils.generate_uuid(),
 | 
			
		||||
            context=self.context)
 | 
			
		||||
        allocation1 = utils.create_test_allocation(
 | 
			
		||||
            used=0,
 | 
			
		||||
            resource_class_id=1,
 | 
			
		||||
            resource_provider_id=provider.id,
 | 
			
		||||
            context=self.context)
 | 
			
		||||
        allocation2 = utils.create_test_allocation(
 | 
			
		||||
            used=1,
 | 
			
		||||
            resource_class_id=2,
 | 
			
		||||
            resource_provider_id=provider.id,
 | 
			
		||||
            context=self.context)
 | 
			
		||||
 | 
			
		||||
        res = dbapi.list_allocations(
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@
 | 
			
		||||
#    under the License.
 | 
			
		||||
 | 
			
		||||
from oslo_config import cfg
 | 
			
		||||
from oslo_utils import uuidutils
 | 
			
		||||
 | 
			
		||||
from zun.common import exception
 | 
			
		||||
import zun.conf
 | 
			
		||||
@@ -41,7 +42,10 @@ class DbInventoryTestCase(base.DbTestCase):
 | 
			
		||||
                resource_class_id=1)
 | 
			
		||||
 | 
			
		||||
    def test_get_inventory_by_id(self):
 | 
			
		||||
        inventory = utils.create_test_inventory(context=self.context)
 | 
			
		||||
        provider = utils.create_test_resource_provider(
 | 
			
		||||
            context=self.context)
 | 
			
		||||
        inventory = utils.create_test_inventory(
 | 
			
		||||
            resource_provider_id=provider.id, context=self.context)
 | 
			
		||||
        res = dbapi.get_inventory(self.context, inventory.id)
 | 
			
		||||
        self.assertEqual(inventory.id, res.id)
 | 
			
		||||
 | 
			
		||||
@@ -53,29 +57,39 @@ class DbInventoryTestCase(base.DbTestCase):
 | 
			
		||||
                          inventory_id)
 | 
			
		||||
 | 
			
		||||
    def test_list_inventories(self):
 | 
			
		||||
        rcs = []
 | 
			
		||||
        totals = []
 | 
			
		||||
        for i in range(1, 6):
 | 
			
		||||
            provider = utils.create_test_resource_provider(
 | 
			
		||||
                id=i,
 | 
			
		||||
                uuid=uuidutils.generate_uuid(),
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            inventory = utils.create_test_inventory(
 | 
			
		||||
                id=i,
 | 
			
		||||
                resource_class_id=i,
 | 
			
		||||
                resource_provider_id=provider.id,
 | 
			
		||||
                total=i,
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            rcs.append(inventory['resource_class_id'])
 | 
			
		||||
            totals.append(inventory['total'])
 | 
			
		||||
        res = dbapi.list_inventories(self.context)
 | 
			
		||||
        res_rcs = [r.resource_class_id for r in res]
 | 
			
		||||
        self.assertEqual(sorted(rcs), sorted(res_rcs))
 | 
			
		||||
        res_totals = [r.total for r in res]
 | 
			
		||||
        self.assertEqual(sorted(totals), sorted(res_totals))
 | 
			
		||||
 | 
			
		||||
    def test_list_inventories_sorted(self):
 | 
			
		||||
        rcs = []
 | 
			
		||||
        totals = []
 | 
			
		||||
        for i in range(5):
 | 
			
		||||
            provider = utils.create_test_resource_provider(
 | 
			
		||||
                id=i,
 | 
			
		||||
                uuid=uuidutils.generate_uuid(),
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            inventory = utils.create_test_inventory(
 | 
			
		||||
                id=i,
 | 
			
		||||
                resource_class_id=i,
 | 
			
		||||
                resource_provider_id=provider.id,
 | 
			
		||||
                total=10 - i,
 | 
			
		||||
                context=self.context)
 | 
			
		||||
            rcs.append(inventory.resource_class_id)
 | 
			
		||||
            totals.append(inventory['total'])
 | 
			
		||||
        res = dbapi.list_inventories(self.context,
 | 
			
		||||
                                     sort_key='resource_class_id')
 | 
			
		||||
        res_rcs = [r.resource_class_id for r in res]
 | 
			
		||||
        self.assertEqual(sorted(rcs), res_rcs)
 | 
			
		||||
                                     sort_key='total')
 | 
			
		||||
        res_totals = [r.total for r in res]
 | 
			
		||||
        self.assertEqual(sorted(totals), res_totals)
 | 
			
		||||
 | 
			
		||||
        self.assertRaises(exception.InvalidParameterValue,
 | 
			
		||||
                          dbapi.list_inventories,
 | 
			
		||||
@@ -83,12 +97,16 @@ class DbInventoryTestCase(base.DbTestCase):
 | 
			
		||||
                          sort_key='foo')
 | 
			
		||||
 | 
			
		||||
    def test_list_inventories_with_filters(self):
 | 
			
		||||
        provider = utils.create_test_resource_provider(
 | 
			
		||||
            context=self.context)
 | 
			
		||||
        inventory1 = utils.create_test_inventory(
 | 
			
		||||
            total=10,
 | 
			
		||||
            resource_provider_id=provider.id,
 | 
			
		||||
            resource_class_id=1,
 | 
			
		||||
            context=self.context)
 | 
			
		||||
        inventory2 = utils.create_test_inventory(
 | 
			
		||||
            total=20,
 | 
			
		||||
            resource_provider_id=provider.id,
 | 
			
		||||
            resource_class_id=2,
 | 
			
		||||
            context=self.context)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user