diff --git a/oauth2client/client.py b/oauth2client/client.py index 0d59dea..1d6cf42 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -975,38 +975,46 @@ def _detect_gce_environment(urlopen=None): return False -def _get_environment(urlopen=None): - """Detect the environment the code is being run on. +def _in_gae_environment(): + """Detects if the code is running in the App Engine environment. + + Returns: + True if running in the GAE environment, False otherwise. + """ + if SETTINGS.env_name is not None: + return SETTINGS.env_name in ('GAE_PRODUCTION', 'GAE_LOCAL') + + try: + import google.appengine + server_software = os.environ.get('SERVER_SOFTWARE', '') + if server_software.startswith('Google App Engine/'): + SETTINGS.env_name = 'GAE_PRODUCTION' + return True + elif server_software.startswith('Development/'): + SETTINGS.env_name = 'GAE_LOCAL' + return True + except ImportError: + pass + + return False + + +def _in_gce_environment(urlopen=None): + """Detect if the code is running in the Compute Engine environment. Args: urlopen: Optional argument. Function used to open a connection to a URL. Returns: - The value of SETTINGS.env_name after being set. If already - set, simply returns the value. + True if running in the GCE environment, False otherwise. """ if SETTINGS.env_name is not None: - return SETTINGS.env_name + return SETTINGS.env_name == 'GCE_PRODUCTION' - # None is an unset value, not the default. - SETTINGS.env_name = DEFAULT_ENV_NAME - - try: - import google.appengine - has_gae_sdk = True - except ImportError: - has_gae_sdk = False - - if has_gae_sdk: - server_software = os.environ.get('SERVER_SOFTWARE', '') - if server_software.startswith('Google App Engine/'): - SETTINGS.env_name = 'GAE_PRODUCTION' - elif server_software.startswith('Development/'): - SETTINGS.env_name = 'GAE_LOCAL' - elif NO_GCE_CHECK != 'True' and _detect_gce_environment(urlopen=urlopen): + if NO_GCE_CHECK != 'True' and _detect_gce_environment(urlopen=urlopen): SETTINGS.env_name = 'GCE_PRODUCTION' - - return SETTINGS.env_name + return True + return False class GoogleCredentials(OAuth2Credentials): @@ -1085,56 +1093,45 @@ class GoogleCredentials(OAuth2Credentials): } @staticmethod - def _implicit_credentials_from_gae(env_name=None): + def _implicit_credentials_from_gae(): """Attempts to get implicit credentials in Google App Engine env. If the current environment is not detected as App Engine, returns None, indicating no Google App Engine credentials can be detected from the current environment. - Args: - env_name: String, indicating current environment. - Returns: None, if not in GAE, else an appengine.AppAssertionCredentials object. """ - env_name = env_name or _get_environment() - if env_name not in ('GAE_PRODUCTION', 'GAE_LOCAL'): + if not _in_gae_environment(): return None return _get_application_default_credential_GAE() @staticmethod - def _implicit_credentials_from_gce(env_name=None): + def _implicit_credentials_from_gce(): """Attempts to get implicit credentials in Google Compute Engine env. If the current environment is not detected as Compute Engine, returns None, indicating no Google Compute Engine credentials can be detected from the current environment. - Args: - env_name: String, indicating current environment. - Returns: None, if not in GCE, else a gce.AppAssertionCredentials object. """ - env_name = env_name or _get_environment() - if env_name != 'GCE_PRODUCTION': + if not _in_gce_environment(): return None return _get_application_default_credential_GCE() @staticmethod - def _implicit_credentials_from_files(env_name=None): + def _implicit_credentials_from_files(): """Attempts to get implicit credentials from local credential files. First checks if the environment variable GOOGLE_APPLICATION_CREDENTIALS is set with a filename and then falls back to a configuration file (the "well known" file) associated with the 'gcloud' command line tool. - Args: - env_name: Unused argument. - Returns: Credentials object associated with the GOOGLE_APPLICATION_CREDENTIALS file or the "well known" file if either exist. If neither file is @@ -1156,6 +1153,10 @@ class GoogleCredentials(OAuth2Credentials): if not credentials_filename: return + # If we can read the credentials from a file, we don't need to know what + # environment we are in. + SETTINGS.env_name = DEFAULT_ENV_NAME + try: return _get_application_default_credential_from_file(credentials_filename) except (ApplicationDefaultCredentialsError, ValueError) as error: @@ -1176,10 +1177,8 @@ class GoogleCredentials(OAuth2Credentials): ApplicationDefaultCredentialsError: raised when the credentials fail to be retrieved. """ - env_name = _get_environment() - # Environ checks (in order). Assumes each checker takes `env_name` - # as a kwarg. + # Environ checks (in order). environ_checkers = [ cls._implicit_credentials_from_gae, cls._implicit_credentials_from_files, @@ -1187,7 +1186,7 @@ class GoogleCredentials(OAuth2Credentials): ] for checker in environ_checkers: - credentials = checker(env_name=env_name) + credentials = checker() if credentials is not None: return credentials diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py index 49cbdb4..7887a0f 100644 --- a/tests/test_oauth2client.py +++ b/tests/test_oauth2client.py @@ -27,6 +27,7 @@ import contextlib import datetime import json import os +import socket import sys import time import unittest @@ -65,9 +66,10 @@ from oauth2client.client import TokenRevokeError from oauth2client.client import VerifyJwtTokenError from oauth2client.client import _extract_id_token from oauth2client.client import _get_application_default_credential_from_file -from oauth2client.client import _get_environment from oauth2client.client import _get_environment_variable_file from oauth2client.client import _get_well_known_file +from oauth2client.client import _in_gae_environment +from oauth2client.client import _in_gce_environment from oauth2client.client import _raise_exception_for_missing_fields from oauth2client.client import _raise_exception_for_reading_json from oauth2client.client import _update_query_params @@ -218,32 +220,97 @@ class GoogleCredentialsTests(unittest.TestCase): self.assertEqual(credentials, credentials.create_scoped(['dummy_scope'])) - def test_get_environment_gae_production(self): + def test_environment_check_gae_production(self): with mock_module_import('google.appengine'): os.environ['SERVER_SOFTWARE'] = 'Google App Engine/XYZ' - self.assertEqual('GAE_PRODUCTION', _get_environment()) + self.assertTrue(_in_gae_environment()) + self.assertFalse(_in_gce_environment()) - def test_get_environment_gae_local(self): + def test_environment_check_gae_local(self): with mock_module_import('google.appengine'): os.environ['SERVER_SOFTWARE'] = 'Development/XYZ' - self.assertEqual('GAE_LOCAL', _get_environment()) + self.assertTrue(_in_gae_environment()) + self.assertFalse(_in_gce_environment()) - def test_get_environment_gce_production(self): + def test_environment_check_fastpath(self): + os.environ['SERVER_SOFTWARE'] = 'Development/XYZ' + with mock_module_import('google.appengine'): + with mock.patch.object(urllib.request, 'urlopen', + return_value=MockResponse({}), + autospec=True) as urlopen: + self.assertTrue(_in_gae_environment()) + self.assertFalse(_in_gce_environment()) + # We already know are in GAE, so we shouldn't actually do the urlopen. + self.assertFalse(urlopen.called) + + def test_environment_caching(self): + os.environ['SERVER_SOFTWARE'] = 'Development/XYZ' + with mock_module_import('google.appengine'): + self.assertTrue(_in_gae_environment()) + os.environ['SERVER_SOFTWARE'] = '' + # Even though we no longer pass the environment check, it is cached. + self.assertTrue(_in_gae_environment()) + + def test_environment_check_gae_module_on_gce(self): + with mock_module_import('google.appengine'): + os.environ['SERVER_SOFTWARE'] = '' + response = MockResponse({'Metadata-Flavor': 'Google'}) + with mock.patch.object(urllib.request, 'urlopen', + return_value=response, + autospec=True) as urlopen: + self.assertFalse(_in_gae_environment()) + self.assertTrue(_in_gce_environment()) + urlopen.assert_called_once_with( + 'http://169.254.169.254/', timeout=1) + + def test_environment_check_gae_module_unknown(self): + with mock_module_import('google.appengine'): + os.environ['SERVER_SOFTWARE'] = '' + with mock.patch.object(urllib.request, 'urlopen', + return_value=MockResponse({}), + autospec=True) as urlopen: + self.assertFalse(_in_gae_environment()) + self.assertFalse(_in_gce_environment()) + urlopen.assert_called_once_with( + 'http://169.254.169.254/', timeout=1) + + def test_environment_check_gce_production(self): os.environ['SERVER_SOFTWARE'] = '' response = MockResponse({'Metadata-Flavor': 'Google'}) with mock.patch.object(urllib.request, 'urlopen', return_value=response, autospec=True) as urlopen: - self.assertEqual('GCE_PRODUCTION', _get_environment()) + self.assertFalse(_in_gae_environment()) + self.assertTrue(_in_gce_environment()) urlopen.assert_called_once_with( 'http://169.254.169.254/', timeout=1) - def test_get_environment_unknown(self): + def test_environment_check_gce_timeout(self): + os.environ['SERVER_SOFTWARE'] = '' + response = MockResponse({'Metadata-Flavor': 'Google'}) + with mock.patch.object(urllib.request, 'urlopen', + return_value=response, + autospec=True) as urlopen: + urlopen.side_effect = socket.timeout() + self.assertFalse(_in_gce_environment()) + urlopen.assert_called_once_with( + 'http://169.254.169.254/', timeout=1) + + with mock.patch.object(urllib.request, 'urlopen', + return_value=response, + autospec=True) as urlopen: + urlopen.side_effect = urllib.error.URLError(socket.timeout()) + self.assertFalse(_in_gce_environment()) + urlopen.assert_called_once_with( + 'http://169.254.169.254/', timeout=1) + + def test_environment_check_unknown(self): os.environ['SERVER_SOFTWARE'] = '' with mock.patch.object(urllib.request, 'urlopen', return_value=MockResponse({}), autospec=True) as urlopen: - self.assertEqual(DEFAULT_ENV_NAME, _get_environment()) + self.assertFalse(_in_gce_environment()) + self.assertFalse(_in_gae_environment()) urlopen.assert_called_once_with( 'http://169.254.169.254/', timeout=1)