diff --git a/os_collect_config/ec2.py b/os_collect_config/ec2.py index b4970c8..ab57458 100644 --- a/os_collect_config/ec2.py +++ b/os_collect_config/ec2.py @@ -29,10 +29,9 @@ opts = [ ] -def _fetch_metadata(fetch_url): - global h +def _fetch_metadata(fetch_url, session): try: - r = requests.get(fetch_url) + r = session.get(fetch_url) r.raise_for_status() except (requests.HTTPError, requests.ConnectionError, @@ -48,11 +47,12 @@ def _fetch_metadata(fetch_url): sub_fetch_url = fetch_url + subkey if subkey[-1] == '/': subkey = subkey[:-1] - new_content[subkey] = _fetch_metadata(sub_fetch_url) + new_content[subkey] = _fetch_metadata(sub_fetch_url, session) content = new_content return content def collect(): root_url = '%s/' % (CONF.ec2.metadata_url) - return _fetch_metadata(root_url) + session = requests.Session() + return _fetch_metadata(root_url, session) diff --git a/os_collect_config/tests/test_collect.py b/os_collect_config/tests/test_collect.py index 2a23f1d..2b032dc 100644 --- a/os_collect_config/tests/test_collect.py +++ b/os_collect_config/tests/test_collect.py @@ -29,7 +29,7 @@ class TestCollect(testtools.TestCase): super(TestCollect, self).setUp() self.useFixture( fixtures.MonkeyPatch( - 'requests.get', test_ec2.fake_get)) + 'requests.Session', test_ec2.FakeSession)) def tearDown(self): super(TestCollect, self).tearDown() diff --git a/os_collect_config/tests/test_ec2.py b/os_collect_config/tests/test_ec2.py index 663e0f4..fc1502f 100644 --- a/os_collect_config/tests/test_ec2.py +++ b/os_collect_config/tests/test_ec2.py @@ -58,22 +58,24 @@ class FakeResponse(dict): pass -def fake_get(url): - url = urlparse.urlparse(url) +class FakeSession(object): + def get(self, url): + url = urlparse.urlparse(url) - if url.path == '/latest/meta-data/': - # Remove keys which have anything after / - ks = [x for x in META_DATA.keys() if ('/' not in x - or not len(x.split('/')[1]))] - return FakeResponse("\n".join(ks)) + if url.path == '/latest/meta-data/': + # Remove keys which have anything after / + ks = [x for x in META_DATA.keys() if ('/' not in x + or not len(x.split('/')[1]))] + return FakeResponse("\n".join(ks)) - path = url.path - path = path.replace('/latest/meta-data/', '') - return FakeResponse(META_DATA[path]) + path = url.path + path = path.replace('/latest/meta-data/', '') + return FakeResponse(META_DATA[path]) -def fake_fail_get(url): - raise requests.exceptions.HTTPError(403, 'Forbidden') +class FakeFailSession(object): + def get(self, url): + raise requests.exceptions.HTTPError(403, 'Forbidden') class TestCollect(testtools.TestCase): @@ -83,7 +85,7 @@ class TestCollect(testtools.TestCase): def test_collect_ec2(self): self.useFixture( - fixtures.MonkeyPatch('requests.get', fake_get)) + fixtures.MonkeyPatch('requests.Session', FakeSession)) collect.setup_conf() ec2_md = ec2.collect() self.assertThat(ec2_md, matchers.IsInstance(dict)) @@ -103,7 +105,7 @@ class TestCollect(testtools.TestCase): def test_collect_ec2_fail(self): self.useFixture( fixtures.MonkeyPatch( - 'requests.get', fake_fail_get)) + 'requests.Session', FakeFailSession)) collect.setup_conf() self.assertRaises(exc.Ec2MetadataNotAvailable, ec2.collect) self.assertIn('Forbidden', self.log.output)