diff --git a/main.py b/main.py index 791919e43..446c2973c 100644 --- a/main.py +++ b/main.py @@ -21,9 +21,9 @@ import sys from twisted.internet import reactor from twisted.python import log -from teeth_agent.agent import TeethAgent +from teeth_agent.agent import StandbyeAgent log.startLogging(sys.stdout) -agent = TeethAgent() -agent.start('localhost', 8081) +agent = StandbyeAgent([['localhost', 8081]]) +agent.start() reactor.run() diff --git a/teeth_agent/__init__.py b/teeth_agent/__init__.py index ecff05203..13c610a20 100644 --- a/teeth_agent/__init__.py +++ b/teeth_agent/__init__.py @@ -1,3 +1,18 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" __all__ = ["__version__"] diff --git a/teeth_agent/agent.py b/teeth_agent/agent.py index 6af563bb8..f63d40840 100644 --- a/teeth_agent/agent.py +++ b/teeth_agent/agent.py @@ -14,53 +14,20 @@ See the License for the specific language governing permissions and limitations under the License. """ -import simplejson as json -from twisted.internet.protocol import ReconnectingClientFactory -from twisted.internet import reactor +from teeth_agent.client import TeethClient from twisted.python import log -from teeth_agent import __version__ as AGENT_VERSION -from teeth_agent.protocol import TeethAgentProtocol -class AgentClientHandler(TeethAgentProtocol): - def __init__(self): - TeethAgentProtocol.__init__(self, json.JSONEncoder()) - self.handlers['v1'] = { - 'prepare_image': self.prepare_image, - } +class StandbyeAgent(TeethClient): + """ + Agent to perform standbye operations. + """ - def connectionMade(self): - def _response(result): - log.msg(format='Handshake successful, connection ID is %(connection_id)s', - connection_id=result['id']) - - self.send_command('handshake', 'a:b:c:d', AGENT_VERSION).addCallback(_response) + def __init__(self, addrs): + super(StandbyeAgent, self).__init__(addrs) + self._addHandler('v1', 'prepare_image', self.prepare_image) def prepare_image(self, image_id): + """Prepare an Image.""" log.msg(format='Preparing image %(image_id)s', image_id=image_id) return {'image_id': image_id, 'status': 'PREPARED'} - - -class AgentClientFactory(ReconnectingClientFactory): - protocol = AgentClientHandler - initialDelay = 1.0 - maxDelay = 120 - - def buildProtocol(self, addr): - self.resetDelay() - return self.protocol() - - def clientConnectionFailed(self, connector, reason): - log.err('Failed to connect, re-trying', delay=self.delay) - ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) - - def clientConnectionLost(self, connector, reason): - log.err('Lost connection, re-connecting', delay=self.delay) - ReconnectingClientFactory.clientConnectionLost(self, connector, reason) - - -class TeethAgent(object): - client_factory = AgentClientFactory() - - def start(self, host, port): - reactor.connectTCP(host, port, self.client_factory) diff --git a/teeth_agent/client.py b/teeth_agent/client.py new file mode 100644 index 000000000..30b38a720 --- /dev/null +++ b/teeth_agent/client.py @@ -0,0 +1,134 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import time +import simplejson as json +from teeth_agent.protocol import TeethAgentProtocol +from twisted.internet.protocol import ReconnectingClientFactory +from twisted.python.failure import Failure +from twisted.python import log +from twisted.internet import reactor +from twisted.internet.defer import maybeDeferred + + +__all__ = ["TeethClientFactory", "TeethClient"] + + +class TeethClientFactory(ReconnectingClientFactory, object): + """ + Protocol Factory for the Teeth Client. + """ + protocol = TeethAgentProtocol + initialDelay = 1.0 + maxDelay = 120 + + def __init__(self, encoder, parent): + super(TeethClientFactory, self).__init__() + self._encoder = encoder + self._parent = parent + + def buildProtocol(self, addr): + """Create protocol for an address.""" + self.resetDelay() + proto = self.protocol(self._encoder, addr, self._parent) + self._parent.add_protocol_instance(proto) + return proto + + def clientConnectionFailed(self, connector, reason): + """clientConnectionFailed""" + log.err('Failed to connect, re-trying', delay=self.delay) + super(TeethClientFactory, self).clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + """clientConnectionLost""" + log.err('Lost connection, re-connecting', delay=self.delay) + super(TeethClientFactory, self).clientConnectionLost(connector, reason) + + +class TeethClient(object): + """ + High level Teeth Client. + """ + client_factory_cls = TeethClientFactory + client_encoder_cls = json.JSONEncoder + + def __init__(self, addrs): + super(TeethClient, self).__init__() + self._client_encoder = self.client_encoder_cls() + self._client_factory = self.client_factory_cls(self._client_encoder, self) + self._start_time = time.time() + self._clients = [] + self._outmsg = [] + self._connectaddrs = addrs + self._handlers = { + 'v1': { + 'status': self._handle_status, + } + } + + def remove_endpoint(self, host, port): + """Remove an Agent Endpoint from the active list.""" + + def op(client): + if client.address.host == host and client.address.port == port: + client.loseConnectionSoon() + return True + return False + self._clients[:] = [client for client in self._clients if not op(client)] + + def add_endpoint(self, host, port): + """Add an agent endpoint to the """ + self._connectaddrs.append([host, port]) + self.start() + + def add_protocol_instance(self, client): + """Add a running protocol to the parent.""" + client.on('command', self._on_command) + self._clients.append(client) + + def start(self): + """Start the agent.""" + for host, port in self._connectaddrs: + reactor.connectTCP(host, port, self._client_factory) + self._connectaddrs = [] + + def _on_command(self, topic, message): + if message.version not in self._handlers: + message.protocol.fatal_error('unknown message version') + return + + if message.method not in self._handlers[message.version]: + message.protocol.fatal_error('unknown message method') + return + + handler = self._handlers[message.version][message.method] + d = maybeDeferred(handler, message=message) + d.addBoth(self._send_response, message) + + def send_response(self, result, message): + """Send a response to a message.""" + if isinstance(result, Failure): + # TODO: log, cleanup + message.protocol.send_error_response('error running command', message) + else: + message.protocol.send_response(result, message) + + def _handle_status(self, message): + running = time.time() - self._start_time + return {'running': running} + + def _addHandler(self, version, command, func): + self._handlers[version][command] = func diff --git a/teeth_agent/events.py b/teeth_agent/events.py new file mode 100644 index 000000000..bef940cd4 --- /dev/null +++ b/teeth_agent/events.py @@ -0,0 +1,81 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict +from twisted.internet.defer import maybeDeferred + +__all__ = ['EventEmitter', 'EventEmitterUnhandledError'] + + +class EventEmitterUnhandledError(RuntimeError): + """ + Error caused by no subscribers to an `error` event. + """ + pass + + +class EventEmitter(object): + """ + + Extremely simple pubsub style things in-process. + + Styled after the Node.js EventEmitter class + + """ + __slots__ = ['_subs'] + + def __init__(self): + self._subs = defaultdict(list) + + def emit(self, topic, *args): + """ + Emit an event to a specific topic with a payload. + """ + ds = [] + if topic == "error": + if len(self._subs[topic]) == 0: + raise EventEmitterUnhandledError("No Subscribers to an error event found") + for s in self._subs[topic]: + ds.append(maybeDeferred(s, topic, *args)) + return ds + + def on(self, topic, callback): + """ + Add a handler for a specific topic. + """ + self.emit("newListener", topic, callback) + self._subs[topic].append(callback) + + def once(self, topic, callback): + """ + Execute a specific handler just once. + """ + def oncecb(*args): + self.removeListener(topic, oncecb) + callback(*args) + self.on(topic, oncecb) + + def removeListener(self, topic, callback): + """ + Remove a handler from a topic. + """ + self._subs[topic] = filter(lambda x: x != callback, self._subs[topic]) + + def removeAllListeners(self, topic): + """ + Remove all listeners from a specific topic. + """ + del self._subs[topic] diff --git a/teeth_agent/protocol.py b/teeth_agent/protocol.py index 41f3a76ec..78c658d55 100644 --- a/teeth_agent/protocol.py +++ b/teeth_agent/protocol.py @@ -17,63 +17,206 @@ limitations under the License. import simplejson as json import uuid -from twisted.protocols.basic import LineReceiver -from twisted.internet.defer import maybeDeferred from twisted.internet import defer - +from twisted.internet import reactor +from twisted.protocols.basic import LineReceiver +from twisted.python.failure import Failure +from twisted.python import log +from teeth_agent import __version__ as AGENT_VERSION +from teeth_agent.events import EventEmitter DEFAULT_PROTOCOL_VERSION = 'v1' +__all__ = ['RPCMessage', 'RPCCommand', 'RPCProtocol', 'TeethAgentProtocol'] -class TeethAgentProtocol(LineReceiver): - def __init__(self, encoder): + +class RPCMessage(object): + """ + Wraps all RPC messages. + """ + def __init__(self, protocol, message): + super(RPCMessage, self).__init__() + self.protocol = protocol + self.id = message['id'] + self.version = message['version'] + + +class RPCCommand(RPCMessage): + """ + Wraps incoming RPC Commands. + """ + def __init__(self, protocol, message): + super(RPCCommand, self).__init__(protocol, message) + self.method = message['method'] + self.params = message['params'] + + +class RPCResponse(RPCMessage): + """ + Wraps incoming RPC Responses. + """ + def __init__(self, protocol, message): + super(RPCResponse, self).__init__(protocol, message) + self.result = message.get('result', None) + + +class RPCError(RPCMessage, RuntimeError): + """ + Wraps incoming RPC Errors Responses. + """ + def __init__(self, protocol, message): + super(RPCError, self).__init__(protocol, message) + self.error = message.get('error', 'unknown error') + self._raw_message = message + + +class RPCProtocol(LineReceiver, EventEmitter): + """ + Twisted Protocol handler for the RPC Protocol of the Teeth + Agent <-> Endpoint communication. + + The protocol is a simple JSON newline based system. Client or server + can request methods with a message id. The recieving party + responds to this message id. + + The low level details are in C{RPCProtocol} while the higher level + functions are in C{TeethAgentProtocol} + """ + + def __init__(self, encoder, address): + super(RPCProtocol, self).__init__() self.encoder = encoder - self.handlers = {} - self.pending_command_deferreds = {} + self.address = address + self._pending_command_deferreds = {} + self._fatal_error = False + + def loseConnectionSoon(self, timeout=10): + """Attempt to disconnect from the transport as 'nicely' as possible. """ + self.loseConnection() + reactor.callLater(timeout, self.abortConnection) + + def connectionMade(self): + """TCP hard. We made it. Maybe.""" + super(RPCProtocol, self).connectionMade() + self.transport.setTcpKeepAlive(True) + self.transport.setTcpNoDelay(True) + self.emit('connect') def lineReceived(self, line): + """Process a line of data.""" line = line.strip() + if not line: return - message = json.loads(line) - if 'method' in message: - self.handle_command(message) - elif 'result' in message: - self.handle_response(message) + try: + message = json.loads(line) + except Exception: + return self.fatal_error('protocol error: unable to decode message.') - def send_command(self, method, *args, **kwargs): + if 'fatal_error' in message: + # TODO: Log what happened? + self.loseConnectionSoon() + return + + if not message.get('id', None): + return self.fatal_error("protocol violation: missing message id.") + + if not message.get('version', None): + return self.fatal_error("protocol violation: missing message version.") + + elif 'method' in message: + if not message.get('params', None): + return self.fatal_error("protocol violation: missing message params.") + + msg = RPCCommand(self, message) + self._handle_command(msg) + + elif 'error' in message: + msg = RPCError(self, message) + self._handle_response(message) + + elif 'result' in message: + + msg = RPCResponse(self, message) + self._handle_response(message) + else: + return self.fatal_error('protocol error: malformed message.') + + def fatal_error(self, message): + """Send a fatal error message, and disconnect.""" + if not self._fatal_error: + self._fatal_error = True + self.sendLine(self.encoder.encode({ + 'fatal_error': message + })) + self.loseConnectionSoon() + + def send_command(self, method, params, timeout=60): + """Send a new command.""" message_id = str(uuid.uuid4()) d = defer.Deferred() - self.pending_command_deferreds[message_id] = d + # d.setTimeout(timeout) + # TODO: cleanup _pending_command_deferreds on timeout. + self._pending_command_deferreds[message_id] = d self.sendLine(self.encoder.encode({ 'id': message_id, 'version': DEFAULT_PROTOCOL_VERSION, 'method': method, - 'args': args, - 'kwargs': kwargs, + 'params': params, })) return d - def handle_command(self, message): - message_id = message['id'] - version = message['version'] - args = message.get('args', []) - kwargs = message.get('kwargs', {}) - d = maybeDeferred(self.handlers[version][message['method']], *args, **kwargs) - d.addCallback(self.send_response, version, message_id) - - def send_response(self, result, version, message_id): + def send_response(self, result, responding_to): + """Send a result response.""" self.sendLine(self.encoder.encode({ - 'id': message_id, - 'version': version, + 'id': responding_to.id, + 'version': responding_to.version, 'result': result, })) - def handle_response(self, message): + def send_error_response(self, error, responding_to): + """Send an error response.""" + self.sendLine(self.encoder.encode({ + 'id': responding_to.id, + 'version': responding_to.version, + 'error': error, + })) + + def _handle_response(self, message): d = self.pending_command_deferreds.pop(message['id']) - error = message.get('error', None) - if error: - d.errback(error) + + if isinstance(message, RPCError): + f = Failure(message) + d.errback(f) else: - d.callback(message.get('result', None)) + d.callback(message) + + def _handle_command(self, message): + d = self.emit('command', message) + + if len(d) == 0: + return self.fatal_error("protocol violation: unsupported command") + + # TODO: do we need to wait on anything here? + pass + + +class TeethAgentProtocol(RPCProtocol): + """ + Handles higher level logic of the RPC protocol like authentication and handshakes. + """ + + def __init__(self, encoder, address, parent): + super(TeethAgentProtocol, self).__init__(encoder, address) + self.encoder = encoder + self.address = address + self.parent = parent + self.on('connect', self._on_connect) + + def _on_connect(self, event): + def _response(result): + log.msg(format='Handshake successful, connection ID is %(connection_id)s', + connection_id=result['id']) + self.send_command('handshake', + {'id': 'a:b:c:d', 'version': AGENT_VERSION}).addCallback(_response)