feat(Request): Add a 'default' kwarg to get_header() (#988)

This commit is contained in:
Kurt Griffiths
2017-01-26 15:50:50 -07:00
committed by John Vrbanac
parent ca79d36c9c
commit 93a315baac
2 changed files with 10 additions and 4 deletions

View File

@@ -778,7 +778,7 @@ class Request(object):
return (preferred_type if preferred_type else None) return (preferred_type if preferred_type else None)
def get_header(self, name, required=False): def get_header(self, name, required=False, default=None):
"""Retrieve the raw string value for the given header. """Retrieve the raw string value for the given header.
Args: Args:
@@ -786,10 +786,13 @@ class Request(object):
required (bool, optional): Set to ``True`` to raise required (bool, optional): Set to ``True`` to raise
``HTTPBadRequest`` instead of returning gracefully when the ``HTTPBadRequest`` instead of returning gracefully when the
header is not found (default ``False``). header is not found (default ``False``).
default (any, optional): Value to return if the header
is not found (default ``None``).
Returns: Returns:
str: The value of the specified header if it exists, or ``None`` if str: The value of the specified header if it exists, or
the header is not found and is not required. the default value if the header is not found and is not
required.
Raises: Raises:
HTTPBadRequest: The header was not found in the request, but HTTPBadRequest: The header was not found in the request, but
@@ -818,7 +821,7 @@ class Request(object):
pass pass
if not required: if not required:
return None return default
raise errors.HTTPMissingHeader(name) raise errors.HTTPMissingHeader(name)

View File

@@ -186,6 +186,9 @@ class TestHeaders(testing.TestCase):
value = req.get_header('X-Not-Found') or '876' value = req.get_header('X-Not-Found') or '876'
self.assertEqual(value, '876') self.assertEqual(value, '876')
value = req.get_header('X-Not-Found', default='some-value')
self.assertEqual(value, 'some-value')
def test_required_header(self): def test_required_header(self):
self.simulate_get() self.simulate_get()