diff --git a/falcon/api.py b/falcon/api.py index 00f8295..637c9ed 100644 --- a/falcon/api.py +++ b/falcon/api.py @@ -87,7 +87,7 @@ class API(object): if ex.headers is not None: resp.set_headers(ex.headers) - if req.client_accepts_json: + if req.client_accepts('application/json'): resp.body = ex.json() except TypeError as ex: diff --git a/falcon/request.py b/falcon/request.py index ed01bc0..836dbbe 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -138,18 +138,19 @@ class Request(object): @property def client_accepts_json(self): """Return True if the Accept header indicates JSON support.""" - - accept = self._get_header_by_wsgi_name('ACCEPT') - return ((accept is not None) and - (('application/json' in accept) or ('*/*' in accept))) + return self.client_accepts('application/json') @property def client_accepts_xml(self): """Return True if the Accept header indicates XML support.""" + return self.client_accepts('application/xml') + + def client_accepts(self, media_type): + """Return True if the Accept header indicates a media type support.""" accept = self._get_header_by_wsgi_name('ACCEPT') return ((accept is not None) and - (('application/xml' in accept) or ('*/*' in accept))) + ((media_type in accept) or ('*/*' in accept))) @property def accept(self): diff --git a/falcon/tests/test_req_vars.py b/falcon/tests/test_req_vars.py index 6d5bc41..7892bf0 100644 --- a/falcon/tests/test_req_vars.py +++ b/falcon/tests/test_req_vars.py @@ -83,22 +83,29 @@ class TestReqVars(testing.TestBase): self.assertEqual(req_noapp.relative_uri, self.relative_uri) - def test_accept_xml(self): + def test_client_accepts(self): headers = {'Accept': 'application/xml'} req = Request(testing.create_environ(headers=headers)) - self.assertTrue(req.client_accepts_xml) + self.assertTrue(req.client_accepts('application/xml')) headers = {'Accept': '*/*'} req = Request(testing.create_environ(headers=headers)) - self.assertTrue(req.client_accepts_xml) + self.assertTrue(req.client_accepts('application/xml')) headers = {'Accept': 'application/json'} req = Request(testing.create_environ(headers=headers)) - self.assertFalse(req.client_accepts_xml) + self.assertFalse(req.client_accepts('application/xml')) headers = {'Accept': 'application/xm'} req = Request(testing.create_environ(headers=headers)) - self.assertFalse(req.client_accepts_xml) + self.assertFalse(req.client_accepts('application/xml')) + + def test_client_accepts_props(self): + headers = {'Accept': 'application/xml'} + req = Request(testing.create_environ(headers=headers)) + + self.assertTrue(req.client_accepts_xml) + self.assertFalse(req.client_accepts_json) def test_range(self): headers = {'Range': '10-'}