[svn r123] Fixed infinite redirect problem that Nat discovered.

This commit is contained in:
which.linden
2008-06-09 17:48:55 -07:00
parent cc2c94ab7b
commit ed9e9b61e0
2 changed files with 24 additions and 7 deletions

View File

@@ -589,13 +589,27 @@ class HttpSuite(object):
def head(self, *args, **kwargs):
return self.head_(*args, **kwargs)[-1]
def get_(self, url, headers=None, use_proxy=False, ok=None, aux=None):
def get_(self, url, headers=None, use_proxy=False, ok=None, aux=None, max_retries=8):
if headers is None:
headers = {}
headers['accept'] = self.fallback_content_type+';q=1,*/*;q=0'
return self.request_(_Params(url, 'GET', headers=headers,
loader=self.loader, dumper=self.dumper,
use_proxy=use_proxy, ok=ok, aux=aux))
def req():
return self.request_(_Params(url, 'GET', headers=headers,
loader=self.loader, dumper=self.dumper,
use_proxy=use_proxy, ok=ok, aux=aux))
def retry_response(err):
def doit():
return err.retry_()
return doit
retried = 0
while retried <= max_retries:
try:
return req()
except (Found, TemporaryRedirect, MovedPermanently, SeeOther), e:
if retried >= max_retries:
raise
retried += 1
req = retry_response(e)
def get(self, *args, **kwargs):
return self.get_(*args, **kwargs)[-1]

View File

@@ -263,11 +263,12 @@ class TestHttpc301(TestBase, tests.TestCase):
def test_get(self):
try:
httpc.get(self.base_url() + 'hello')
httpc.get(self.base_url() + 'hello', max_retries=0)
self.assert_(False)
except httpc.MovedPermanently, err:
response = err.retry()
self.assertEquals(response, 'hello world')
self.assertEquals(httpc.get(self.base_url() + 'hello', max_retries=1), 'hello world')
def test_post(self):
data = 'qunge'
@@ -284,19 +285,21 @@ class TestHttpc302(TestBase, tests.TestCase):
def test_get_expired(self):
try:
httpc.get(self.base_url() + 'expired/hello')
httpc.get(self.base_url() + 'expired/hello', max_retries=0)
self.assert_(False)
except httpc.Found, err:
response = err.retry()
self.assertEquals(response, 'hello world')
self.assertEquals(httpc.get(self.base_url() + 'expired/hello', max_retries=1), 'hello world')
def test_get_expires(self):
try:
httpc.get(self.base_url() + 'expires/hello')
httpc.get(self.base_url() + 'expires/hello', max_retries=0)
self.assert_(False)
except httpc.Found, err:
response = err.retry()
self.assertEquals(response, 'hello world')
self.assertEquals(httpc.get(self.base_url() + 'expires/hello', max_retries=1), 'hello world')
class TestHttpc303(TestBase, tests.TestCase):