Add many test cases for the RPC protocol and start making a Task structure.
parent
7ead822716
commit
697f98b929
@ -0,0 +1,100 @@
|
||||
"""
|
||||
Copyright 2013 Rackspace, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from twisted.application.service import MultiService
|
||||
from twisted.application.internet import TimerService
|
||||
from teeth_agent.logging import get_logger
|
||||
log = get_logger()
|
||||
|
||||
|
||||
__all__ = ['Task', 'PrepareImageTask']
|
||||
|
||||
|
||||
class Task(MultiService, object):
|
||||
"""
|
||||
Task to execute, reporting status periodically to TeethClient instance.
|
||||
"""
|
||||
|
||||
task_name = 'task_undefined'
|
||||
|
||||
def __init__(self, client, task_id, task_name, reporting_interval=10):
|
||||
super(Task, self).__init__()
|
||||
self.setName(self.task_name)
|
||||
self._client = client
|
||||
self._id = task_id
|
||||
self._percent = 0
|
||||
self._reporting_interval = reporting_interval
|
||||
self._state = 'starting'
|
||||
self._timer = TimerService(self._reporting_interval, self._tick)
|
||||
self._timer.setServiceParent(self)
|
||||
self._error_msg = None
|
||||
|
||||
def run(self):
|
||||
"""Run the Task."""
|
||||
# setServiceParent actually starts the task if it is already running
|
||||
# so we run it in start.
|
||||
self.setServiceParent(self._client)
|
||||
|
||||
def _tick(self):
|
||||
if not self.running:
|
||||
# log.debug("_tick called while not running :()")
|
||||
return
|
||||
return self._client.update_task_status(self)
|
||||
|
||||
def error(self, message):
|
||||
"""Error out running of the task."""
|
||||
self._error_msg = message
|
||||
self._state = 'error'
|
||||
self.stopService()
|
||||
|
||||
def complete(self):
|
||||
"""Complete running of the task."""
|
||||
self._state = 'complete'
|
||||
self.stopService()
|
||||
|
||||
def startService(self):
|
||||
"""Start the Service."""
|
||||
super(Task, self).startService()
|
||||
self._state = 'running'
|
||||
|
||||
def stopService(self):
|
||||
"""Stop the Service."""
|
||||
super(Task, self).stopService()
|
||||
|
||||
if not self._client.running:
|
||||
return
|
||||
|
||||
if self._state not in ['error', 'complete']:
|
||||
log.err("told to shutdown before task could complete, marking as error.")
|
||||
self._error_msg = 'service being shutdown'
|
||||
self._state = 'error'
|
||||
|
||||
self._client.finish_task(self)
|
||||
|
||||
|
||||
class PrepareImageTask(Task):
|
||||
|
||||
"""Prepare an image to be ran on the machine."""
|
||||
|
||||
task_name = 'prepare_image'
|
||||
|
||||
def __init__(self, client, task_id, image_info, reporting_interval=10):
|
||||
super(PrepareImageTask, self).__init__(client, task_id)
|
||||
self._image_info = image_info
|
||||
|
||||
def run():
|
||||
"""Run the Prepare Image task."""
|
||||
pass
|
@ -0,0 +1,190 @@
|
||||
"""
|
||||
Copyright 2013 Rackspace, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import main
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.python import failure
|
||||
from twisted.test.proto_helpers import StringTransportWithDisconnection
|
||||
from twisted.trial import unittest
|
||||
import simplejson as json
|
||||
from mock import Mock
|
||||
|
||||
from teeth_agent.protocol import RPCError, RPCProtocol, TeethAgentProtocol
|
||||
from teeth_agent import __version__ as AGENT_VERSION
|
||||
|
||||
|
||||
class FakeTCPTransport(StringTransportWithDisconnection, object):
|
||||
_aborting = False
|
||||
disconnected = False
|
||||
|
||||
setTcpKeepAlive = Mock(return_value=None)
|
||||
setTcpNoDelay = Mock(return_value=None)
|
||||
setTcpNoDelay = Mock(return_value=None)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.protocol.connectionLost(reason)
|
||||
|
||||
def abortConnection(self):
|
||||
if self.disconnected or self._aborting:
|
||||
return
|
||||
self._aborting = True
|
||||
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
|
||||
|
||||
class RPCProtocolTest(unittest.TestCase):
|
||||
"""RPC Protocol tests."""
|
||||
|
||||
def setUp(self):
|
||||
self.tr = FakeTCPTransport()
|
||||
self.proto = RPCProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0))
|
||||
self.proto.makeConnection(self.tr)
|
||||
self.tr.protocol = self.proto
|
||||
|
||||
def test_timeout(self):
|
||||
d = defer.Deferred()
|
||||
called = []
|
||||
orig = self.proto.connectionLost
|
||||
|
||||
def lost(arg):
|
||||
orig()
|
||||
called.append(True)
|
||||
d.callback(True)
|
||||
|
||||
self.proto.connectionLost = lost
|
||||
self.proto.timeoutConnection()
|
||||
|
||||
def check(ignore):
|
||||
self.assertEqual(called, [True])
|
||||
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def test_recv_command_no_params(self):
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived(json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF'}))
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message params.')
|
||||
|
||||
def test_recv_bogus_command(self):
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived(
|
||||
json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF', 'params': {'d': '1'}}))
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['fatal_error'], 'protocol violation: unsupported command.')
|
||||
|
||||
def test_recv_valid_json_no_id(self):
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived(json.dumps({'version': 'v913'}))
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message id.')
|
||||
|
||||
def test_recv_valid_json_no_version(self):
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived(json.dumps({'version': None, 'id': 'foo'}))
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message version.')
|
||||
|
||||
def test_recv_invalid_data(self):
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived('')
|
||||
self.proto.lineReceived('invalid json!')
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['fatal_error'], 'protocol error: unable to decode message.')
|
||||
|
||||
def test_recv_missing_key_parts(self):
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived(json.dumps(
|
||||
{'id': '1', 'version': 'v1'}))
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['fatal_error'], 'protocol error: malformed message.')
|
||||
|
||||
def test_recv_error_to_unknown_id(self):
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived(json.dumps(
|
||||
{'id': '1', 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['fatal_error'], 'protocol violation: unknown message id referenced.')
|
||||
|
||||
def _send_command(self):
|
||||
self.tr.clear()
|
||||
d = self.proto.send_command('test_command', {'body': 42})
|
||||
req = json.loads(self.tr.io.getvalue().strip())
|
||||
self.tr.clear()
|
||||
return (d, req)
|
||||
|
||||
def test_recv_result(self):
|
||||
dout = defer.Deferred()
|
||||
d, req = self._send_command()
|
||||
self.proto.lineReceived(json.dumps(
|
||||
{'id': req['id'], 'version': 'v1', 'result': {'duh': req['params']['body']}}))
|
||||
self.assertEqual(len(self.tr.io.getvalue()), 0)
|
||||
|
||||
def check(resp):
|
||||
self.assertEqual(resp.result['duh'], 42)
|
||||
dout.callback(True)
|
||||
|
||||
d.addCallback(check)
|
||||
|
||||
return dout
|
||||
|
||||
def test_recv_error(self):
|
||||
d, req = self._send_command()
|
||||
self.proto.lineReceived(json.dumps(
|
||||
{'id': req['id'], 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
|
||||
self.assertEqual(len(self.tr.io.getvalue()), 0)
|
||||
return self.assertFailure(d, RPCError)
|
||||
|
||||
def test_recv_fatal_error(self):
|
||||
d = defer.Deferred()
|
||||
called = []
|
||||
orig = self.proto.connectionLost
|
||||
|
||||
def lost(arg):
|
||||
self.failUnless(isinstance(arg, failure.Failure))
|
||||
orig()
|
||||
called.append(True)
|
||||
d.callback(True)
|
||||
|
||||
self.proto.connectionLost = lost
|
||||
|
||||
def check(ignore):
|
||||
self.assertEqual(called, [True])
|
||||
|
||||
d.addCallback(check)
|
||||
|
||||
self.tr.clear()
|
||||
self.proto.lineReceived(json.dumps({'fatal_error': 'you be broken'}))
|
||||
return d
|
||||
|
||||
|
||||
class TeethAgentProtocolTest(unittest.TestCase):
|
||||
"""Teeth Agent Protocol tests."""
|
||||
|
||||
def setUp(self):
|
||||
self.tr = FakeTCPTransport()
|
||||
self.proto = TeethAgentProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0), None)
|
||||
self.proto.makeConnection(self.tr)
|
||||
self.tr.protocol = self.proto
|
||||
|
||||
def test_on_connect(self):
|
||||
obj = json.loads(self.tr.io.getvalue().strip())
|
||||
self.assertEqual(obj['version'], 'v1')
|
||||
self.assertEqual(obj['method'], 'handshake')
|
||||
self.assertEqual(obj['method'], 'handshake')
|
||||
self.assertEqual(obj['params']['id'], 'a:b:c:d')
|
||||
self.assertEqual(obj['params']['version'], AGENT_VERSION)
|
@ -0,0 +1,53 @@
|
||||
"""
|
||||
Copyright 2013 Rackspace, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from twisted.trial import unittest
|
||||
from teeth_agent.task import Task
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class FakeClient(object):
|
||||
addService = Mock(return_value=None)
|
||||
running = Mock(return_value=0)
|
||||
update_task_status = Mock(return_value=None)
|
||||
finish_task = Mock(return_value=None)
|
||||
|
||||
|
||||
class TaskTest(unittest.TestCase):
|
||||
"""Event Emitter tests."""
|
||||
|
||||
def setUp(self):
|
||||
self.task_id = str(uuid.uuid4())
|
||||
self.client = FakeClient()
|
||||
self.task = Task(self.client, self.task_id, 'test_task')
|
||||
|
||||
def tearDown(self):
|
||||
del self.task_id
|
||||
del self.task
|
||||
del self.client
|
||||
|
||||
def test_run(self):
|
||||
self.assertEqual(self.task._state, 'starting')
|
||||
self.assertEqual(self.task._id, self.task_id)
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
self.task.startService()
|
||||
self.client.update_task_status.assert_called_once_with(self.task)
|
||||
self.task.complete()
|
||||
self.assertEqual(self.task._state, 'complete')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
Loading…
Reference in New Issue