out with the old, in with the new

This commit is contained in:
Russell Haering 2013-12-16 15:38:02 -08:00
parent b3ee1d828c
commit 27fdc13e49
26 changed files with 91 additions and 1808 deletions

18
.gitignore vendored
View File

@ -1,15 +1,9 @@
*.py[co]
bin/*
lib/*
include/*
local/*
src/*
build/*
.ve
_trial_temp/
subunit-output.txt
test-report.xml
twisted/plugins/dropin.cache
twistd.log
.coverage
_trial_coverage/
AUTHORS
ChangeLog
*.egg-info
*.egg
.tox/*
devenv/*

View File

@ -1,35 +0,0 @@
CODEDIR=teeth_agent
SCRIPTSDIR=scripts
UNITTESTS ?= ${CODEDIR}
PYTHONLINT=${SCRIPTSDIR}/python-lint.py
PYDIRS=${CODEDIR} ${SCRIPTSDIR}
default: lint test
test: unit
unit:
ifneq ($(JENKINS_URL), )
trial --random 0 --reporter=subunit ${UNITTESTS} | tee subunit-output.txt
tail -n +3 subunit-output.txt | subunit2junitxml > test-report.xml
else
trial --random 0 ${UNITTESTS}
endif
coverage:
coverage run --source=${CODEDIR} --branch `which trial` ${UNITTESTS} && coverage html -d _trial_coverage --omit="*/tests/*"
env:
./scripts/bootstrap-virtualenv.sh
lint:
${PYTHONLINT} ${PYDIRS}
clean:
find . -name '*.pyc' -delete
find . -name '.coverage' -delete
find . -name '_trial_coverage' -print0 | xargs --null rm -rf
find . -name '_trial_temp' -print0 | xargs --null rm -rf
rm -rf dist build *.egg-info twisted/plugins/dropin.cache

View File

@ -1,68 +1,3 @@
# teeth-agent
An agent for rebuilding and controlling Teeth chassis.
## Protocol
JSON. Line Delimitated. Bi-directional. Most messages contain:
* `version` Message Version - String
* `id` Message ID - String
Commands contain:
* `method` Method - String
* `params` Params - Hash of parameters to Method.
Success Responses contain:
* `id` Original Message ID from Command - String
* `result` Result from Command - Hash
Error Responses contain:
* `id` Original Message ID from Command - String
* `error` Result from Command - Hash, `.msg` contains human readable error. Other fields might come later.
Fatal Error:
* `fatal_error` - String - Fatal error message; Connection should be closed.
## Builders
Teeth Agent master builder: https://jenkins.t.k1k.me/job/teeth-agent-master/
Teeth Agent PR builder: https://jenkins.t.k1k.me/job/teeth-agent-pr/
Builds are automatically triggered on pushes to master, or upon opening a PR.
### Commands
#### All Protocol Implementations.
* `ping`: (All) Params are echo'ed back as results.
#### Agent specific Commands
* `log`: (Agent->Server) Log a structured message from the Agent.
* `status`: (Server->Agent) Uptime, version, and other fields reported.
* `task_status`: (Agent->Server) Update status of a task. Task has an `.task_id` which was previously designated. Task has a `.state`, which is `running`, `error` or `complete`. `running` will additionally contain `.eta` and `.percent`, a measure of how much work estimated to remain in seconds and how much work is done. Once `error` or `complete` is sent, no more updates will be sent. `error` state includes an additional human readable `.msg` field.
#### Decommission
* `decom.disk_erase`: (Server->Agent) Erase all attached block devices securely. Takes a `task_id`.
* `decom.firmware_secure`: (Server->Agent) Update Firmwares/BIOS versions and settings. Takes a `task_id`.
* `decom.qc`: (Server->Agent) Run quality control checks on chassis model. Includes sending specifications of chassis (cpu types, disks, etc). Takes a `task_id`.
#### Standbye
* `standbye.cache_images`: (Server->Agent) Cache an set of image UUID on local storage. Ordered in priority, chassis may only cache a subset depending on local storage. Takes a `task_id`.
* `standbye.prepare_image`: (Server->Agent) Prepare a image UUID to be ran. Takes a `task_id`.
* `standbye.run_image`: (Server->Agent) Run an image UUID. Must include Config Drive Settings. Agent will write config drive, and setup grub. If the Agent can detect a viable kexec target it will kexec into it, otherwise reboot. Takes a `task_id`.

View File

@ -1,7 +1,3 @@
setuptools==1.1.6
Twisted==13.1.0
argparse==1.2.1
wsgiref==0.1.2
zope.interface==4.0.5
structlog==0.3.0
treq==0.2.0
Werkzeug==0.9.4
requests==2.0.0
-e git+git@github.com:racker/teeth-rest.git@4d5d81833bad3c4235bdea6cf6d53a0a4defb275#egg=teeth_rest-master

View File

@ -1,23 +0,0 @@
#!/bin/bash
#
# Create an initial virtualenv based on the VE_DIR environment variable (.ve)
# by default. This is used by the Makefile `make env` to allow bootstrapping in
# environments where virtualenvwrapper is unavailable or unappropriate. Such
# as on Jenkins.
#
set -e
VE_DIR=${VE_DIR:=.ve}
if [[ -d ${VE_DIR} ]]; then
echo "Skipping build virtualenv"
else
echo "Building complete virtualenv"
virtualenv ${VE_DIR}
fi
source ${VE_DIR}/bin/activate
pip install --upgrade -r requirements.txt -r dev-requirements.txt

View File

@ -1,139 +0,0 @@
#!/usr/bin/env python
"""
Enforces Python coding standards via pep8, pyflakes and pylint
Installation:
pip install pep8 - style guide
pip install pep257 - for docstrings
pip install pyflakes - unused imports and variable declarations
pip install plumbum - used for executing shell commands
This script can be called from the git pre-commit hook with a
--git-precommit option
"""
import os
import pep257
import re
import sys
from plumbum import local, cli, commands
pep8_options = [
'--max-line-length=105'
]
def lint(to_lint):
"""
Run all linters against a list of files.
:param to_lint: a list of files to lint.
"""
exit_code = 0
for linter, options in (('pyflakes', []), ('pep8', pep8_options)):
try:
output = local[linter](*(options + to_lint))
except commands.ProcessExecutionError as e:
output = e.stdout
if output:
exit_code = 1
print "{0} Errors:".format(linter)
print output
output = hacked_pep257(to_lint)
if output:
exit_code = 1
print "Docstring Errors:".format(linter.upper())
print output
sys.exit(exit_code)
def hacked_pep257(to_lint):
"""
Check for the presence of docstrings, but ignore some of the options
"""
def ignore(*args, **kwargs):
pass
pep257.check_blank_before_after_class = ignore
pep257.check_blank_after_last_paragraph = ignore
pep257.check_blank_after_summary = ignore
pep257.check_ends_with_period = ignore
pep257.check_one_liners = ignore
pep257.check_imperative_mood = ignore
original_check_return_type = pep257.check_return_type
def better_check_return_type(def_docstring, context, is_script):
"""
Ignore private methods
"""
def_name = context.split()[1]
if def_name.startswith('_') and not def_name.endswith('__'):
original_check_return_type(def_docstring, context, is_script)
pep257.check_return_type = better_check_return_type
errors = []
for filename in to_lint:
with open(filename) as f:
source = f.read()
if source:
errors.extend(pep257.check_source(source, filename))
return '\n'.join([str(error) for error in sorted(errors)])
class Lint(cli.Application):
"""
Command line app for VmrunWrapper
"""
DESCRIPTION = "Lints python with pep8, pep257, and pyflakes"
git = cli.Flag("--git-precommit", help="Lint only modified git files",
default=False)
def main(self, *directories):
"""
The actual logic that runs the linters
"""
if not self.git and len(directories) == 0:
print ("ERROR: At least one directory must be provided (or the "
"--git-precommit flag must be passed.\n")
self.help()
return
if len(directories) > 0:
find = local['find']
files = []
for directory in directories:
real = os.path.expanduser(directory)
if not os.path.exists(real):
raise ValueError("{0} does not exist".format(directory))
files.extend(find(real, '-not', '-name', '._*', '-name', '*.py').strip().split('\n'))
else:
status = local['git']('status', '--porcelain', '-uno')
root = local['git']('rev-parse', '--show-toplevel').strip()
# get all modified or added python files
modified = re.findall(r"^\s[AM]\s+(\S+\.py)$", status, re.MULTILINE)
# now just get the path part, which all should be relative to the
# root
files = [os.path.join(root, line.split(' ', 1)[-1].strip())
for line in modified]
if len(files) > 0:
print "Linting {0} python files.\n".format(len(files))
lint(files)
else:
print "No python files found to lint.\n"
if __name__ == "__main__":
Lint.run()

19
setup.cfg Normal file
View File

@ -0,0 +1,19 @@
[metadata]
name = teeth-agent
author = Rackspace
author-email = teeth-dev@lists.rackspace.com
summary = Teeth Host Agent
license = Apache-2
classifier =
Development Status :: 4 - Beta
Intended Audience :: Developers
License :: OSI Approved :: Apache Software License
Operating System :: OS Independent
Programming Language :: Python
[files]
packages =
teeth_agent
[entry_points]
console_scripts =
teeth-agent = teeth_agent.cmd.agent:run

View File

@ -16,28 +16,9 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from setuptools import setup, find_packages
import codecs
import os
import re
here = os.path.abspath(os.path.dirname(__file__))
def read(*parts):
return codecs.open(os.path.join(here, *parts), 'r').read()
def find_version(*file_paths):
version_file = read(*file_paths)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
from setuptools import setup
setup(
name='teeth-agent',
version=find_version('teeth_agent', '__init__.py'),
packages=find_packages(),
setup_requires=['pbr'],
pbr=True,
)

View File

@ -13,7 +13,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = ["__version__"]
__version__ = "0.1-dev"

View File

@ -1,65 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from teeth_agent.client import TeethClient
from teeth_agent.logging import get_logger
from teeth_agent.protocol import require_parameters, CommandValidationError
from teeth_agent.task import PrepareImageTask
log = get_logger()
class StandbyAgent(TeethClient):
"""
Agent to perform standbye operations.
"""
AGENT_MODE = 'STANDBY'
def __init__(self, addrs):
super(StandbyAgent, self).__init__(addrs)
self._add_handler('v1', 'standby.cache_images', self.cache_images)
self._add_handler('v1', 'standby.prepare_image', self.prepare_image)
self._add_handler('v1', 'standby.run_image', self.run_image)
log.msg('Starting agent', addrs=addrs)
@require_parameters('task_id', 'image_ids')
def cache_images(self, command):
"""
Cache a set of images. Ordered in priority, we may only cache a
subset depending on storage availability.
"""
if not isinstance(command.params['image_ids'], list):
raise CommandValidationError('"image_ids" must be a list')
pass
@require_parameters('task_id', 'image_id')
def prepare_image(self, command):
"""Prepare an Image."""
task_id = command.params['task_id']
image_id = command.params['image_id']
t = PrepareImageTask(self, task_id, image_id)
t.run()
@require_parameters('task_id', 'image_id')
def run_image(self, command):
"""
Run the specified image.
"""
def update_task_status(self, task):
"""Send an updated task status to the agent endpoint."""
pass

View File

@ -1,126 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from twisted.application.service import MultiService
from twisted.application.internet import TimerService
from twisted.internet import defer
from teeth_agent.logging import get_logger
__all__ = ['BaseTask', 'MultiTask']
class BaseTask(MultiService, object):
"""
Task to execute, reporting status periodically to TeethClient instance.
"""
task_name = 'task_undefined'
def __init__(self, client, task_id, reporting_interval=10):
super(BaseTask, self).__init__()
self.log = get_logger(task_id=task_id, task_name=self.task_name)
self.setName(self.task_name + '.' + task_id)
self._client = client
self._id = task_id
self._percent = 0
self._reporting_interval = reporting_interval
self._state = 'starting'
self._timer = TimerService(self._reporting_interval, self._tick)
self._timer.setServiceParent(self)
self._error_msg = None
self._done = False
self._d = defer.Deferred()
def _run(self):
"""Do the actual work here."""
def run(self):
"""Run the Task."""
# setServiceParent actually starts the task if it is already running
# so we run it in start.
if not self.parent:
self.setServiceParent(self._client)
self._run()
return self._d
def _tick(self):
if not self.running:
# log.debug("_tick called while not running :()")
return
if self._state in ['error', 'complete']:
self.stopService()
return self._client.update_task_status(self)
def error(self, message, *args, **kwargs):
"""Error out running of the task."""
self._error_msg = message
self._state = 'error'
self.stopService()
def complete(self, *args, **kwargs):
"""Complete running of the task."""
self._state = 'complete'
self.stopService()
def startService(self):
"""Start the Service."""
self._state = 'running'
super(BaseTask, self).startService()
def stopService(self):
"""Stop the Service."""
super(BaseTask, self).stopService()
if self._state not in ['error', 'complete']:
self.log.err("told to shutdown before task could complete, marking as error.")
self._error_msg = 'service being shutdown'
self._state = 'error'
if self._done is False:
self._done = True
self._d.callback(None)
self._client.finish_task(self)
class MultiTask(BaseTask):
"""Run multiple tasks in parallel."""
def __init__(self, client, task_id, reporting_interval=10):
super(MultiTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
self._tasks = []
def _tick(self):
if len(self._tasks):
percents = [t._percent for t in self._tasks]
self._percent = sum(percents)/float(len(percents))
else:
self._percent = 0
super(MultiTask, self)._tick()
def _run(self):
ds = []
for t in self._tasks:
ds.append(t.run())
dl = defer.DeferredList(ds)
dl.addBoth(self.complete, self.error)
def add_task(self, task):
"""Add a task to be ran."""
task.setServiceParent(self)
self._tasks.append(task)

View File

@ -1,54 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from teeth_agent.base_task import BaseTask
import treq
class ImageDownloaderTask(BaseTask):
"""Download image to cache. """
task_name = 'image_download'
def __init__(self, client, task_id, image_info, destination_filename, reporting_interval=10):
super(ImageDownloaderTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
self._destination_filename = destination_filename
self._image_id = image_info.id
self._image_hashes = image_info.hashes
self._iamge_urls = image_info.urls
self._destination_filename = destination_filename
def _run(self):
# TODO: pick by protocol priority.
url = self._iamge_urls[0]
# TODO: more than just download, sha1 it.
return self._download_image_to_file(url)
def _tick(self):
# TODO: get file download percentages.
self.percent = 0
super(ImageDownloaderTask, self)._tick()
def _download_image_to_file(self, url):
destination = open(self._destination_filename, 'wb')
def push(data):
if self.running:
destination.write(data)
d = treq.get(url)
d.addCallback(treq.collect, push)
d.addBoth(lambda _: destination.close())
return d

View File

@ -1,170 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import time
import json
import random
import tempfile
from twisted.application.service import MultiService
from twisted.application.internet import TCPClient
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.internet.defer import maybeDeferred
from twisted.python.failure import Failure
from teeth_agent import __version__ as AGENT_VERSION
from teeth_agent.protocol import TeethAgentProtocol
from teeth_agent.logging import get_logger
log = get_logger()
__all__ = ["TeethClientFactory", "TeethClient"]
class TeethClientFactory(ReconnectingClientFactory, object):
"""
Protocol Factory for the Teeth Client.
"""
protocol = TeethAgentProtocol
initialDelay = 1.0
maxDelay = 120
def __init__(self, encoder, parent):
super(TeethClientFactory, self).__init__()
self._encoder = encoder
self._parent = parent
def buildProtocol(self, addr):
"""Create protocol for an address."""
self.resetDelay()
proto = self.protocol(self._encoder, addr, self._parent)
self._parent.add_protocol_instance(proto)
return proto
def clientConnectionFailed(self, connector, reason):
"""clientConnectionFailed"""
log.err('Failed to connect, re-trying', delay=self.delay)
super(TeethClientFactory, self).clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector, reason):
"""clientConnectionLost"""
log.err('Lost connection, re-connecting', delay=self.delay)
super(TeethClientFactory, self).clientConnectionLost(connector, reason)
class TeethClient(MultiService, object):
"""
High level Teeth Client.
"""
client_factory_cls = TeethClientFactory
client_encoder_cls = json.JSONEncoder
def __init__(self, addrs):
super(TeethClient, self).__init__()
self.setName('teeth-agent')
self._client_encoder = self.client_encoder_cls()
self._client_factory = self.client_factory_cls(self._client_encoder, self)
self._log = get_logger()
self._start_time = time.time()
self._protocols = []
self._outmsg = []
self._connectaddrs = addrs
self._handlers = {
'v1': {
'status': self._handle_status,
}
}
@property
def conf_image_cache_path(self):
"""Path to iamge cache."""
# TODO: improve:
return tempfile.gettempdir()
def startService(self):
"""Start the Service."""
super(TeethClient, self).startService()
for host, port in self._connectaddrs:
service = TCPClient(host, port, self._client_factory)
service.setName("teeth-agent[%s:%d]".format(host, port))
self.addService(service)
self._connectaddrs = []
def remove_endpoint(self, host, port):
"""Remove an Agent Endpoint from the active list."""
def op(protocol):
if protocol.address.host == host and protocol.address.port == port:
protocol.loseConnectionSoon()
return True
return False
self._protocols[:] = [protocol for protocol in self._protocols if not op(protocol)]
def add_endpoint(self, host, port):
"""Add an agent endpoint to the """
self._connectaddrs.append([host, port])
self.start()
def add_protocol_instance(self, protocol):
"""Add a running protocol to the parent."""
protocol.on('command', self._on_command)
self._protocols.append(protocol)
def _on_command(self, topic, message):
if message.version not in self._handlers:
message.protocol.fatal_error('unknown message version')
return
if message.method not in self._handlers[message.version]:
message.protocol.fatal_error('unknown message method')
return
handler = self._handlers[message.version][message.method]
d = maybeDeferred(handler, message)
d.addBoth(self._send_response, message)
def _send_response(self, result, message):
"""Send a response to a message."""
if isinstance(result, Failure):
self._log.err(result)
message.protocol.send_error_response(result.value.message, message)
else:
message.protocol.send_response(result, message)
def _handle_status(self, command):
return {
'mode': self.AGENT_MODE,
'uptime': time.time() - self._start_time,
'version': AGENT_VERSION,
}
def _add_handler(self, version, command, func):
self._handlers[version][command] = func
def _send_command(self, method, params):
protocol = random.choice(self._protocols)
return protocol.send_command(method, params)
def send_log(self, message, **kwargs):
"""
Send a log message to the endpoint.
"""
event = {}
event.update(kwargs)
event['message'] = message
event['time'] = time.time()
return self._send_command('log', event)

View File

@ -0,0 +1,15 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

View File

@ -15,15 +15,5 @@ limitations under the License.
"""
"""
Teeth Agent Twisted Application Plugin.
"""
from twisted.application.service import ServiceMaker
TeethAgent = ServiceMaker(
"Teeth Agent Client Application",
"teeth_agent.service",
"Teeth Agent for decomissioning and standbye",
"teeth-agent"
)
def run():
pass

View File

@ -1,81 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
from twisted.internet.defer import maybeDeferred
__all__ = ['EventEmitter', 'EventEmitterUnhandledError']
class EventEmitterUnhandledError(RuntimeError):
"""
Error caused by no subscribers to an `error` event.
"""
pass
class EventEmitter(object):
"""
Extremely simple pubsub style things in-process.
Styled after the Node.js EventEmitter class
"""
__slots__ = ['_subs']
def __init__(self):
self._subs = defaultdict(list)
def emit(self, topic, *args):
"""
Emit an event to a specific topic with a payload.
"""
ds = []
if topic == "error":
if len(self._subs[topic]) == 0:
raise EventEmitterUnhandledError("No Subscribers to an error event found")
for s in self._subs[topic]:
ds.append(maybeDeferred(s, topic, *args))
return ds
def on(self, topic, callback):
"""
Add a handler for a specific topic.
"""
self.emit("newListener", topic, callback)
self._subs[topic].append(callback)
def once(self, topic, callback):
"""
Execute a specific handler just once.
"""
def oncecb(*args):
self.removeListener(topic, oncecb)
callback(*args)
self.on(topic, oncecb)
def removeListener(self, topic, callback):
"""
Remove a handler from a topic.
"""
self._subs[topic] = filter(lambda x: x != callback, self._subs[topic])
def removeAllListeners(self, topic):
"""
Remove all listeners from a specific topic.
"""
del self._subs[topic]

View File

@ -1,46 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import structlog
CONFIGURED_LOGGING = False
def configure():
"""
Configure logging subsystem.
"""
global CONFIGURED_LOGGING
if CONFIGURED_LOGGING:
return
CONFIGURED_LOGGING = True
structlog.configure(
context_class=dict,
logger_factory=structlog.twisted.LoggerFactory(),
wrapper_class=structlog.twisted.BoundLogger,
cache_logger_on_first_use=True)
def get_logger(*args, **kwargs):
"""
Get a logger instance.
"""
configure()
return structlog.get_logger(*args, **kwargs)

View File

@ -1,330 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from functools import wraps
import json
import uuid
import time
from twisted.internet import defer, task
from twisted.protocols import policies
from twisted.protocols.basic import LineReceiver
from teeth_agent import __version__ as AGENT_VERSION
from teeth_agent.events import EventEmitter
from teeth_agent.logging import get_logger
log = get_logger()
DEFAULT_PROTOCOL_VERSION = 'v1'
__all__ = ['RPCMessage', 'RPCCommand', 'RPCProtocol', 'TeethAgentProtocol']
class RPCMessage(object):
"""
Wraps all RPC messages.
"""
def __init__(self, protocol, message):
super(RPCMessage, self).__init__()
self.protocol = protocol
self.id = message['id']
self.version = message['version']
class RPCCommand(RPCMessage):
"""
Wraps incoming RPC Commands.
"""
def __init__(self, protocol, message):
super(RPCCommand, self).__init__(protocol, message)
self.method = message['method']
self.params = message['params']
class RPCResponse(RPCMessage):
"""
Wraps incoming RPC Responses.
"""
def __init__(self, protocol, message):
super(RPCResponse, self).__init__(protocol, message)
self.result = message.get('result', None)
class RPCError(RPCMessage, RuntimeError):
"""
Wraps incoming RPC Errors Responses.
"""
def __init__(self, protocol, message):
super(RPCError, self).__init__(protocol, message)
self.error = message.get('error', 'unknown error')
self._raw_message = message
class ImageInfo(object):
"""
Metadata about a machine image.
"""
def __init__(self, image_id, image_urls, image_hashes):
super(ImageInfo, self).__init__()
self.id = image_id
self.urls = image_urls
self.hashes = image_hashes
class CommandValidationError(RuntimeError):
"""
Exception class which can be used to return an error when the
opposite party attempts a command with invalid parameters.
"""
def __init__(self, message, fatal=False):
super(CommandValidationError, self).__init__(message)
self.fatal = fatal
def require_parameters(*parameters, **kwargs):
"""
Return a decorator which wraps a function, `fn`, and verifies that
when `fn` is called, each of the parameters passed to
`require_parameters` has been passed to `fn` as a keyword argument.
For example::
@require_parameters('foo')
def my_handler(self, command):
return command.params['foo']
If a parameter is missing, an error will be returned to the opposite
party. If `fatal=True`, a fatal error will be returned and the
connection terminated.
"""
fatal = kwargs.get('fatal', False)
def deco(fn):
@wraps(fn)
def decorated(instance, command):
for parameter in parameters:
if parameter not in command.params:
message = 'missing parameter "{}" in "{}" command'.format(parameter, command.method)
raise CommandValidationError(message, fatal=fatal)
return fn(instance, command)
return decorated
return deco
class RPCProtocol(LineReceiver,
EventEmitter,
policies.TimeoutMixin):
"""
Twisted Protocol handler for the RPC Protocol of the Teeth
Agent <-> Endpoint communication.
The protocol is a simple JSON newline based system. Client or server
can request methods with a message id. The recieving party
responds to this message id.
The low level details are in C{RPCProtocol} while the higher level
functions are in C{TeethAgentProtocol}
"""
timeOut = 60
delimiter = '\n'
MAX_LENGTH = 1024 * 512
def __init__(self, encoder, address):
super(RPCProtocol, self).__init__()
self.encoder = encoder
self.address = address
self._pending_command_deferreds = {}
self._fatal_error = False
self._log = log.bind(host=address.host, port=address.port)
self.setTimeout(self.timeOut)
def timeoutConnection(self):
"""Action called when the connection has hit a timeout."""
self._log.msg('connection timed out', timeout=self.timeOut)
self.transport.abortConnection()
def connectionLost(self, *args, **kwargs):
"""
Handle connection loss. Don't try to re-connect, that is handled
at the factory level.
"""
super(RPCProtocol, self).connectionLost(*args, **kwargs)
self.setTimeout(None)
self.emit('end')
def connectionMade(self):
"""TCP hard. We made it. Maybe."""
super(RPCProtocol, self).connectionMade()
self._log.msg('Connection established.')
self.transport.setTcpKeepAlive(True)
self.transport.setTcpNoDelay(True)
self.emit('connect')
def sendLine(self, line):
"""Send a line of content to our peer."""
super(RPCProtocol, self).sendLine(line)
def lineReceived(self, line):
"""Process a line of data."""
self.resetTimeout()
line = line.strip()
if not line:
return
self._log.msg('Got Line', line=line)
try:
message = json.loads(line)
except Exception:
return self.fatal_error('protocol error: unable to decode message.')
if 'fatal_error' in message:
self._log.err('fatal transport error occurred', error_msg=message['fatal_error'])
self.transport.abortConnection()
return
if not message.get('version', None):
return self.fatal_error("protocol violation: missing message version.")
if not message.get('id', None):
return self.fatal_error("protocol violation: missing message id.")
if 'method' in message:
if 'params' not in message:
return self.fatal_error("protocol violation: missing message params.")
if not isinstance(message['params'], dict):
return self.fatal_error("protocol violation: message params must be an object.")
msg = RPCCommand(self, message)
self._handle_command(msg)
elif 'error' in message:
msg = RPCError(self, message)
self._handle_response(msg)
elif 'result' in message:
msg = RPCResponse(self, message)
self._handle_response(msg)
else:
return self.fatal_error('protocol error: malformed message.')
def fatal_error(self, message):
"""Send a fatal error message, and disconnect."""
self._log.msg('sending a fatal error', message=message)
if not self._fatal_error:
self._fatal_error = True
self.sendLine(self.encoder.encode({
'fatal_error': message
}))
self.transport.loseConnection()
def send_command(self, method, params):
"""Send a new command."""
message_id = str(uuid.uuid4())
d = defer.Deferred()
self._pending_command_deferreds[message_id] = d
self.sendLine(self.encoder.encode({
'id': message_id,
'version': DEFAULT_PROTOCOL_VERSION,
'method': method,
'params': params,
}))
return d
def send_response(self, result, responding_to):
"""Send a result response."""
self.sendLine(self.encoder.encode({
'id': responding_to.id,
'version': responding_to.version,
'result': result,
}))
def send_error_response(self, error, responding_to):
"""Send an error response."""
self.sendLine(self.encoder.encode({
'id': responding_to.id,
'version': responding_to.version,
'error': error,
}))
def _handle_response(self, message):
d = self._pending_command_deferreds.pop(message.id, None)
if not d:
return self.fatal_error("protocol violation: unknown message id referenced.")
if isinstance(message, RPCError):
d.errback(message)
else:
d.callback(message)
def _handle_command(self, message):
d = self.emit('command', message)
if len(d) == 0:
return self.fatal_error("protocol violation: unsupported command.")
# TODO: do we need to wait on anything here?
pass
class TeethAgentProtocol(RPCProtocol):
"""
Handles higher level logic of the RPC protocol like authentication and handshakes.
"""
def __init__(self, encoder, address, parent):
super(TeethAgentProtocol, self).__init__(encoder, address)
self.encoder = encoder
self.address = address
self.parent = parent
self.ping_interval = self.timeOut / 3
self.pinger = task.LoopingCall(self.ping_endpoint)
self.once('connect', self._once_connect)
self.once('end', self._once_disconnect)
def _once_connect(self, event):
def _response(result):
self._log.msg('Handshake successful', connection_id=result.id)
self._log.msg('beginning pinging endpoint', ping_interval=self.ping_interval)
self.pinger.start(self.ping_interval)
return self.send_command('handshake',
{'id': 'a:b:c:d', 'version': AGENT_VERSION}).addCallback(_response)
def _once_disconnect(self, event):
if self.pinger.running:
self.pinger.stop()
def ping_endpoint(self):
"""
Send a ping command to the agent endpoint.
"""
sent_at = time.time()
def _log_ping_response(response):
seconds = time.time() - sent_at
self._log.msg('received ping response', response_time=seconds)
self._log.msg('pinging agent endpoint')
self.send_command('ping', {}).addCallback(_log_ping_response)

View File

@ -1,47 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import sys
from twisted.python import usage
from twisted.application.service import MultiService
from teeth_agent.logging import configure as configureLogging
from teeth_agent.agent import StandbyAgent
class Options(usage.Options):
"""Additional options for the Teeth Agent"""
synopsis = """%s [options]
""" % (
os.path.basename(sys.argv[0]),)
optParameters = [["mode", "m", "standbye", "Mode to run Agent in, standbye or decom."]]
def makeService(config):
"""Create an instance of the Teeth-Agent service."""
configureLogging()
s = MultiService()
if config['mode'] == "standbye":
agent = StandbyAgent([['localhost', 8081]])
agent.setServiceParent(s)
else:
raise SystemExit("Invalid mode")
return s

View File

@ -1,57 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from teeth_agent.base_task import MultiTask, BaseTask
from teeth_agent.cache_image import ImageDownloaderTask
__all__ = ['CacheImagesTask', 'PrepareImageTask']
class CacheImagesTask(MultiTask):
"""Cache an array of images on a machine."""
task_name = 'cache_images'
def __init__(self, client, task_id, images, reporting_interval=10):
super(CacheImagesTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
self._images = images
for image in self._images:
image_path = os.path.join(client.conf_image_cache_path, image.id + '.img')
t = ImageDownloaderTask(client,
task_id, image,
image_path,
reporting_interval=reporting_interval)
self.add_task(t)
class PrepareImageTask(BaseTask):
"""Prepare an image to be ran on the machine."""
task_name = 'prepare_image'
def __init__(self, client, task_id, image_info, reporting_interval=10):
super(PrepareImageTask, self).__init__(client, task_id)
self._image_info = image_info
def _run(self):
"""Run the Prepare Image task."""
self.log.msg('running prepare_image', image_info=self._image_info)
pass

View File

@ -1,3 +1,15 @@
"""
Unit Tests for the Teeth Agent.
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

View File

@ -1,103 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from twisted.trial import unittest
from teeth_agent.events import EventEmitter, EventEmitterUnhandledError
class EventEmitterTest(unittest.TestCase):
"""Event Emitter tests."""
def setUp(self):
self.ee = EventEmitter()
def tearDown(self):
del self.ee
def test_empty_emit(self):
self.ee.emit("nothing.here", "some args")
self.ee.emit("nothing.here2")
def test_single_event(self):
self.count = 0
def got_it(topic):
self.assertEqual(topic, "test")
self.count += 1
self.ee.on("test", got_it)
self.ee.emit("test")
self.ee.emit("other_test")
self.assertEqual(self.count, 1)
def test_multicb(self):
self.count = 0
def got_it(topic):
self.assertEqual(topic, "test")
self.count += 1
self.ee.on("test", got_it)
self.ee.on("test", got_it)
self.ee.emit("test")
self.assertEqual(self.count, 2)
def test_once(self):
self.count = 0
def got_it(topic):
self.assertEqual(topic, "test")
self.count += 1
self.ee.once("test", got_it)
self.ee.emit("test")
self.ee.emit("test")
self.assertEqual(self.count, 1)
def test_removeAllListeners(self):
self.count = 0
def got_it(topic):
self.assertEqual(topic, "test")
self.count += 1
self.ee.on("test", got_it)
self.ee.emit("test")
self.ee.removeAllListeners("test")
self.ee.emit("test")
self.assertEqual(self.count, 1)
def test_error(self):
self.count = 0
try:
self.ee.emit("error")
except EventEmitterUnhandledError:
self.count += 1
self.assertEqual(self.count, 1)

View File

@ -1,216 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from twisted.internet import defer
from twisted.internet import main
from twisted.internet.address import IPv4Address
from twisted.python import failure
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.trial import unittest
from mock import Mock
from teeth_agent.protocol import RPCCommand, RPCProtocol, RPCError, \
CommandValidationError, TeethAgentProtocol, require_parameters
from teeth_agent import __version__ as AGENT_VERSION
class FakeTCPTransport(StringTransportWithDisconnection, object):
_aborting = False
disconnected = False
setTcpKeepAlive = Mock(return_value=None)
setTcpNoDelay = Mock(return_value=None)
setTcpNoDelay = Mock(return_value=None)
def connectionLost(self, reason):
self.protocol.connectionLost(reason)
def abortConnection(self):
if self.disconnected or self._aborting:
return
self._aborting = True
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
class RPCProtocolTest(unittest.TestCase):
"""RPC Protocol tests."""
def setUp(self):
self.tr = FakeTCPTransport()
self.proto = RPCProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0))
self.proto.makeConnection(self.tr)
self.tr.protocol = self.proto
def tearDown(self):
self.tr.abortConnection()
def test_timeout(self):
d = defer.Deferred()
called = []
orig = self.proto.connectionLost
def lost(arg):
orig()
called.append(True)
d.callback(True)
self.proto.connectionLost = lost
self.proto.timeoutConnection()
def check(ignore):
self.assertEqual(called, [True])
d.addCallback(check)
return d
def test_recv_command_no_params(self):
self.tr.clear()
self.proto.lineReceived(json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message params.')
def test_recv_bogus_command(self):
self.tr.clear()
self.proto.lineReceived(
json.dumps({'id': '1', 'version': 'v1', 'method': 'BOGUS_STUFF', 'params': {'d': '1'}}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: unsupported command.')
def test_recv_valid_json_no_id(self):
self.tr.clear()
self.proto.lineReceived(json.dumps({'version': 'v913'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message id.')
def test_recv_valid_json_no_version(self):
self.tr.clear()
self.proto.lineReceived(json.dumps({'version': None, 'id': 'foo'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: missing message version.')
def test_recv_invalid_data(self):
self.tr.clear()
self.proto.lineReceived('')
self.proto.lineReceived('invalid json!')
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol error: unable to decode message.')
def test_recv_missing_key_parts(self):
self.tr.clear()
self.proto.lineReceived(json.dumps(
{'id': '1', 'version': 'v1'}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol error: malformed message.')
def test_recv_error_to_unknown_id(self):
self.tr.clear()
self.proto.lineReceived(json.dumps(
{'id': '1', 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['fatal_error'], 'protocol violation: unknown message id referenced.')
def _send_command(self):
self.tr.clear()
d = self.proto.send_command('test_command', {'body': 42})
req = json.loads(self.tr.io.getvalue().strip())
self.tr.clear()
return (d, req)
def test_recv_result(self):
dout = defer.Deferred()
d, req = self._send_command()
self.proto.lineReceived(json.dumps(
{'id': req['id'], 'version': 'v1', 'result': {'duh': req['params']['body']}}))
self.assertEqual(len(self.tr.io.getvalue()), 0)
def check(resp):
self.assertEqual(resp.result['duh'], 42)
dout.callback(True)
d.addCallback(check)
return dout
def test_recv_error(self):
d, req = self._send_command()
self.proto.lineReceived(json.dumps(
{'id': req['id'], 'version': 'v1', 'error': {'msg': 'something is wrong'}}))
self.assertEqual(len(self.tr.io.getvalue()), 0)
return self.assertFailure(d, RPCError)
def test_recv_fatal_error(self):
d = defer.Deferred()
called = []
orig = self.proto.connectionLost
def lost(arg):
self.failUnless(isinstance(arg, failure.Failure))
orig()
called.append(True)
d.callback(True)
self.proto.connectionLost = lost
def check(ignore):
self.assertEqual(called, [True])
d.addCallback(check)
self.tr.clear()
self.proto.lineReceived(json.dumps({'fatal_error': 'you be broken'}))
return d
class TeethAgentProtocolTest(unittest.TestCase):
"""Teeth Agent Protocol tests."""
def setUp(self):
self.tr = FakeTCPTransport()
self.proto = TeethAgentProtocol(json.JSONEncoder(), IPv4Address('TCP', '127.0.0.1', 0), None)
self.proto.makeConnection(self.tr)
self.tr.protocol = self.proto
def tearDown(self):
self.tr.abortConnection()
def test_on_connect(self):
obj = json.loads(self.tr.io.getvalue().strip())
self.assertEqual(obj['version'], 'v1')
self.assertEqual(obj['method'], 'handshake')
self.assertEqual(obj['method'], 'handshake')
self.assertEqual(obj['params']['id'], 'a:b:c:d')
self.assertEqual(obj['params']['version'], AGENT_VERSION)
class RequiresParamsTest(unittest.TestCase):
@require_parameters('foo')
def _test_for_foo(self, command):
self._called = True
def test_require_parameters_there(self):
self._called = False
m = RPCCommand(self, {'version': 1, 'id': 'a', 'method': 'test', 'params': {'foo': 1}})
self._test_for_foo(m)
self.assertEqual(self._called, True)
def test_require_parameters_not_there(self):
self._called = False
m = RPCCommand(self, {'version': 1, 'id': 'a', 'method': 'test', 'params': {}})
self.assertRaises(CommandValidationError, self._test_for_foo, m)
self.assertEqual(self._called, False)

View File

@ -1,193 +0,0 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import uuid
import shutil
import tempfile
import hashlib
import os
from mock import Mock, patch
from twisted.internet import defer
from twisted.trial import unittest
from twisted.web.client import ResponseDone
from twisted.python.failure import Failure
from teeth_agent.protocol import ImageInfo
from teeth_agent.base_task import BaseTask, MultiTask
from teeth_agent.cache_image import ImageDownloaderTask
class FakeClient(object):
def __init__(self):
self.addService = Mock(return_value=None)
self.running = Mock(return_value=0)
self.update_task_status = Mock(return_value=None)
self.finish_task = Mock(return_value=None)
class TestTask(BaseTask):
task_name = 'test_task'
class TaskTest(unittest.TestCase):
"""Basic tests of the Task API."""
def setUp(self):
self.task_id = str(uuid.uuid4())
self.client = FakeClient()
self.task = TestTask(self.client, self.task_id)
def tearDown(self):
del self.task_id
del self.task
del self.client
def test_error(self):
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_called_once_with(self.task)
self.task.error('chaos monkey attack')
self.assertEqual(self.task._state, 'error')
self.client.finish_task.assert_called_once_with(self.task)
def test_run(self):
self.assertEqual(self.task._state, 'starting')
self.assertEqual(self.task._id, self.task_id)
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_called_once_with(self.task)
self.task.complete()
self.assertEqual(self.task._state, 'complete')
self.client.finish_task.assert_called_once_with(self.task)
def test_fast_shutdown(self):
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_called_once_with(self.task)
self.task.stopService()
self.assertEqual(self.task._state, 'error')
self.client.finish_task.assert_called_once_with(self.task)
class MultiTestTask(MultiTask):
task_name = 'test_multitask'
class MultiTaskTest(unittest.TestCase):
"""Basic tests of the Multi Task API."""
def setUp(self):
self.task_id = str(uuid.uuid4())
self.client = FakeClient()
self.task = MultiTestTask(self.client, self.task_id)
def tearDown(self):
del self.task_id
del self.task
del self.client
def test_tasks(self):
t = TestTask(self.client, self.task_id)
self.task.add_task(t)
self.assertEqual(self.task._state, 'starting')
self.assertEqual(self.task._id, self.task_id)
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_any_call(self.task)
t.complete()
self.assertEqual(self.task._state, 'complete')
self.client.finish_task.assert_any_call(t)
self.client.finish_task.assert_any_call(self.task)
class StubResponse(object):
def __init__(self, code, headers, body):
self.version = ('HTTP', 1, 1)
self.code = code
self.status = "ima teapot"
self.headers = headers
self.body = body
self.length = reduce(lambda x, y: x + len(y), body, 0)
self.protocol = None
def deliverBody(self, protocol):
self.protocol = protocol
def run(self):
self.protocol.connectionMade()
for data in self.body:
self.protocol.dataReceived(data)
self.protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
class ImageDownloaderTaskTest(unittest.TestCase):
def setUp(self):
get_patcher = patch('treq.get', autospec=True)
self.TreqGet = get_patcher.start()
self.addCleanup(get_patcher.stop)
self.tmpdir = tempfile.mkdtemp('image_download_test')
self.task_id = str(uuid.uuid4())
self.image_data = str(uuid.uuid4())
self.image_md5 = hashlib.md5(self.image_data).hexdigest()
self.cache_path = os.path.join(self.tmpdir, 'a1234.img')
self.client = FakeClient()
self.image_info = ImageInfo('a1234',
['http://127.0.0.1/images/a1234.img'], {'md5': self.image_md5})
self.task = ImageDownloaderTask(self.client,
self.task_id,
self.image_info,
self.cache_path)
def tearDown(self):
shutil.rmtree(self.tmpdir)
def assertFileHash(self, hash_type, path, value):
file_hash = hashlib.new(hash_type)
with open(path, 'r') as fp:
file_hash.update(fp.read())
self.assertEqual(value, file_hash.hexdigest())
def test_download_success(self):
resp = StubResponse(200, [], [self.image_data])
d = defer.Deferred()
self.TreqGet.return_value = d
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.TreqGet.assert_called_once_with('http://127.0.0.1/images/a1234.img')
self.task.startService()
d.callback(resp)
resp.run()
self.client.update_task_status.assert_called_once_with(self.task)
self.assertFileHash('md5', self.cache_path, self.image_md5)
self.task.stopService()
self.assertEqual(self.task._state, 'error')
self.client.finish_task.assert_called_once_with(self.task)

10
test-requirements.txt Normal file
View File

@ -0,0 +1,10 @@
pep257
plumbum
pep8
pyflakes
junitxml
python-subunit
mock
coverage
nose
flake8

20
tox.ini Normal file
View File

@ -0,0 +1,20 @@
[tox]
envlist = flake8, unit
[flake8]
max-line-length = 105
[testenv]
deps =
-rrequirements.txt
-rtest-requirements.txt
[testenv:flake8]
commands = flake8 teeth_agent
[testenv:unit]
commands = nosetests --all-modules teeth_agent/tests
[testenv:devenv]
envdir = devenv
usedevelop = True