Add many test cases for the RPC protocol and start making a Task structure.

This commit is contained in:
Paul Querna 2013-09-25 00:53:23 +00:00
parent 7ead822716
commit 697f98b929
11 changed files with 398 additions and 36 deletions

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ subunit-output.txt
test-report.xml
twisted/plugins/dropin.cache
twistd.log
.coverage
_trial_coverage/

View File

@ -15,6 +15,9 @@ else
trial --random 0 ${UNITTESTS}
endif
coverage:
coverage run --source=${CODEDIR} --branch `which trial` ${UNITTESTS} && coverage html -d _trial_coverage --omit="*/tests/*"
env:
./scripts/bootstrap-virtualenv.sh

View File

@ -37,21 +37,21 @@ Fatal Error:
* `log`: (Agent->Server) Log a structured message from the Agent.
* `status`: (Server->Agent) Uptime, version, and other fields reported.
* `task_status`: (Agent->Server) Update status of a task. Task has a `.state`, which is `running`, `error` or `complete`. `running` will additionally contain `.eta` and `.percent`, a measure of how much work estimated to remain in seconds and how much work is done. Once `error` or `complete` is sent, no more updates will be sent. `error` state includes an additional human readable `.msg` field.
* `task_status`: (Agent->Server) Update status of a task. Task has an `.task_id` which was previously designated. Task has a `.state`, which is `running`, `error` or `complete`. `running` will additionally contain `.eta` and `.percent`, a measure of how much work estimated to remain in seconds and how much work is done. Once `error` or `complete` is sent, no more updates will be sent. `error` state includes an additional human readable `.msg` field.
#### Decommission
* `decom.disk_erase`: (Server->Agent) Erase all attached block devices securely. Returns a Task ID.
* `decom.firmware_secure`: (Server->Agent) Update Firmwares/BIOS versions and settings. Returns a Task ID.
* `decom.qc`: (Server->Agent) Run quality control checks on chassis model. Includes sending specifications of chassis (cpu types, disks, etc). Returns a Task ID.
* `decom.disk_erase`: (Server->Agent) Erase all attached block devices securely. Takes a `task_id`.
* `decom.firmware_secure`: (Server->Agent) Update Firmwares/BIOS versions and settings. Takes a `task_id`.
* `decom.qc`: (Server->Agent) Run quality control checks on chassis model. Includes sending specifications of chassis (cpu types, disks, etc). Takes a `task_id`.
#### Standbye
* `standbye.cache_images`: (Server->Agent) Cache an set of image UUID on local storage. Ordered in priority, chassis may only cache a subset depending on local storage. Returns a Task ID.
* `standbye.prepare_image`: (Server->Agent) Prepare a image UUID to be ran. Returns a Task ID.
* `standbye.run_image`: (Server->Agent) Run an image UUID. Must include Config Drive Settings. Agent will write config drive, and setup grub. If the Agent can detect a viable kexec target it will kexec into it, otherwise reboot. Returns a Task ID.
* `standbye.cache_images`: (Server->Agent) Cache an set of image UUID on local storage. Ordered in priority, chassis may only cache a subset depending on local storage. Takes a `task_id`.
* `standbye.prepare_image`: (Server->Agent) Prepare a image UUID to be ran. Takes a `task_id`.
* `standbye.run_image`: (Server->Agent) Run an image UUID. Must include Config Drive Settings. Agent will write config drive, and setup grub. If the Agent can detect a viable kexec target it will kexec into it, otherwise reboot. Takes a `task_id`.

View File

@ -3,4 +3,6 @@ plumbum==1.3.0
pep8==1.4.6
pyflakes==0.7.3
junitxml==0.7
python-subunit==0.0.15
python-subunit==0.0.15
mock==1.0.1
coverage==3.6

View File

@ -27,7 +27,7 @@ class StandbyAgent(TeethClient):
def __init__(self, addrs):
super(StandbyAgent, self).__init__(addrs)
self._addHandler('v1', 'prepare_image', self.prepare_image)
log.info('Starting agent', addrs=addrs)
log.msg('Starting agent', addrs=addrs)
def prepare_image(self, image_id):
"""Prepare an Image."""

View File

@ -98,7 +98,7 @@ class TeethClient(MultiService, object):
self._running = False
dl = []
for client in self._clients:
dl.append(client.loseConnectionSoon(timeout=0.05))
dl.append(client.abortConnection())
return DeferredList(dl)
def remove_endpoint(self, host, port):

View File

@ -32,6 +32,8 @@ def configure():
structlog.configure(
context_class=dict,
logger_factory=structlog.twisted.LoggerFactory(),
wrapper_class=structlog.twisted.BoundLogger,
cache_logger_on_first_use=True)
@ -39,4 +41,6 @@ def get_logger():
"""
Get a logger instance.
"""
configure()
return structlog.get_logger()

View File

@ -18,10 +18,8 @@ import simplejson as json
import uuid
from twisted.internet import defer
from twisted.internet import task
from twisted.internet import reactor
from twisted.protocols import policies
from twisted.protocols.basic import LineReceiver
from twisted.python.failure import Failure
from teeth_agent import __version__ as AGENT_VERSION
from teeth_agent.events import EventEmitter
from teeth_agent.logging import get_logger
@ -73,7 +71,9 @@ class RPCError(RPCMessage, RuntimeError):
self._raw_message = message
class RPCProtocol(LineReceiver, EventEmitter):
class RPCProtocol(LineReceiver,
EventEmitter,
policies.TimeoutMixin):
"""
Twisted Protocol handler for the RPC Protocol of the Teeth
Agent <-> Endpoint communication.
@ -96,29 +96,34 @@ class RPCProtocol(LineReceiver, EventEmitter):
self._pending_command_deferreds = {}
self._fatal_error = False
self._log = log.bind(host=address.host, port=address.port)
self._timeOut = 60
def loseConnectionSoon(self, timeout=10):
"""Attempt to disconnect from the transport as 'nicely' as possible. """
self._log.info('Trying to disconnect.')
self.transport.loseConnection()
return task.deferLater(reactor, timeout, self.transport.abortConnection)
def timeoutConnection(self):
"""Action called when the connection has hit a timeout."""
self.transport.abortConnection()
def connectionMade(self):
"""TCP hard. We made it. Maybe."""
super(RPCProtocol, self).connectionMade()
self._log.info('Connection established.')
self._log.msg('Connection established.')
self.transport.setTcpKeepAlive(True)
self.transport.setTcpNoDelay(True)
self.emit('connect')
def sendLine(self, line):
"""Send a line of content to our peer."""
self.resetTimeout()
super(RPCProtocol, self).sendLine(line)
def lineReceived(self, line):
"""Process a line of data."""
self.resetTimeout()
line = line.strip()
if not line:
return
self._log.debug('Got Line', line=line)
self._log.msg('Got Line', line=line)
try:
message = json.loads(line)
@ -127,16 +132,16 @@ class RPCProtocol(LineReceiver, EventEmitter):
if 'fatal_error' in message:
# TODO: Log what happened?
self.loseConnectionSoon()
self.transport.abortConnection()
return
if not message.get('id', None):
return self.fatal_error("protocol violation: missing message id.")
if not message.get('version', None):
return self.fatal_error("protocol violation: missing message version.")
elif 'method' in message:
if not message.get('id', None):
return self.fatal_error("protocol violation: missing message id.")
if 'method' in message:
if not message.get('params', None):
return self.fatal_error("protocol violation: missing message params.")
@ -145,24 +150,24 @@ class RPCProtocol(LineReceiver, EventEmitter):
elif 'error' in message:
msg = RPCError(self, message)
self._handle_response(message)
self._handle_response(msg)
elif 'result' in message:
msg = RPCResponse(self, message)
self._handle_response(message)
self._handle_response(msg)
else:
return self.fatal_error('protocol error: malformed message.')
def fatal_error(self, message):
"""Send a fatal error message, and disconnect."""
self._log.error('sending a fatal error', message=message)
self._log.msg('sending a fatal error', message=message)
if not self._fatal_error:
self._fatal_error = True
self.sendLine(self.encoder.encode({
'fatal_error': message
}))
self.loseConnectionSoon()
self.transport.abortConnection()
def send_command(self, method, params, timeout=60):
"""Send a new command."""
@ -196,11 +201,13 @@ class RPCProtocol(LineReceiver, EventEmitter):
}))
def _handle_response(self, message):
d = self.pending_command_deferreds.pop(message['id'])
d = self._pending_command_deferreds.pop(message.id, None)
if not d:
return self.fatal_error("protocol violation: unknown message id referenced.")
if isinstance(message, RPCError):
f = Failure(message)
d.errback(f)
d.errback(message)
else:
d.callback(message)
@ -208,7 +215,7 @@ class RPCProtocol(LineReceiver, EventEmitter):
d = self.emit('command', message)
if len(d) == 0:
return self.fatal_error("protocol violation: unsupported command")
return self.fatal_error("protocol violation: unsupported command.")
# TODO: do we need to wait on anything here?
pass
@ -224,9 +231,10 @@ class TeethAgentProtocol(RPCProtocol):
self.encoder = encoder
self.address = address
self.parent = parent
self.on('connect', self._on_connect)
self.once('connect', self._once_connect)
def _once_connect(self, event):
def _on_connect(self, event):
def _response(result):
self._log.msg('Handshake successful', connection_id=result['id'])

100
teeth_agent/task.py Normal file
View File

@ -0,0 +1,100 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from twisted.application.service import MultiService
from twisted.application.internet import TimerService
from teeth_agent.logging import get_logger
log = get_logger()
__all__ = ['Task', 'PrepareImageTask']
class Task(MultiService, object):
"""
Task to execute, reporting status periodically to TeethClient instance.
"""
task_name = 'task_undefined'
def __init__(self, client, task_id, task_name, reporting_interval=10):
super(Task, self).__init__()
self.setName(self.task_name)
self._client = client
self._id = task_id
self._percent = 0
self._reporting_interval = reporting_interval
self._state = 'starting'
self._timer = TimerService(self._reporting_interval, self._tick)
self._timer.setServiceParent(self)
self._error_msg = None
def run(self):
"""Run the Task."""
# setServiceParent actually starts the task if it is already running
# so we run it in start.
self.setServiceParent(self._client)
def _tick(self):
if not self.running:
# log.debug("_tick called while not running :()")
return
return self._client.update_task_status(self)
def error(self, message):
"""Error out running of the task."""
self._error_msg = message
self._state = 'error'
self.stopService()
def complete(self):
"""Complete running of the task."""
self._state = 'complete'
self.stopService()
def startService(self):
"""Start the Service."""
super(Task, self).startService()
self._state = 'running'
def stopService(self):
"""Stop the Service."""
super(Task, self).stopService()
if not self._client.running:
return
if self._state not in ['error', 'complete']:
log.err("told to shutdown before task could complete, marking as error.")
self._error_msg = 'service being shutdown'
self._state = 'error'
self._client.finish_task(self)
class PrepareImageTask(Task):
"""Prepare an image to be ran on the machine."""
task_name = 'prepare_image'
def __init__(self, client, task_id, image_info, reporting_interval=10):
super(PrepareImageTask, self).__init__(client, task_id)
self._image_info = image_info
def run():
"""Run the Prepare Image task."""
pass

View File

@ -0,0 +1,190 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from twisted.internet import defer
from twisted.internet import main
from twisted.internet.address import IPv4Address
from twisted.python import failure
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.trial import unittest
import simplejson as json
from mock import Mock
from teeth_agent.protocol import RPCError, RPCProtocol, TeethAgentProtocol
from teeth_agent import __version__ as AGENT_VERSION
class FakeTCPTransport(StringTransportWithDisconnection, object):
_aborting = False
disconnected = False
setTcpKeepAlive = Mock(return_value=None)
setTcpNoDelay = Mock(return_value=None)
setTcpNoDelay = Mock(return_value=None)
def connectionLost(self, reason):
self.protocol.connectionLost(reason)
def abortConnection(self):
if self.disconnected or self._aborting:
return
self._aborting = True
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
class RPCProtocolTest(unittest.TestCase):
"""RPC Protocol tests."""
def setUp(self):
self.tr = FakeTCPTransport()
self.proto = RPCProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0))
self.proto.makeConnection(self.tr)
self.tr.protocol = self.proto
def test_timeout(self):
d = defer.Deferred()
called = []
orig = self.proto.connectionLost
def lost(arg):
orig()
called.append(True)
d.callback(True)
self.proto.connectionLost = lost
self.proto.timeoutConnection()
def check(ignore):
self.assertEqual(called, [True])
d.addCallback(check)
return d
def test_recv_command_no_params(self):
self.tr.clear()
self.proto.lineReceived(json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message params.')
def test_recv_bogus_command(self):
self.tr.clear()
self.proto.lineReceived(
json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF', 'params': {'d': '1'}}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: unsupported command.')
def test_recv_valid_json_no_id(self):
self.tr.clear()
self.proto.lineReceived(json.dumps({'version': 'v913'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message id.')
def test_recv_valid_json_no_version(self):
self.tr.clear()
self.proto.lineReceived(json.dumps({'version': None, 'id': 'foo'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message version.')
def test_recv_invalid_data(self):
self.tr.clear()
self.proto.lineReceived('')
self.proto.lineReceived('invalid json!')
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol error: unable to decode message.')
def test_recv_missing_key_parts(self):
self.tr.clear()
self.proto.lineReceived(json.dumps(
{'id': '1', 'version': 'v1'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol error: malformed message.')
def test_recv_error_to_unknown_id(self):
self.tr.clear()
self.proto.lineReceived(json.dumps(
{'id': '1', 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: unknown message id referenced.')
def _send_command(self):
self.tr.clear()
d = self.proto.send_command('test_command', {'body': 42})
req = json.loads(self.tr.io.getvalue().strip())
self.tr.clear()
return (d, req)
def test_recv_result(self):
dout = defer.Deferred()
d, req = self._send_command()
self.proto.lineReceived(json.dumps(
{'id': req['id'], 'version': 'v1', 'result': {'duh': req['params']['body']}}))
self.assertEqual(len(self.tr.io.getvalue()), 0)
def check(resp):
self.assertEqual(resp.result['duh'], 42)
dout.callback(True)
d.addCallback(check)
return dout
def test_recv_error(self):
d, req = self._send_command()
self.proto.lineReceived(json.dumps(
{'id': req['id'], 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
self.assertEqual(len(self.tr.io.getvalue()), 0)
return self.assertFailure(d, RPCError)
def test_recv_fatal_error(self):
d = defer.Deferred()
called = []
orig = self.proto.connectionLost
def lost(arg):
self.failUnless(isinstance(arg, failure.Failure))
orig()
called.append(True)
d.callback(True)
self.proto.connectionLost = lost
def check(ignore):
self.assertEqual(called, [True])
d.addCallback(check)
self.tr.clear()
self.proto.lineReceived(json.dumps({'fatal_error': 'you be broken'}))
return d
class TeethAgentProtocolTest(unittest.TestCase):
"""Teeth Agent Protocol tests."""
def setUp(self):
self.tr = FakeTCPTransport()
self.proto = TeethAgentProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0), None)
self.proto.makeConnection(self.tr)
self.tr.protocol = self.proto
def test_on_connect(self):
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['version'], 'v1')
self.assertEqual(obj['method'], 'handshake')
self.assertEqual(obj['method'], 'handshake')
self.assertEqual(obj['params']['id'], 'a:b:c:d')
self.assertEqual(obj['params']['version'], AGENT_VERSION)

View File

@ -0,0 +1,53 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import uuid
from twisted.trial import unittest
from teeth_agent.task import Task
from mock import Mock
class FakeClient(object):
addService = Mock(return_value=None)
running = Mock(return_value=0)
update_task_status = Mock(return_value=None)
finish_task = Mock(return_value=None)
class TaskTest(unittest.TestCase):
"""Event Emitter tests."""
def setUp(self):
self.task_id = str(uuid.uuid4())
self.client = FakeClient()
self.task = Task(self.client, self.task_id, 'test_task')
def tearDown(self):
del self.task_id
del self.task
del self.client
def test_run(self):
self.assertEqual(self.task._state, 'starting')
self.assertEqual(self.task._id, self.task_id)
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_called_once_with(self.task)
self.task.complete()
self.assertEqual(self.task._state, 'complete')
self.client.finish_task.assert_called_once_with(self.task)