Improve Accept and Content-Type handling
Originally, if WSME received an Accept or Content-Type header that was not aligned with what it was prepared to handle it would error out with a 500 status code. This is not good behavior for a web service. In the process of trying to fix this it was discovered that the content-negotiation code within WSME (the code that, in part, looks for a suitable protocol handler for a request) and tests of that code are incorrect, violating expected HTTP behaviors. GET requests are passing Content-Type headers to declare the desired type of representation in the response. This is what Accept is for. Unfortunately the server-side code was perfectly willing to accept this behavior. These changes correct that. Closes-Bug: 1419110 Change-Id: I2b5c0075611490c047b27b1b43b0505fc5534b3b
This commit is contained in:
		| @@ -2,6 +2,9 @@ import weakref | |||||||
|  |  | ||||||
| import pkg_resources | import pkg_resources | ||||||
|  |  | ||||||
|  | from wsme.exc import ClientSideError | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     'CallContext', |     'CallContext', | ||||||
|  |  | ||||||
| @@ -111,3 +114,35 @@ def getprotocol(name, **options): | |||||||
|             raise ValueError("Cannot find protocol '%s'" % name) |             raise ValueError("Cannot find protocol '%s'" % name) | ||||||
|         registered_protocols[name] = protocol_class |         registered_protocols[name] = protocol_class | ||||||
|     return protocol_class(**options) |     return protocol_class(**options) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def media_type_accept(request, content_types): | ||||||
|  |     """Return True if the requested media type is available. | ||||||
|  |  | ||||||
|  |     When request.method is GET or HEAD compare with the Accept header. | ||||||
|  |     When request.method is POST, PUT or PATCH compare with the Content-Type | ||||||
|  |     header. | ||||||
|  |     When request.method is DELETE media type is irrelevant, so return True. | ||||||
|  |     """ | ||||||
|  |     if request.method in ['GET', 'HEAD']: | ||||||
|  |         if request.accept: | ||||||
|  |             if request.accept.best_match(content_types): | ||||||
|  |                 return True | ||||||
|  |             error_message = ('Unacceptable Accept type: %s not in %s' | ||||||
|  |                              % (request.accept, content_types)) | ||||||
|  |             raise ClientSideError(error_message, status_code=406) | ||||||
|  |         return False | ||||||
|  |     elif request.method in ['PUT', 'POST', 'PATCH']: | ||||||
|  |         content_type = request.headers.get('Content-Type') | ||||||
|  |         if content_type: | ||||||
|  |             for ct in content_types: | ||||||
|  |                 if request.headers.get('Content-Type', '').startswith(ct): | ||||||
|  |                     return True | ||||||
|  |             error_message = ('Unacceptable Content-Type: %s not in %s' | ||||||
|  |                              % (content_type, content_types)) | ||||||
|  |             raise ClientSideError(error_message, status_code=415) | ||||||
|  |         else: | ||||||
|  |             raise ClientSideError('missing Content-Type header') | ||||||
|  |     elif request.method in ['DELETE']: | ||||||
|  |         return True | ||||||
|  |     return False | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ import os.path | |||||||
| import logging | import logging | ||||||
|  |  | ||||||
| from wsme.utils import OrderedDict | from wsme.utils import OrderedDict | ||||||
| from wsme.protocol import CallContext, Protocol | from wsme.protocol import CallContext, Protocol, media_type_accept | ||||||
|  |  | ||||||
| import wsme.rest | import wsme.rest | ||||||
| import wsme.rest.args | import wsme.rest.args | ||||||
| @@ -34,12 +34,7 @@ class RestProtocol(Protocol): | |||||||
|         for dataformat in self.dataformats: |         for dataformat in self.dataformats: | ||||||
|             if request.path.endswith('.' + dataformat): |             if request.path.endswith('.' + dataformat): | ||||||
|                 return True |                 return True | ||||||
|         if request.headers.get('Accept') in self.content_types: |         return media_type_accept(request, self.content_types) | ||||||
|             return True |  | ||||||
|         for ct in self.content_types: |  | ||||||
|             if request.headers['Content-Type'].startswith(ct): |  | ||||||
|                 return True |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     def iter_calls(self, request): |     def iter_calls(self, request): | ||||||
|         context = CallContext(request) |         context = CallContext(request) | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								wsme/root.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								wsme/root.py
									
									
									
									
									
								
							| @@ -230,20 +230,34 @@ class WSRoot(object): | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             msg = None |             msg = None | ||||||
|  |             error_status = 500 | ||||||
|             protocol = self._select_protocol(request) |             protocol = self._select_protocol(request) | ||||||
|  |             if protocol is None: | ||||||
|  |                 if request.method in ['GET', 'HEAD']: | ||||||
|  |                     error_status = 406 | ||||||
|  |                 elif request.method in ['POST', 'PUT', 'PATCH']: | ||||||
|  |                     error_status = 415 | ||||||
|  |         except ClientSideError as e: | ||||||
|  |             error_status = e.code | ||||||
|  |             msg = e.faultstring | ||||||
|  |             protocol = None | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             msg = ("Error while selecting protocol: %s" % str(e)) |             msg = ("Unexpected error while selecting protocol: %s" % str(e)) | ||||||
|             log.exception(msg) |             log.exception(msg) | ||||||
|             protocol = None |             protocol = None | ||||||
|  |             error_status = 500 | ||||||
|  |  | ||||||
|         if protocol is None: |         if protocol is None: | ||||||
|             if msg is None: |             if msg is None: | ||||||
|                 msg = ("None of the following protocols can handle this " |                 msg = ("None of the following protocols can handle this " | ||||||
|                        "request : %s" % ','.join(( |                        "request : %s" % ','.join(( | ||||||
|                            p.name for p in self.protocols))) |                            p.name for p in self.protocols))) | ||||||
|             res.status = 500 |             res.status = error_status | ||||||
|             res.content_type = 'text/plain' |             res.content_type = 'text/plain' | ||||||
|             res.text = u(msg) |             try: | ||||||
|  |                 res.text = u(msg) | ||||||
|  |             except TypeError: | ||||||
|  |                 res.text = msg | ||||||
|             log.error(msg) |             log.error(msg) | ||||||
|             return res |             return res | ||||||
|  |  | ||||||
|   | |||||||
| @@ -200,7 +200,7 @@ Value should be one of:")) | |||||||
|         app = webtest.TestApp(r.wsgiapp()) |         app = webtest.TestApp(r.wsgiapp()) | ||||||
|  |  | ||||||
|         res = app.get('/', expect_errors=True) |         res = app.get('/', expect_errors=True) | ||||||
|         assert res.status_int == 500 |         assert res.status_int == 406 | ||||||
|         print(res.body) |         print(res.body) | ||||||
|         assert res.body.find( |         assert res.body.find( | ||||||
|             b("None of the following protocols can handle this request")) != -1 |             b("None of the following protocols can handle this request")) != -1 | ||||||
|   | |||||||
| @@ -351,7 +351,7 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): | |||||||
|  |  | ||||||
|     def test_GET(self): |     def test_GET(self): | ||||||
|         headers = { |         headers = { | ||||||
|             'Content-Type': 'application/json', |             'Accept': 'application/json', | ||||||
|         } |         } | ||||||
|         res = self.app.get( |         res = self.app.get( | ||||||
|             '/crud?ref.id=1', |             '/crud?ref.id=1', | ||||||
| @@ -362,7 +362,58 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): | |||||||
|         print(result) |         print(result) | ||||||
|         assert result['data']['id'] == 1 |         assert result['data']['id'] == 1 | ||||||
|         assert result['data']['name'] == u("test") |         assert result['data']['name'] == u("test") | ||||||
|         assert result['message'] == "read" |  | ||||||
|  |     def test_GET_complex_accept(self): | ||||||
|  |         headers = { | ||||||
|  |             'Accept': 'text/html,application/xml;q=0.9,*/*;q=0.8' | ||||||
|  |         } | ||||||
|  |         res = self.app.get( | ||||||
|  |             '/crud?ref.id=1', | ||||||
|  |             headers=headers, | ||||||
|  |             expect_errors=False) | ||||||
|  |         print("Received:", res.body) | ||||||
|  |         result = json.loads(res.text) | ||||||
|  |         print(result) | ||||||
|  |         assert result['data']['id'] == 1 | ||||||
|  |         assert result['data']['name'] == u("test") | ||||||
|  |  | ||||||
|  |     def test_GET_complex_choose_xml(self): | ||||||
|  |         headers = { | ||||||
|  |             'Accept': 'text/html,text/xml;q=0.9,*/*;q=0.8' | ||||||
|  |         } | ||||||
|  |         res = self.app.get( | ||||||
|  |             '/crud?ref.id=1', | ||||||
|  |             headers=headers, | ||||||
|  |             expect_errors=False) | ||||||
|  |         print("Received:", res.body) | ||||||
|  |         assert res.content_type == 'text/xml' | ||||||
|  |  | ||||||
|  |     def test_GET_complex_accept_no_match(self): | ||||||
|  |         headers = { | ||||||
|  |             'Accept': 'text/html,application/xml;q=0.9' | ||||||
|  |         } | ||||||
|  |         res = self.app.get( | ||||||
|  |             '/crud?ref.id=1', | ||||||
|  |             headers=headers, | ||||||
|  |             status=406) | ||||||
|  |         print("Received:", res.body) | ||||||
|  |         assert res.body == ("Unacceptable Accept type: " | ||||||
|  |                             "text/html, application/xml;q=0.9 not in " | ||||||
|  |                             "['application/json', 'text/javascript', " | ||||||
|  |                             "'application/javascript', 'text/xml']") | ||||||
|  |  | ||||||
|  |     def test_GET_bad_simple_accept(self): | ||||||
|  |         headers = { | ||||||
|  |             'Accept': 'text/plain', | ||||||
|  |         } | ||||||
|  |         res = self.app.get( | ||||||
|  |             '/crud?ref.id=1', | ||||||
|  |             headers=headers, | ||||||
|  |             status=406) | ||||||
|  |         print("Received:", res.body) | ||||||
|  |         assert res.body == ("Unacceptable Accept type: text/plain not in " | ||||||
|  |                             "['application/json', 'text/javascript', " | ||||||
|  |                             "'application/javascript', 'text/xml']") | ||||||
|  |  | ||||||
|     def test_POST(self): |     def test_POST(self): | ||||||
|         headers = { |         headers = { | ||||||
| @@ -380,6 +431,20 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): | |||||||
|         assert result['data']['name'] == u("test") |         assert result['data']['name'] == u("test") | ||||||
|         assert result['message'] == "update" |         assert result['message'] == "update" | ||||||
|  |  | ||||||
|  |     def test_POST_bad_content_type(self): | ||||||
|  |         headers = { | ||||||
|  |             'Content-Type': 'text/plain', | ||||||
|  |         } | ||||||
|  |         res = self.app.post( | ||||||
|  |             '/crud', | ||||||
|  |             json.dumps(dict(data=dict(id=1, name=u('test')))), | ||||||
|  |             headers=headers, | ||||||
|  |             status=415) | ||||||
|  |         print("Received:", res.body) | ||||||
|  |         assert res.body == ("Unacceptable Content-Type: text/plain not in " | ||||||
|  |                             "['application/json', 'text/javascript', " | ||||||
|  |                             "'application/javascript', 'text/xml']") | ||||||
|  |  | ||||||
|     def test_DELETE(self): |     def test_DELETE(self): | ||||||
|         res = self.app.delete( |         res = self.app.delete( | ||||||
|             '/crud.json?ref.id=1', |             '/crud.json?ref.id=1', | ||||||
| @@ -393,7 +458,7 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): | |||||||
|  |  | ||||||
|     def test_extra_arguments(self): |     def test_extra_arguments(self): | ||||||
|         headers = { |         headers = { | ||||||
|             'Content-Type': 'application/json', |             'Accept': 'application/json', | ||||||
|         } |         } | ||||||
|         res = self.app.get( |         res = self.app.get( | ||||||
|             '/crud?ref.id=1&extraarg=foo', |             '/crud?ref.id=1&extraarg=foo', | ||||||
|   | |||||||
| @@ -38,4 +38,5 @@ class TestRoot(unittest.TestCase): | |||||||
|         res = root._handle_request(req) |         res = root._handle_request(req) | ||||||
|         assert res.status_int == 500 |         assert res.status_int == 500 | ||||||
|         assert res.content_type == 'text/plain' |         assert res.content_type == 'text/plain' | ||||||
|         assert res.text == u('Error while selecting protocol: test'), req.text |         assert (res.text == | ||||||
|  |                 'Unexpected error while selecting protocol: test'), req.text | ||||||
|   | |||||||
| @@ -136,12 +136,26 @@ class TestCRUDController(): | |||||||
|         DBSession.flush() |         DBSession.flush() | ||||||
|         pid = p.id |         pid = p.id | ||||||
|         r = self.app.get('/person?ref.id=%s' % pid, |         r = self.app.get('/person?ref.id=%s' % pid, | ||||||
|                          headers={'Content-Type': 'application/json'}) |                          headers={'Accept': 'application/json'}) | ||||||
|         r = json.loads(r.text) |         r = json.loads(r.text) | ||||||
|         print(r) |         print(r) | ||||||
|         assert r['name'] == u('Pierre-Joseph') |         assert r['name'] == u('Pierre-Joseph') | ||||||
|         assert r['birthdate'] == u('1809-01-15') |         assert r['birthdate'] == u('1809-01-15') | ||||||
|  |  | ||||||
|  |     def test_GET_bad_accept(self): | ||||||
|  |         p = DBPerson( | ||||||
|  |             name=u('Pierre-Joseph'), | ||||||
|  |             birthdate=datetime.date(1809, 1, 15)) | ||||||
|  |         DBSession.add(p) | ||||||
|  |         DBSession.flush() | ||||||
|  |         pid = p.id | ||||||
|  |         r = self.app.get('/person?ref.id=%s' % pid, | ||||||
|  |                          headers={'Accept': 'text/plain'}, | ||||||
|  |                          status=406) | ||||||
|  |         assert r.text == ("Unacceptable Accept type: text/plain not in " | ||||||
|  |                           "['application/json', 'text/javascript', " | ||||||
|  |                           "'application/javascript', 'text/xml']") | ||||||
|  |  | ||||||
|     def test_update(self): |     def test_update(self): | ||||||
|         p = DBPerson( |         p = DBPerson( | ||||||
|             name=u('Pierre-Joseph'), |             name=u('Pierre-Joseph'), | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Chris Dent
					Chris Dent