diff --git a/tricircle/db/api.py b/tricircle/db/api.py index 3d156ce0..66e198f0 100644 --- a/tricircle/db/api.py +++ b/tricircle/db/api.py @@ -311,7 +311,8 @@ def find_pod_by_az_or_region(context, az_or_region): # if the pods list contains more than one pod, then we will raise an # exception if len(pods) > 1: - raise Exception('Multiple pods with the same az_name are found') + raise exceptions.InvalidInput( + reason='Multiple pods with the same az_name are found') def new_job(context, _type, resource_id): diff --git a/tricircle/tests/unit/db/test_api.py b/tricircle/tests/unit/db/test_api.py index 7b55eeb3..62b8d168 100644 --- a/tricircle/tests/unit/db/test_api.py +++ b/tricircle/tests/unit/db/test_api.py @@ -17,6 +17,7 @@ from six.moves import xrange import unittest from tricircle.common import context +from tricircle.common import exceptions from tricircle.db import api from tricircle.db import core @@ -29,6 +30,15 @@ class APITest(unittest.TestCase): core.ModelBase.metadata.create_all(core.get_engine()) self.context = context.Context() + def _create_pod(self, index, test_az_uuid): + pod_body = {'pod_id': 'test_pod_uuid_%d' % index, + 'region_name': 'test_pod_%d' % index, + 'pod_az_name': 'test_pod_az_name_%d' % index, + 'dc_name': 'test_dc_name_%d' % index, + 'az_name': test_az_uuid, + } + api.create_pod(self.context, pod_body) + def test_get_bottom_mappings_by_top_id(self): for i in xrange(3): pod = {'pod_id': 'test_pod_uuid_%d' % i, @@ -180,5 +190,32 @@ class APITest(unittest.TestCase): self.context, current_pod_id='test_pod_uuid_4') self.assertIsNone(next_pod) + def test_find_pod_by_az_or_region(self): + self._create_pod(0, 'test_az_uuid1') + self._create_pod(1, 'test_az_uuid1') + self._create_pod(2, 'test_az_uuid2') + + az_region = None + pod = api.find_pod_by_az_or_region(self.context, az_region) + self.assertIsNone(pod) + + az_region = 'test_pod_3' + self.assertRaises(exceptions.PodNotFound, + api.find_pod_by_az_or_region, + self.context, az_region) + + az_region = 'test_pod_0' + pod = api.find_pod_by_az_or_region(self.context, az_region) + self.assertEqual(pod['region_name'], az_region) + + az_region = 'test_az_uuid2' + pod = api.find_pod_by_az_or_region(self.context, az_region) + self.assertEqual(pod['az_name'], az_region) + + az_region = 'test_az_uuid1' + self.assertRaises(exceptions.InvalidInput, + api.find_pod_by_az_or_region, + self.context, az_region) + def tearDown(self): core.ModelBase.metadata.drop_all(core.get_engine())