import logging import sys import traceback import weakref import webob from wsme import exc from wsme.protocols import getprotocol from wsme.api import scan_api log = logging.getLogger(__name__) html_body = """
%(content)s """ class WSRoot(object): """ Root controller for webservices. :param protocols: A list of protocols to enable (see :meth:`addprotocol`) :param webpath: The web path where the webservice is published. """ def __init__(self, protocols=[], webpath=''): self._debug = True self._webpath = webpath self.protocols = [] for protocol in protocols: self.addprotocol(protocol) self._api = None def addprotocol(self, protocol, **options): """ Enable a new protocol on the controller. :param protocol: A registered protocol name or an instance of a protocol. """ if isinstance(protocol, str): protocol = getprotocol(protocol, **options) self.protocols.append(protocol) protocol.root = weakref.proxy(self) def getapi(self): """ Returns the api description. :rtype: list of (path, :class:`FunctionDefinition`) """ if self._api is None: self._api = [i for i in scan_api(self)] return self._api def _get_protocol(self, name): for protocol in self.protocols: if protocol.name == name: return protocol def _select_protocol(self, request): log.debug("Selecting a protocol for the following request :\n" "headers: %s\nbody: %s", request.headers, len(request.body) > 512 and request.body[:512] or request.body) protocol = None if 'wsmeproto' in request.params: return self._get_protocol(request.params['wsmeproto']) else: for p in self.protocols: if p.accept(request): protocol = p break return protocol def _do_call(self, protocol, context): request = context.request request.calls.append(context) try: if context.path is None: context.path = protocol.extract_path(context) if context.path is None: raise exc.ClientSideError( u'The %s protocol was unable to extract a function ' u'path from the request' % protocol.name) context.func, context.funcdef = self._lookup_function(context.path) kw = protocol.read_arguments(context) for arg in context.funcdef.arguments: if arg.mandatory and arg.name not in kw: raise exc.MissingArgument(arg.name) result = context.func(**kw) if context.funcdef.protocol_specific \ and context.funcdef.return_type is None: return result else: # TODO make sure result type == a._wsme_definition.return_type return protocol.encode_result(context, result) except Exception, e: infos = self._format_exception(sys.exc_info()) if isinstance(e, exc.ClientSideError): request.client_errorcount += 1 else: request.server_errorcount += 1 return protocol.encode_error(context, infos) def _handle_request(self, request): def default_prepare_response_body(request, results): return '\n'.join(results) res = webob.Response() res_content_type = None try: protocol = self._select_protocol(request) if protocol is None: msg = ("None of the following protocols can handle this " "request : %s" % ','.join( (p.name for p in self.protocols))) res.status = 500 res.content_type = 'text/plain' res.body = msg log.error(msg) return res context = None request.calls = [] request.client_errorcount = 0 request.server_errorcount = 0 if hasattr(protocol, 'prepare_response_body'): prepare_response_body = protocol.prepare_response_body else: prepare_response_body = default_prepare_response_body body = prepare_response_body(request, ( self._do_call(protocol, context) for context in protocol.iter_calls(request))) if isinstance(body, unicode): res.unicode_body = body else: res.body = body if len(request.calls) == 1: if hasattr(protocol, 'get_response_status'): res.status = protocol.get_response_status(request) else: if request.client_errorcount: res.status = 400 elif request.server_errorcount: res.status = 500 else: res.status = 200 if request.calls[0].funcdef: res_content_type = request.calls[0].funcdef.contenttype else: res.status = protocol.get_response_status(request) res_content_type = protocol.get_response_contenttype(request) except Exception, e: infos = self._format_exception(sys.exc_info()) request.server_errorcount += 1 res.body = protocol.encode_error(context, infos) res.status = 500 if res_content_type is None: # Attempt to correctly guess what content-type we should return. last_q = 0 ctypes = [ct for ct in protocol.content_types if ct] if ctypes: res_content_type = request.accept.best_match(ctypes) # If not we will attempt to convert the body to an accepted # output format. if res_content_type is None: if "text/html" in request.accept: res.body = self._html_format(res.body, protocol.content_types) res_content_type = "text/html" # TODO should we consider the encoding asked by # the web browser ? res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type return res def _lookup_function(self, path): a = self isprotocol_specific = path[0] == '_protocol' if isprotocol_specific: a = self._get_protocol(path[1]) path = path[2:] for name in path: a = getattr(a, name, None) if a is None: break if not hasattr(a, '_wsme_definition'): raise exc.UnknownFunction('/'.join(path)) definition = a._wsme_definition return a, definition def _format_exception(self, excinfo): """Extract informations that can be sent to the client.""" error = excinfo[1] if isinstance(error, exc.ClientSideError): r = dict(faultcode="Client", faultstring=error.faultstring) log.warning("Client-side error: %s" % r['faultstring']) r['debuginfo'] = None return r else: faultstring = str(error) debuginfo = "\n".join(traceback.format_exception(*excinfo)) log.error('Server-side error: "%s". Detail: \n%s' % ( faultstring, debuginfo)) r = dict(faultcode="Server", faultstring=faultstring) if self._debug: r['debuginfo'] = debuginfo else: r['debuginfo'] = None return r def _html_format(self, content, content_types): try: from pygments import highlight from pygments.lexers import get_lexer_for_mimetype from pygments.formatters import HtmlFormatter lexer = None for ct in content_types: try: lexer = get_lexer_for_mimetype(ct) break except: pass if lexer is None: raise ValueError("No lexer found") formatter = HtmlFormatter() return html_body % dict( css=formatter.get_style_defs(), content=highlight(content, lexer, formatter).encode('utf8')) except Exception, e: log.warning( "Could not pygment the content because of the following " "error :\n%s" % e) return html_body % dict( css='', content='%s' % content.replace('>', '>').replace('<', '<'))