diff --git a/.gitignore b/.gitignore index aefdc6b4c..deca4277b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ subunit-output.txt test-report.xml twisted/plugins/dropin.cache twistd.log +.coverage +_trial_coverage/ diff --git a/Makefile b/Makefile index eefb23055..a24ec780a 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,9 @@ else trial --random 0 ${UNITTESTS} endif +coverage: + coverage run --source=${CODEDIR} --branch `which trial` ${UNITTESTS} && coverage html -d _trial_coverage --omit="*/tests/*" + env: ./scripts/bootstrap-virtualenv.sh diff --git a/README.md b/README.md index d79095951..f297455b8 100644 --- a/README.md +++ b/README.md @@ -37,21 +37,21 @@ Fatal Error: * `log`: (Agent->Server) Log a structured message from the Agent. * `status`: (Server->Agent) Uptime, version, and other fields reported. -* `task_status`: (Agent->Server) Update status of a task. Task has a `.state`, which is `running`, `error` or `complete`. `running` will additionally contain `.eta` and `.percent`, a measure of how much work estimated to remain in seconds and how much work is done. Once `error` or `complete` is sent, no more updates will be sent. `error` state includes an additional human readable `.msg` field. +* `task_status`: (Agent->Server) Update status of a task. Task has an `.task_id` which was previously designated. Task has a `.state`, which is `running`, `error` or `complete`. `running` will additionally contain `.eta` and `.percent`, a measure of how much work estimated to remain in seconds and how much work is done. Once `error` or `complete` is sent, no more updates will be sent. `error` state includes an additional human readable `.msg` field. #### Decommission -* `decom.disk_erase`: (Server->Agent) Erase all attached block devices securely. Returns a Task ID. -* `decom.firmware_secure`: (Server->Agent) Update Firmwares/BIOS versions and settings. Returns a Task ID. -* `decom.qc`: (Server->Agent) Run quality control checks on chassis model. Includes sending specifications of chassis (cpu types, disks, etc). Returns a Task ID. +* `decom.disk_erase`: (Server->Agent) Erase all attached block devices securely. Takes a `task_id`. +* `decom.firmware_secure`: (Server->Agent) Update Firmwares/BIOS versions and settings. Takes a `task_id`. +* `decom.qc`: (Server->Agent) Run quality control checks on chassis model. Includes sending specifications of chassis (cpu types, disks, etc). Takes a `task_id`. #### Standbye -* `standbye.cache_images`: (Server->Agent) Cache an set of image UUID on local storage. Ordered in priority, chassis may only cache a subset depending on local storage. Returns a Task ID. -* `standbye.prepare_image`: (Server->Agent) Prepare a image UUID to be ran. Returns a Task ID. -* `standbye.run_image`: (Server->Agent) Run an image UUID. Must include Config Drive Settings. Agent will write config drive, and setup grub. If the Agent can detect a viable kexec target it will kexec into it, otherwise reboot. Returns a Task ID. +* `standbye.cache_images`: (Server->Agent) Cache an set of image UUID on local storage. Ordered in priority, chassis may only cache a subset depending on local storage. Takes a `task_id`. +* `standbye.prepare_image`: (Server->Agent) Prepare a image UUID to be ran. Takes a `task_id`. +* `standbye.run_image`: (Server->Agent) Run an image UUID. Must include Config Drive Settings. Agent will write config drive, and setup grub. If the Agent can detect a viable kexec target it will kexec into it, otherwise reboot. Takes a `task_id`. diff --git a/dev-requirements.txt b/dev-requirements.txt index 1d356a799..5aa4d7e7c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,4 +3,6 @@ plumbum==1.3.0 pep8==1.4.6 pyflakes==0.7.3 junitxml==0.7 -python-subunit==0.0.15 \ No newline at end of file +python-subunit==0.0.15 +mock==1.0.1 +coverage==3.6 \ No newline at end of file diff --git a/teeth_agent/agent.py b/teeth_agent/agent.py index 16cddcf78..e7e3eca18 100644 --- a/teeth_agent/agent.py +++ b/teeth_agent/agent.py @@ -27,7 +27,7 @@ class StandbyAgent(TeethClient): def __init__(self, addrs): super(StandbyAgent, self).__init__(addrs) self._addHandler('v1', 'prepare_image', self.prepare_image) - log.info('Starting agent', addrs=addrs) + log.msg('Starting agent', addrs=addrs) def prepare_image(self, image_id): """Prepare an Image.""" diff --git a/teeth_agent/client.py b/teeth_agent/client.py index 3f6b8e395..8666db005 100644 --- a/teeth_agent/client.py +++ b/teeth_agent/client.py @@ -98,7 +98,7 @@ class TeethClient(MultiService, object): self._running = False dl = [] for client in self._clients: - dl.append(client.loseConnectionSoon(timeout=0.05)) + dl.append(client.abortConnection()) return DeferredList(dl) def remove_endpoint(self, host, port): diff --git a/teeth_agent/logging.py b/teeth_agent/logging.py index 610bac881..a5876a8ec 100644 --- a/teeth_agent/logging.py +++ b/teeth_agent/logging.py @@ -32,6 +32,8 @@ def configure(): structlog.configure( context_class=dict, + logger_factory=structlog.twisted.LoggerFactory(), + wrapper_class=structlog.twisted.BoundLogger, cache_logger_on_first_use=True) @@ -39,4 +41,6 @@ def get_logger(): """ Get a logger instance. """ + configure() + return structlog.get_logger() diff --git a/teeth_agent/protocol.py b/teeth_agent/protocol.py index 9b96b8fee..e8a77af6b 100644 --- a/teeth_agent/protocol.py +++ b/teeth_agent/protocol.py @@ -18,10 +18,8 @@ import simplejson as json import uuid from twisted.internet import defer -from twisted.internet import task -from twisted.internet import reactor +from twisted.protocols import policies from twisted.protocols.basic import LineReceiver -from twisted.python.failure import Failure from teeth_agent import __version__ as AGENT_VERSION from teeth_agent.events import EventEmitter from teeth_agent.logging import get_logger @@ -73,7 +71,9 @@ class RPCError(RPCMessage, RuntimeError): self._raw_message = message -class RPCProtocol(LineReceiver, EventEmitter): +class RPCProtocol(LineReceiver, + EventEmitter, + policies.TimeoutMixin): """ Twisted Protocol handler for the RPC Protocol of the Teeth Agent <-> Endpoint communication. @@ -96,29 +96,34 @@ class RPCProtocol(LineReceiver, EventEmitter): self._pending_command_deferreds = {} self._fatal_error = False self._log = log.bind(host=address.host, port=address.port) + self._timeOut = 60 - def loseConnectionSoon(self, timeout=10): - """Attempt to disconnect from the transport as 'nicely' as possible. """ - self._log.info('Trying to disconnect.') - self.transport.loseConnection() - return task.deferLater(reactor, timeout, self.transport.abortConnection) + def timeoutConnection(self): + """Action called when the connection has hit a timeout.""" + self.transport.abortConnection() def connectionMade(self): """TCP hard. We made it. Maybe.""" super(RPCProtocol, self).connectionMade() - self._log.info('Connection established.') + self._log.msg('Connection established.') self.transport.setTcpKeepAlive(True) self.transport.setTcpNoDelay(True) self.emit('connect') + def sendLine(self, line): + """Send a line of content to our peer.""" + self.resetTimeout() + super(RPCProtocol, self).sendLine(line) + def lineReceived(self, line): """Process a line of data.""" + self.resetTimeout() line = line.strip() if not line: return - self._log.debug('Got Line', line=line) + self._log.msg('Got Line', line=line) try: message = json.loads(line) @@ -127,16 +132,16 @@ class RPCProtocol(LineReceiver, EventEmitter): if 'fatal_error' in message: # TODO: Log what happened? - self.loseConnectionSoon() + self.transport.abortConnection() return - if not message.get('id', None): - return self.fatal_error("protocol violation: missing message id.") - if not message.get('version', None): return self.fatal_error("protocol violation: missing message version.") - elif 'method' in message: + if not message.get('id', None): + return self.fatal_error("protocol violation: missing message id.") + + if 'method' in message: if not message.get('params', None): return self.fatal_error("protocol violation: missing message params.") @@ -145,24 +150,24 @@ class RPCProtocol(LineReceiver, EventEmitter): elif 'error' in message: msg = RPCError(self, message) - self._handle_response(message) + self._handle_response(msg) elif 'result' in message: msg = RPCResponse(self, message) - self._handle_response(message) + self._handle_response(msg) else: return self.fatal_error('protocol error: malformed message.') def fatal_error(self, message): """Send a fatal error message, and disconnect.""" - self._log.error('sending a fatal error', message=message) + self._log.msg('sending a fatal error', message=message) if not self._fatal_error: self._fatal_error = True self.sendLine(self.encoder.encode({ 'fatal_error': message })) - self.loseConnectionSoon() + self.transport.abortConnection() def send_command(self, method, params, timeout=60): """Send a new command.""" @@ -196,11 +201,13 @@ class RPCProtocol(LineReceiver, EventEmitter): })) def _handle_response(self, message): - d = self.pending_command_deferreds.pop(message['id']) + d = self._pending_command_deferreds.pop(message.id, None) + + if not d: + return self.fatal_error("protocol violation: unknown message id referenced.") if isinstance(message, RPCError): - f = Failure(message) - d.errback(f) + d.errback(message) else: d.callback(message) @@ -208,7 +215,7 @@ class RPCProtocol(LineReceiver, EventEmitter): d = self.emit('command', message) if len(d) == 0: - return self.fatal_error("protocol violation: unsupported command") + return self.fatal_error("protocol violation: unsupported command.") # TODO: do we need to wait on anything here? pass @@ -224,9 +231,10 @@ class TeethAgentProtocol(RPCProtocol): self.encoder = encoder self.address = address self.parent = parent - self.on('connect', self._on_connect) + self.once('connect', self._once_connect) + + def _once_connect(self, event): - def _on_connect(self, event): def _response(result): self._log.msg('Handshake successful', connection_id=result['id']) diff --git a/teeth_agent/task.py b/teeth_agent/task.py new file mode 100644 index 000000000..1fd59d317 --- /dev/null +++ b/teeth_agent/task.py @@ -0,0 +1,100 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from twisted.application.service import MultiService +from twisted.application.internet import TimerService +from teeth_agent.logging import get_logger +log = get_logger() + + +__all__ = ['Task', 'PrepareImageTask'] + + +class Task(MultiService, object): + """ + Task to execute, reporting status periodically to TeethClient instance. + """ + + task_name = 'task_undefined' + + def __init__(self, client, task_id, task_name, reporting_interval=10): + super(Task, self).__init__() + self.setName(self.task_name) + self._client = client + self._id = task_id + self._percent = 0 + self._reporting_interval = reporting_interval + self._state = 'starting' + self._timer = TimerService(self._reporting_interval, self._tick) + self._timer.setServiceParent(self) + self._error_msg = None + + def run(self): + """Run the Task.""" + # setServiceParent actually starts the task if it is already running + # so we run it in start. + self.setServiceParent(self._client) + + def _tick(self): + if not self.running: + # log.debug("_tick called while not running :()") + return + return self._client.update_task_status(self) + + def error(self, message): + """Error out running of the task.""" + self._error_msg = message + self._state = 'error' + self.stopService() + + def complete(self): + """Complete running of the task.""" + self._state = 'complete' + self.stopService() + + def startService(self): + """Start the Service.""" + super(Task, self).startService() + self._state = 'running' + + def stopService(self): + """Stop the Service.""" + super(Task, self).stopService() + + if not self._client.running: + return + + if self._state not in ['error', 'complete']: + log.err("told to shutdown before task could complete, marking as error.") + self._error_msg = 'service being shutdown' + self._state = 'error' + + self._client.finish_task(self) + + +class PrepareImageTask(Task): + + """Prepare an image to be ran on the machine.""" + + task_name = 'prepare_image' + + def __init__(self, client, task_id, image_info, reporting_interval=10): + super(PrepareImageTask, self).__init__(client, task_id) + self._image_info = image_info + + def run(): + """Run the Prepare Image task.""" + pass diff --git a/teeth_agent/tests/test_protocol.py b/teeth_agent/tests/test_protocol.py new file mode 100644 index 000000000..dda6afaf5 --- /dev/null +++ b/teeth_agent/tests/test_protocol.py @@ -0,0 +1,190 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +from twisted.internet import defer +from twisted.internet import main +from twisted.internet.address import IPv4Address +from twisted.python import failure +from twisted.test.proto_helpers import StringTransportWithDisconnection +from twisted.trial import unittest +import simplejson as json +from mock import Mock + +from teeth_agent.protocol import RPCError, RPCProtocol, TeethAgentProtocol +from teeth_agent import __version__ as AGENT_VERSION + + +class FakeTCPTransport(StringTransportWithDisconnection, object): + _aborting = False + disconnected = False + + setTcpKeepAlive = Mock(return_value=None) + setTcpNoDelay = Mock(return_value=None) + setTcpNoDelay = Mock(return_value=None) + + def connectionLost(self, reason): + self.protocol.connectionLost(reason) + + def abortConnection(self): + if self.disconnected or self._aborting: + return + self._aborting = True + self.connectionLost(failure.Failure(main.CONNECTION_DONE)) + + +class RPCProtocolTest(unittest.TestCase): + """RPC Protocol tests.""" + + def setUp(self): + self.tr = FakeTCPTransport() + self.proto = RPCProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0)) + self.proto.makeConnection(self.tr) + self.tr.protocol = self.proto + + def test_timeout(self): + d = defer.Deferred() + called = [] + orig = self.proto.connectionLost + + def lost(arg): + orig() + called.append(True) + d.callback(True) + + self.proto.connectionLost = lost + self.proto.timeoutConnection() + + def check(ignore): + self.assertEqual(called, [True]) + + d.addCallback(check) + return d + + def test_recv_command_no_params(self): + self.tr.clear() + self.proto.lineReceived(json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF'})) + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['fatal_error'], 'protocol violation: missing message params.') + + def test_recv_bogus_command(self): + self.tr.clear() + self.proto.lineReceived( + json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF', 'params': {'d': '1'}})) + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['fatal_error'], 'protocol violation: unsupported command.') + + def test_recv_valid_json_no_id(self): + self.tr.clear() + self.proto.lineReceived(json.dumps({'version': 'v913'})) + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['fatal_error'], 'protocol violation: missing message id.') + + def test_recv_valid_json_no_version(self): + self.tr.clear() + self.proto.lineReceived(json.dumps({'version': None, 'id': 'foo'})) + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['fatal_error'], 'protocol violation: missing message version.') + + def test_recv_invalid_data(self): + self.tr.clear() + self.proto.lineReceived('') + self.proto.lineReceived('invalid json!') + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['fatal_error'], 'protocol error: unable to decode message.') + + def test_recv_missing_key_parts(self): + self.tr.clear() + self.proto.lineReceived(json.dumps( + {'id': '1', 'version': 'v1'})) + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['fatal_error'], 'protocol error: malformed message.') + + def test_recv_error_to_unknown_id(self): + self.tr.clear() + self.proto.lineReceived(json.dumps( + {'id': '1', 'version': 'v1', 'error': {'msg': 'something is wrong'}})) + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['fatal_error'], 'protocol violation: unknown message id referenced.') + + def _send_command(self): + self.tr.clear() + d = self.proto.send_command('test_command', {'body': 42}) + req = json.loads(self.tr.io.getvalue().strip()) + self.tr.clear() + return (d, req) + + def test_recv_result(self): + dout = defer.Deferred() + d, req = self._send_command() + self.proto.lineReceived(json.dumps( + {'id': req['id'], 'version': 'v1', 'result': {'duh': req['params']['body']}})) + self.assertEqual(len(self.tr.io.getvalue()), 0) + + def check(resp): + self.assertEqual(resp.result['duh'], 42) + dout.callback(True) + + d.addCallback(check) + + return dout + + def test_recv_error(self): + d, req = self._send_command() + self.proto.lineReceived(json.dumps( + {'id': req['id'], 'version': 'v1', 'error': {'msg': 'something is wrong'}})) + self.assertEqual(len(self.tr.io.getvalue()), 0) + return self.assertFailure(d, RPCError) + + def test_recv_fatal_error(self): + d = defer.Deferred() + called = [] + orig = self.proto.connectionLost + + def lost(arg): + self.failUnless(isinstance(arg, failure.Failure)) + orig() + called.append(True) + d.callback(True) + + self.proto.connectionLost = lost + + def check(ignore): + self.assertEqual(called, [True]) + + d.addCallback(check) + + self.tr.clear() + self.proto.lineReceived(json.dumps({'fatal_error': 'you be broken'})) + return d + + +class TeethAgentProtocolTest(unittest.TestCase): + """Teeth Agent Protocol tests.""" + + def setUp(self): + self.tr = FakeTCPTransport() + self.proto = TeethAgentProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0), None) + self.proto.makeConnection(self.tr) + self.tr.protocol = self.proto + + def test_on_connect(self): + obj = json.loads(self.tr.io.getvalue().strip()) + self.assertEqual(obj['version'], 'v1') + self.assertEqual(obj['method'], 'handshake') + self.assertEqual(obj['method'], 'handshake') + self.assertEqual(obj['params']['id'], 'a:b:c:d') + self.assertEqual(obj['params']['version'], AGENT_VERSION) diff --git a/teeth_agent/tests/test_task.py b/teeth_agent/tests/test_task.py new file mode 100644 index 000000000..1c8e29aee --- /dev/null +++ b/teeth_agent/tests/test_task.py @@ -0,0 +1,53 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import uuid + +from twisted.trial import unittest +from teeth_agent.task import Task +from mock import Mock + + +class FakeClient(object): + addService = Mock(return_value=None) + running = Mock(return_value=0) + update_task_status = Mock(return_value=None) + finish_task = Mock(return_value=None) + + +class TaskTest(unittest.TestCase): + """Event Emitter tests.""" + + def setUp(self): + self.task_id = str(uuid.uuid4()) + self.client = FakeClient() + self.task = Task(self.client, self.task_id, 'test_task') + + def tearDown(self): + del self.task_id + del self.task + del self.client + + def test_run(self): + self.assertEqual(self.task._state, 'starting') + self.assertEqual(self.task._id, self.task_id) + self.task.run() + self.client.addService.assert_called_once_with(self.task) + self.task.startService() + self.client.update_task_status.assert_called_once_with(self.task) + self.task.complete() + self.assertEqual(self.task._state, 'complete') + self.client.finish_task.assert_called_once_with(self.task)