# Copyright (C) 2013 Nippon Telegraph and Telephone Corporation. # Copyright (C) 2013 YAMAMOTO Takashi # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. # See the License for the specific language governing permissions and # limitations under the License. # Specification: # - msgpack # https://github.com/msgpack/msgpack/blob/master/spec.md # - msgpack-rpc # https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md from collections import deque import select import msgpack class MessageType(object): REQUEST = 0 RESPONSE = 1 NOTIFY = 2 class MessageEncoder(object): """msgpack-rpc encoder/decoder. intended to be transport-agnostic. """ def __init__(self): super(MessageEncoder, self).__init__() # NOTE(ralonsoh): msgpack>=1.0.0 self._packer = msgpack.Packer() # The strict_map_key=False option is required to use int keys in # maps; it is disabled by default to prevent hash collision denial # of service attacks (hashdos) in scenarios where an attacker can # control the keys to be hashed. self._unpacker = msgpack.Unpacker(strict_map_key=False) self._next_msgid = 0 def _create_msgid(self): this_id = self._next_msgid self._next_msgid = (self._next_msgid + 1) % 0xffffffff return this_id def create_request(self, method, params): assert isinstance(method, (str, bytes)) assert isinstance(params, list) msgid = self._create_msgid() return (self._packer.pack( [MessageType.REQUEST, msgid, method, params]), msgid) def create_response(self, msgid, error=None, result=None): assert isinstance(msgid, int) assert 0 <= msgid <= 0xffffffff assert error is None or result is None return self._packer.pack([MessageType.RESPONSE, msgid, error, result]) def create_notification(self, method, params): assert isinstance(method, (str, bytes)) assert isinstance(params, list) return self._packer.pack([MessageType.NOTIFY, method, params]) def get_and_dispatch_messages(self, data, disp_table): """dissect messages from a raw stream data. disp_table[type] should be a callable for the corresponding MessageType. """ self._unpacker.feed(data) for m in self._unpacker: self._dispatch_message(m, disp_table) @staticmethod def _dispatch_message(m, disp_table): # XXX validation t = m[0] try: f = disp_table[t] except KeyError: # ignore messages with unknown type return f(m[1:]) class EndPoint(object): """An endpoint *sock* is a socket-like. it can be either blocking or non-blocking. """ def __init__(self, sock, encoder=None, disp_table=None): if encoder is None: encoder = MessageEncoder() self._encoder = encoder self._sock = sock if disp_table is None: self._table = { MessageType.REQUEST: self._enqueue_incoming_request, MessageType.RESPONSE: self._enqueue_incoming_response, MessageType.NOTIFY: self._enqueue_incoming_notification } else: self._table = disp_table self._send_buffer = bytearray() # msgids for which we sent a request but have not received a response self._pending_requests = set() # queues for incoming messages self._requests = deque() self._notifications = deque() self._responses = {} self._incoming = 0 # number of incoming messages in our queues self._closed_by_peer = False def selectable(self): rlist = [self._sock] wlist = [] if self._send_buffer: wlist.append(self._sock) return rlist, wlist def process_outgoing(self): try: sent_bytes = self._sock.send(self._send_buffer) except IOError: sent_bytes = 0 del self._send_buffer[:sent_bytes] def process_incoming(self): self.receive_messages(all=True) def process(self): self.process_outgoing() self.process_incoming() def block(self): rlist, wlist = self.selectable() select.select(rlist, wlist, rlist + wlist) def serve(self): while not self._closed_by_peer: self.block() self.process() def _send_message(self, msg): self._send_buffer += msg self.process_outgoing() def send_request(self, method, params): """Send a request """ msg, msgid = self._encoder.create_request(method, params) self._send_message(msg) self._pending_requests.add(msgid) return msgid def send_response(self, msgid, error=None, result=None): """Send a response """ msg = self._encoder.create_response(msgid, error, result) self._send_message(msg) def send_notification(self, method, params): """Send a notification """ msg = self._encoder.create_notification(method, params) self._send_message(msg) def receive_messages(self, all=False): """Try to receive some messages. Received messages are put on the internal queues. They can be retrieved using get_xxx() methods. Returns True if there's something queued for get_xxx() methods. """ while all or self._incoming == 0: try: packet = self._sock.recv(4096) # XXX the size is arbitrary except IOError: packet = None if not packet: if packet is not None: # socket closed by peer self._closed_by_peer = True break self._encoder.get_and_dispatch_messages(packet, self._table) return self._incoming > 0 def _enqueue_incoming_request(self, m): self._requests.append(m) self._incoming += 1 def _enqueue_incoming_response(self, m): msgid, error, result = m try: self._pending_requests.remove(msgid) except KeyError: # bogus msgid # XXXwarn return assert msgid not in self._responses self._responses[msgid] = (error, result) self._incoming += 1 def _enqueue_incoming_notification(self, m): self._notifications.append(m) self._incoming += 1 def _get_message(self, q): try: m = q.popleft() assert self._incoming > 0 self._incoming -= 1 return m except IndexError: return None def get_request(self): return self._get_message(self._requests) def get_response(self, msgid): try: m = self._responses.pop(msgid) assert self._incoming > 0 self._incoming -= 1 except KeyError: return None error, result = m return result, error def get_notification(self): return self._get_message(self._notifications) class RPCError(Exception): """an error from server """ def __init__(self, error): super(RPCError, self).__init__() self._error = error def get_value(self): return self._error def __str__(self): return str(self._error) class Client(object): """a convenient class for a pure rpc client *sock* is a socket-like. it should be blocking. """ def __init__(self, sock, encoder=None, notification_callback=None): self._endpoint = EndPoint(sock, encoder) if notification_callback is None: # ignore notifications by default self._notification_callback = lambda n: None else: self._notification_callback = notification_callback def _process_input_notification(self): n = self._endpoint.get_notification() if n: self._notification_callback(n) def _process_input_request(self): # ignore requests as we are a pure client # XXXwarn self._endpoint.get_request() def call(self, method, params): """synchronous call. send a request and wait for a response. return a result. or raise RPCError exception if the peer sends us an error. """ msgid = self._endpoint.send_request(method, params) while True: if not self._endpoint.receive_messages(): raise EOFError("EOF") res = self._endpoint.get_response(msgid) if res: result, error = res if error is None: return result raise RPCError(error) self._process_input_notification() self._process_input_request() def send_notification(self, method, params): """send a notification to the peer. """ self._endpoint.send_notification(method, params) def receive_notification(self): """wait for the next incoming message. intended to be used when we have nothing to send but want to receive notifications. """ if not self._endpoint.receive_messages(): raise EOFError("EOF") self._process_input_notification() self._process_input_request() def peek_notification(self): while True: rlist, _wlist = self._endpoint.selectable() rlist, _wlist, _xlist = select.select(rlist, [], [], 0) if not rlist: break self.receive_notification()