diff --git a/wsme/controller.py b/wsme/controller.py index 7dc1495..0d4346f 100644 --- a/wsme/controller.py +++ b/wsme/controller.py @@ -223,7 +223,7 @@ class WSRoot(object): def __init__(self, protocols=[], webpath=''): self._debug = True self._webpath = webpath - self.protocols = {} + self.protocols = [] for protocol in protocols: self.addprotocol(protocol) @@ -238,7 +238,7 @@ class WSRoot(object): """ if isinstance(protocol, str): protocol = getprotocol(protocol, **options) - self.protocols[protocol.name] = protocol + self.protocols.append(protocol) protocol.root = weakref.proxy(self) def getapi(self): @@ -251,6 +251,11 @@ class WSRoot(object): self._api = [i for i in scan_api(self)] return self._api + def _get_protocol(self, name): + for protocol in self.protocols: + if protocol.name == name: + return protocol + def _select_protocol(self, request): log.debug("Selecting a protocol for the following request :\n" "headers: %s\nbody: %s", request.headers, @@ -259,10 +264,10 @@ class WSRoot(object): or request.body) protocol = None if 'wsmeproto' in request.params: - protocol = self.protocols[request.params['wsmeproto']] + return self._get_protocol(request.params['wsmeproto']) else: - for p in self.protocols.values(): + for p in self.protocols: if p.accept(request): protocol = p break @@ -275,7 +280,8 @@ class WSRoot(object): protocol = self._select_protocol(request) if protocol is None: msg = ("None of the following protocols can handle this " - "request : %s" % ','.join(self.protocols.keys())) + "request : %s" % ','.join( + (p.name for p in self.protocols))) res.status = 500 res.content_type = 'text/plain' res.body = msg @@ -298,7 +304,10 @@ class WSRoot(object): res.status = 200 if funcdef.protocol_specific and funcdef.return_type is None: - res.body = result + if isinstance(result, unicode): + res.unicode_body = result + else: + res.body = result else: # TODO make sure result type == a._wsme_definition.return_type res.body = protocol.encode_result(funcdef, result) @@ -341,7 +350,7 @@ class WSRoot(object): isprotocol_specific = path[0] == '_protocol' if isprotocol_specific: - a = self.protocols[path[1]] + a = self._get_protocol(path[1]) path = path[2:] for name in path: