diff --git a/.travis.yml b/.travis.yml index 2acf19f42..d5a7a07fb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,3 +4,10 @@ install: - pip install tox script: - tox +notifications: + irc: + channels: + - "chat.freenode.net#teeth-dev" + use_notice: true + skip_join: true + email: false diff --git a/README.md b/README.md index 5f7eb62b6..15e3a8a4f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ # teeth-agent +[![Build Status](https://travis-ci.org/rackerlabs/teeth-agent.png?branch=master)](https://travis-ci.org/rackerlabs/teeth-agent) + An agent for rebuilding and controlling Teeth chassis. diff --git a/requirements.txt b/requirements.txt index 3ac744fbd..b0528704f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ Werkzeug==0.9.4 requests==2.0.0 cherrypy==3.2.4 --e git+https://github.com/racker/teeth-rest.git@c62ac56cd4273e54592768ad94bb72c7c5e92508#egg=teeth_rest +stevedore==0.13 +-e git+https://github.com/racker/teeth-rest.git@e876c0fddd5ce2f5223ab16936f711b0d57e19c4#egg=teeth_rest diff --git a/setup.cfg b/setup.cfg index 869d5d767..539e217d4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,5 +16,8 @@ packages = [entry_points] console_scripts = - teeth-standby-agent = teeth_agent.cmd.standby:run - teeth-decom-agent = teeth_agent.cmd.decom:run + teeth-agent = teeth_agent.cmd.agent:run + +teeth_agent.modes = + standby = teeth_agent.standby:StandbyMode + decom = teeth_agent.decom:DecomMode diff --git a/teeth_agent/agent.py b/teeth_agent/agent.py new file mode 100644 index 000000000..70e8f4f97 --- /dev/null +++ b/teeth_agent/agent.py @@ -0,0 +1,253 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import collections +import random +import socket +import threading +import time +import urlparse + +from cherrypy import wsgiserver +import pkg_resources +from stevedore import driver +import structlog +from teeth_rest import encoding +from teeth_rest import errors as rest_errors + +from teeth_agent import api +from teeth_agent import base +from teeth_agent import errors +from teeth_agent import hardware +from teeth_agent import overlord_agent_api + + +class TeethAgentStatus(encoding.Serializable): + def __init__(self, mode, started_at, version): + self.mode = mode + self.started_at = started_at + self.version = version + + def serialize(self, view): + """Turn the status into a dict.""" + return collections.OrderedDict([ + ('mode', self.mode), + ('started_at', self.started_at), + ('version', self.version), + ]) + + +class TeethAgentHeartbeater(threading.Thread): + # If we could wait at most N seconds between heartbeats (or in case of an + # error) we will instead wait r x N seconds, where r is a random value + # between these multipliers. + min_jitter_multiplier = 0.3 + max_jitter_multiplier = 0.6 + + # Exponential backoff values used in case of an error. In reality we will + # only wait a portion of either of these delays based on the jitter + # multipliers. + initial_delay = 1.0 + max_delay = 300.0 + backoff_factor = 2.7 + + def __init__(self, agent): + super(TeethAgentHeartbeater, self).__init__() + self.agent = agent + self.api = overlord_agent_api.APIClient(agent.api_url) + self.log = structlog.get_logger(api_url=agent.api_url) + self.stop_event = threading.Event() + self.error_delay = self.initial_delay + + def run(self): + # The first heartbeat happens now + self.log.info('starting heartbeater') + interval = 0 + + while not self.stop_event.wait(interval): + next_heartbeat_by = self.do_heartbeat() + interval_multiplier = random.uniform(self.min_jitter_multiplier, + self.max_jitter_multiplier) + interval = (next_heartbeat_by - time.time()) * interval_multiplier + self.log.info('sleeping before next heartbeat', interval=interval) + + def do_heartbeat(self): + try: + deadline = self.api.heartbeat( + mac_addr=self.agent.get_agent_mac_addr(), + url=self.agent.get_agent_url(), + version=self.agent.version, + mode=self.agent.mode_implementation.name) + self.error_delay = self.initial_delay + self.log.info('heartbeat successful') + except Exception as e: + self.log.error('error sending heartbeat', exception=e) + deadline = time.time() + self.error_delay + self.error_delay = min(self.error_delay * self.backoff_factor, + self.max_delay) + pass + + return deadline + + def stop(self): + self.log.info('stopping heartbeater') + self.stop_event.set() + return self.join() + + +class TeethAgent(object): + def __init__(self, api_url, listen_address, advertise_address, mode_impl): + self.api_url = api_url + self.listen_address = listen_address + self.advertise_address = advertise_address + self.mode_implementation = mode_impl + self.version = pkg_resources.get_distribution('teeth-agent').version + self.api = api.TeethAgentAPIServer(self) + self.command_results = collections.OrderedDict() + self.heartbeater = TeethAgentHeartbeater(self) + self.hardware = hardware.HardwareInspector() + self.command_lock = threading.Lock() + self.log = structlog.get_logger() + self.started_at = None + + def get_status(self): + """Retrieve a serializable status.""" + return TeethAgentStatus( + mode=self.mode_implementation.name, + started_at=self.started_at, + version=self.version + ) + + def get_agent_url(self): + # If we put this behind any sort of proxy (ie, stunnel) we're going to + # need to (re)think this. + return 'http://{host}:{port}/'.format(host=self.advertise_address[0], + port=self.advertise_address[1]) + + def get_agent_mac_addr(self): + return self.hardware.get_primary_mac_address() + + def list_command_results(self): + return self.command_results.values() + + def get_command_result(self, result_id): + try: + return self.command_results[result_id] + except KeyError: + raise errors.RequestedObjectNotFoundError('Command Result', + result_id) + + def execute_command(self, command_name, **kwargs): + """Execute an agent command.""" + with self.command_lock: + if len(self.command_results) > 0: + last_command = self.command_results.values()[-1] + if not last_command.is_done(): + raise errors.CommandExecutionError('agent is busy') + + try: + result = self.mode_implementation.execute(command_name, + **kwargs) + except rest_errors.InvalidContentError as e: + # Any command may raise a InvalidContentError which will be + # returned to the caller directly. + raise e + except Exception as e: + # Other errors are considered command execution errors, and are + # recorded as an + result = base.SyncCommandResult(command_name, kwargs, False, e) + + self.command_results[result.id] = result + return result + + def run(self): + """Run the Teeth Agent.""" + self.started_at = time.time() + self.heartbeater.start() + server = wsgiserver.CherryPyWSGIServer(self.listen_address, self.api) + + try: + server.start() + except BaseException as e: + self.log.error('shutting down', exception=e) + server.stop() + + self.heartbeater.stop() + + +def _get_api_facing_ip_address(api_url): + """Note: this will raise an exception if anything goes wrong. That is + expected to be fine, if we can't get to the agent API there isn't much + point in starting up. Just crash and rely on the process manager to + restart us in a sane fashion. + """ + api_addr = urlparse.urlparse(api_url) + + if api_addr.scheme not in ('http', 'https'): + raise RuntimeError('API URL scheme must be one of \'http\' or ' + '\'https\'.') + + api_port = api_addr.port or {'http': 80, 'https': 443}[api_addr.scheme] + api_host = api_addr.hostname + + conn = socket.create_connection((api_host, api_port)) + listen_ip = conn.getsockname()[0] + conn.close() + + return listen_ip + + +def _load_mode_implementation(mode_name): + mgr = driver.DriverManager( + namespace='teeth_agent.modes', + name=mode_name.lower(), + invoke_on_load=True, + invoke_args=[], + ) + return mgr.driver + + +def build_agent(api_url, + listen_host, + listen_port, + advertise_host, + advertise_port): + log = structlog.get_logger() + + if not advertise_host: + log.info('resolving API-facing IP address') + advertise_host = _get_api_facing_ip_address(api_url) + log.info('resolved API-facing IP address', ip_address=advertise_host) + + if not listen_host: + listen_host = advertise_host + + mac_addr = hardware.HardwareInspector().get_primary_mac_address() + api_client = overlord_agent_api.APIClient(api_url) + + log.info('fetching agent configuration from API', + api_url=api_url, + mac_addr=mac_addr) + config = api_client.get_configuration(mac_addr) + mode_name = config['mode'] + + log.info('loading mode implementation', mode=mode_name) + mode_implementation = _load_mode_implementation(mode_name) + + return TeethAgent(api_url, + (listen_host, listen_port), + (advertise_host, advertise_port), + mode_implementation) diff --git a/teeth_agent/base.py b/teeth_agent/base.py index c9ff69beb..af61ed2a5 100644 --- a/teeth_agent/base.py +++ b/teeth_agent/base.py @@ -16,38 +16,14 @@ limitations under the License. import abc import collections -import random -import socket import threading -import time -import urlparse import uuid -from cherrypy import wsgiserver -import pkg_resources import structlog from teeth_rest import encoding from teeth_rest import errors as rest_errors -from teeth_agent import api from teeth_agent import errors -from teeth_agent import hardware -from teeth_agent import overlord_agent_api - - -class TeethAgentStatus(encoding.Serializable): - def __init__(self, mode, started_at, version): - self.mode = mode - self.started_at = started_at - self.version = version - - def serialize(self, view): - """Turn the status into a dict.""" - return collections.OrderedDict([ - ('mode', self.mode), - ('started_at', self.started_at), - ('version', self.version), - ]) class AgentCommandStatus(object): @@ -142,190 +118,23 @@ class AsyncCommandResult(BaseCommandResult): pass -class TeethAgentHeartbeater(threading.Thread): - # If we could wait at most N seconds between heartbeats (or in case of an - # error) we will instead wait r x N seconds, where r is a random value - # between these multipliers. - min_jitter_multiplier = 0.3 - max_jitter_multiplier = 0.6 - - # Exponential backoff values used in case of an error. In reality we will - # only wait a portion of either of these delays based on the jitter - # multipliers. - initial_delay = 1.0 - max_delay = 300.0 - backoff_factor = 2.7 - - def __init__(self, agent): - super(TeethAgentHeartbeater, self).__init__() - self.agent = agent - self.api = overlord_agent_api.APIClient(agent.api_url) - self.log = structlog.get_logger(api_url=agent.api_url) - self.stop_event = threading.Event() - self.error_delay = self.initial_delay - - def run(self): - # The first heartbeat happens now - self.log.info('starting heartbeater') - interval = 0 - - while not self.stop_event.wait(interval): - next_heartbeat_by = self.do_heartbeat() - interval_multiplier = random.uniform(self.min_jitter_multiplier, - self.max_jitter_multiplier) - interval = (next_heartbeat_by - time.time()) * interval_multiplier - self.log.info('sleeping before next heartbeat', interval=interval) - - def do_heartbeat(self): - try: - deadline = self.api.heartbeat( - mac_addr=self.agent.get_agent_mac_addr(), - url=self.agent.get_agent_url(), - version=self.agent.version, - mode=self.agent.mode) - self.error_delay = self.initial_delay - self.log.info('heartbeat successful') - except Exception as e: - self.log.error('error sending heartbeat', exception=e) - deadline = time.time() + self.error_delay - self.error_delay = min(self.error_delay * self.backoff_factor, - self.max_delay) - pass - - return deadline - - def stop(self): - self.log.info('stopping heartbeater') - self.stop_event.set() - return self.join() - - -class BaseTeethAgent(object): - def __init__(self, - listen_host, - listen_port, - advertise_host, - advertise_port, - api_url, - mode): - self.listen_host = listen_host - self.listen_port = listen_port - self.advertise_host = advertise_host - self.advertise_port = advertise_port - self.api_url = api_url - self.started_at = None - self.mode = mode - self.version = pkg_resources.get_distribution('teeth-agent').version - self.api = api.TeethAgentAPIServer(self) - self.command_results = collections.OrderedDict() +class BaseAgentMode(object): + def __init__(self, name): + super(BaseAgentMode, self).__init__() + self.log = structlog.get_logger(agent_mode=name) + self.name = name self.command_map = {} - self.heartbeater = TeethAgentHeartbeater(self) - self.hardware = hardware.HardwareInspector() - self.command_lock = threading.Lock() - self.log = structlog.get_logger() - def get_status(self): - """Retrieve a serializable status.""" - return TeethAgentStatus( - mode=self.mode, - started_at=self.started_at, - version=self.version - ) + def execute(self, command_name, **kwargs): + if command_name not in self.command_map: + raise errors.InvalidCommandError(command_name) - def get_agent_url(self): - # If we put this behind any sort of proxy (ie, stunnel) we're going to - # need to (re)think this. - return 'http://{host}:{port}/'.format(host=self.advertise_host, - port=self.advertise_port) + result = self.command_map[command_name](command_name, **kwargs) - def get_api_facing_ip_address(self): - """Note: this will raise an exception if anything goes wrong. That is - expected to be fine, if we can't get to the agent API there isn't much - point in starting up. Just crash and rely on the process manager to - restart us in a sane fashion. - """ - api_addr = urlparse.urlparse(self.api_url) + # In order to enable extremely succinct synchronous commands, we allow + # them to return a value directly, and we'll handle wrapping it up in a + # SyncCommandResult + if not isinstance(result, BaseCommandResult): + result = SyncCommandResult(command_name, kwargs, True, result) - if api_addr.scheme not in ('http', 'https'): - raise RuntimeError('API URL scheme must be one of \'http\' or ' - '\'https\'.') - - api_port = api_addr.port or {'http': 80, 'https': 443}[api_addr.scheme] - api_host = api_addr.hostname - - self.log.info('attempting to resolve listen IP', - api_host=api_host, - api_port=api_port) - - conn = socket.create_connection((api_host, api_port)) - listen_ip = conn.getsockname()[0] - conn.close() - self.log.info('resolved listen IP', listen_ip=listen_ip) - - return listen_ip - - def get_agent_mac_addr(self): - return self.hardware.get_primary_mac_address() - - def list_command_results(self): - return self.command_results.values() - - def get_command_result(self, result_id): - try: - return self.command_results[result_id] - except KeyError: - raise errors.RequestedObjectNotFoundError('Command Result', - result_id) - - def execute_command(self, command_name, **kwargs): - """Execute an agent command.""" - with self.command_lock: - if len(self.command_results) > 0: - last_command = self.command_results.values()[-1] - if not last_command.is_done(): - raise errors.CommandExecutionError('agent is busy') - - if command_name not in self.command_map: - raise errors.InvalidCommandError(command_name) - - try: - result = self.command_map[command_name](command_name, **kwargs) - if not isinstance(result, BaseCommandResult): - result = SyncCommandResult(command_name, - kwargs, - True, - result) - except rest_errors.InvalidContentError as e: - # Any command may raise a InvalidContentError which will be - # returned to the caller directly. - raise e - except Exception as e: - # Other errors are considered command execution errors, and are - # recorded as an - result = SyncCommandResult(command_name, kwargs, False, e) - - self.command_results[result.id] = result - return result - - def run(self): - """Run the Teeth Agent.""" - self.started_at = time.time() - - if not self.advertise_host: - self.advertise_host = self.get_api_facing_ip_address() - - if not self.listen_host: - self.listen_host = self.advertise_host - - self.heartbeater.start() - - listen_address = (self.listen_host, self.listen_port) - server = wsgiserver.CherryPyWSGIServer(listen_address, self.api) - - try: - server.start() - except BaseException as e: - self.log.error('shutting down', exception=e) - server.stop() - - self.heartbeater.stop() + return result diff --git a/teeth_agent/cmd/decom.py b/teeth_agent/cmd/agent.py similarity index 84% rename from teeth_agent/cmd/decom.py rename to teeth_agent/cmd/agent.py index c6df33ead..bfba9ccaa 100644 --- a/teeth_agent/cmd/decom.py +++ b/teeth_agent/cmd/agent.py @@ -16,13 +16,14 @@ limitations under the License. import argparse -from teeth_agent import decom +from teeth_agent import agent from teeth_agent import logging def run(): parser = argparse.ArgumentParser( - description='Run the teeth-agent in decom mode') + description=('An agent that handles decomissioning and provisioning' + ' on behalf of teeth-overlord.')) parser.add_argument('--api-url', required=True, @@ -55,8 +56,8 @@ def run(): args = parser.parse_args() logging.configure() advertise_port = args.advertise_port or args.listen_port - decom.DecomAgent(args.listen_host, - args.listen_port, - args.advertise_host, - advertise_port, - args.api_url).run() + agent.build_agent(args.api_url, + args.listen_host, + args.listen_port, + args.advertise_host, + advertise_port).run() diff --git a/teeth_agent/cmd/standby.py b/teeth_agent/cmd/standby.py deleted file mode 100644 index 8c59b67a9..000000000 --- a/teeth_agent/cmd/standby.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Copyright 2013 Rackspace, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import argparse - -from teeth_agent import logging -from teeth_agent import standby - - -def run(): - parser = argparse.ArgumentParser( - description='Run the teeth-agent in standby mode') - - parser.add_argument('--api-url', - required=True, - help='URL of the Teeth agent API') - - parser.add_argument('--listen-host', - type=str, - help=('The IP address to listen on. Leave this blank' - ' to auto-detect. A common use-case would be to' - ' override this with \'localhost\', in order to' - ' run behind a proxy, while leaving' - ' advertise-host unspecified.')) - - parser.add_argument('--listen-port', - default=9999, - type=int, - help='The port to listen on') - - parser.add_argument('--advertise-host', - type=str, - help=('The IP address to advertise. Leave this blank' - ' to auto-detect by calling \'getsockname()\' on' - ' a connection to the agent API.')) - - parser.add_argument('--advertise-port', - type=int, - help=('The port to advertise. Defaults to listen-port.' - ' Useful when running behind a proxy.')) - - args = parser.parse_args() - logging.configure() - advertise_port = args.advertise_port or args.listen_port - standby.StandbyAgent(args.listen_host, - args.listen_port, - args.advertise_host, - advertise_port, - args.api_url).run() diff --git a/teeth_agent/decom.py b/teeth_agent/decom.py index 849acb0c5..b1fe1f102 100644 --- a/teeth_agent/decom.py +++ b/teeth_agent/decom.py @@ -17,16 +17,6 @@ limitations under the License. from teeth_agent import base -class DecomAgent(base.BaseTeethAgent): - def __init__(self, - listen_host, - listen_port, - advertise_host, - advertise_port, - api_url): - super(DecomAgent, self).__init__(listen_host, - listen_port, - advertise_host, - advertise_port, - api_url, - 'DECOM') +class DecomMode(base.BaseAgentMode): + def __init__(self): + super(DecomMode, self).__init__('DECOM') diff --git a/teeth_agent/errors.py b/teeth_agent/errors.py index 07e4e8467..996b57af2 100644 --- a/teeth_agent/errors.py +++ b/teeth_agent/errors.py @@ -53,14 +53,23 @@ class RequestedObjectNotFoundError(errors.NotFound): self.details = details -class HeartbeatError(errors.RESTError): +class OverlordAPIError(errors.RESTError): + """Error raised when a call to the agent API fails.""" + + message = 'Error in call to teeth-agent-api.' + + def __init__(self, details): + super(OverlordAPIError, self).__init__(details) + self.details = details + + +class HeartbeatError(OverlordAPIError): """Error raised when a heartbeat to the agent API fails.""" message = 'Error heartbeating to agent API.' def __init__(self, details): - super(HeartbeatError, self).__init__() - self.details = details + super(HeartbeatError, self).__init__(details) class ImageDownloadError(errors.RESTError): diff --git a/teeth_agent/overlord_agent_api.py b/teeth_agent/overlord_agent_api.py index 4d0c7f65b..209ec8920 100644 --- a/teeth_agent/overlord_agent_api.py +++ b/teeth_agent/overlord_agent_api.py @@ -14,8 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. """ -import requests +import json +import requests from teeth_rest import encoding from teeth_agent import errors @@ -72,3 +73,19 @@ class APIClient(object): raise errors.HeartbeatError('Missing Heartbeat-Before header') except Exception: raise errors.HeartbeatError('Invalid Heartbeat-Before header') + + def get_configuration(self, mac_addr): + path = '/{api_version}/agents/{mac_addr}/configuration'.format( + api_version=self.api_version, + mac_addr=mac_addr) + + response = self._request('GET', path) + + if response.status_code != requests.codes.OK: + msg = 'Invalid status code: {}'.format(response.status_code) + raise errors.OverlordAPIError(msg) + + try: + return json.loads(response.content) + except Exception as e: + raise errors.OverlordAPIError('Error decoding response: ' + str(e)) diff --git a/teeth_agent/standby.py b/teeth_agent/standby.py index 7628451c9..d7a6f3136 100644 --- a/teeth_agent/standby.py +++ b/teeth_agent/standby.py @@ -126,25 +126,12 @@ class RunImageCommand(base.AsyncCommandResult): _run_image() -class StandbyAgent(base.BaseTeethAgent): - def __init__(self, - listen_host, - listen_port, - advertise_host, - advertise_port, - api_url): - super(StandbyAgent, self).__init__(listen_host, - listen_port, - advertise_host, - advertise_port, - api_url, - 'STANDBY') - - self.command_map = { - 'cache_images': self.cache_images, - 'prepare_image': self.prepare_image, - 'run_image': self.run_image, - } +class StandbyMode(base.BaseAgentMode): + def __init__(self): + super(StandbyMode, self).__init__('STANDBY') + self.command_map['cache_images'] = self.cache_images + self.command_map['prepare_image'] = self.prepare_image + self.command_map['run_image'] = self.run_image def _validate_image_info(self, image_info): for field in ['id', 'urls', 'hashes']: diff --git a/teeth_agent/tests/base_agent.py b/teeth_agent/tests/agent.py similarity index 89% rename from teeth_agent/tests/base_agent.py rename to teeth_agent/tests/agent.py index 92bbb03aa..e906397cf 100644 --- a/teeth_agent/tests/base_agent.py +++ b/teeth_agent/tests/agent.py @@ -23,6 +23,7 @@ import pkg_resources from teeth_rest import encoding +from teeth_agent import agent from teeth_agent import base from teeth_agent import errors @@ -38,10 +39,15 @@ class FooTeethAgentCommandResult(base.AsyncCommandResult): return 'command execution succeeded' +class FakeMode(base.BaseAgentMode): + def __init__(self): + super(FakeMode, self).__init__('FAKE') + + class TestHeartbeater(unittest.TestCase): def setUp(self): self.mock_agent = mock.Mock() - self.heartbeater = base.TeethAgentHeartbeater(self.mock_agent) + self.heartbeater = agent.TeethAgentHeartbeater(self.mock_agent) self.heartbeater.api = mock.Mock() self.heartbeater.stop_event = mock.Mock() @@ -108,17 +114,15 @@ class TestHeartbeater(unittest.TestCase): self.assertEqual(self.heartbeater.error_delay, 2.7) -class TestBaseTeethAgent(unittest.TestCase): +class TestBaseAgent(unittest.TestCase): def setUp(self): self.encoder = encoding.RESTJSONEncoder( encoding.SerializationViews.PUBLIC, indent=4) - self.agent = base.BaseTeethAgent(None, - 9999, - None, - 9999, - 'https://fake_api.example.org:8081/', - 'TEST_MODE') + self.agent = agent.TeethAgent('https://fake_api.example.org:8081/', + ('localhost', 9999), + ('localhost', 9999), + FakeMode()) def assertEqualEncoded(self, a, b): # Evidently JSONEncoder.default() can't handle None (??) so we have to @@ -133,17 +137,15 @@ class TestBaseTeethAgent(unittest.TestCase): self.agent.started_at = started_at status = self.agent.get_status() - self.assertIsInstance(status, base.TeethAgentStatus) - self.assertEqual(status.mode, 'TEST_MODE') + self.assertIsInstance(status, agent.TeethAgentStatus) self.assertEqual(status.started_at, started_at) self.assertEqual(status.version, pkg_resources.get_distribution('teeth-agent').version) def test_execute_command(self): do_something_impl = mock.Mock() - self.agent.command_map = { - 'do_something': do_something_impl, - } + command_map = self.agent.mode_implementation.command_map + command_map['do_something'] = do_something_impl self.agent.execute_command('do_something', foo='bar') do_something_impl.assert_called_once_with('do_something', foo='bar') @@ -159,12 +161,10 @@ class TestBaseTeethAgent(unittest.TestCase): wsgi_server = wsgi_server_cls.return_value wsgi_server.start.side_effect = KeyboardInterrupt() - self.agent.get_api_facing_ip_address = mock.Mock() - self.agent.get_api_facing_ip_address.return_value = '1.2.3.4' self.agent.heartbeater = mock.Mock() self.agent.run() - listen_addr = ('1.2.3.4', 9999) + listen_addr = ('localhost', 9999) wsgi_server_cls.assert_called_once_with(listen_addr, self.agent.api) wsgi_server.start.assert_called_once_with() wsgi_server.stop.assert_called_once_with() diff --git a/teeth_agent/tests/api.py b/teeth_agent/tests/api.py index 7e27350de..fc0686640 100644 --- a/teeth_agent/tests/api.py +++ b/teeth_agent/tests/api.py @@ -24,6 +24,7 @@ from werkzeug import wrappers from teeth_rest import encoding +from teeth_agent import agent from teeth_agent import api from teeth_agent import base @@ -44,7 +45,7 @@ class TestTeethAPI(unittest.TestCase): return client.open(self._get_env_builder(method, path, data, query)) def test_get_agent_status(self): - status = base.TeethAgentStatus('TEST_MODE', time.time(), 'v72ac9') + status = agent.TeethAgentStatus('TEST_MODE', time.time(), 'v72ac9') mock_agent = mock.MagicMock() mock_agent.get_status.return_value = status api_server = api.TeethAgentAPIServer(mock_agent) @@ -120,7 +121,7 @@ class TestTeethAPI(unittest.TestCase): True, {'test': 'result'}) - mock_agent = mock.create_autospec(base.BaseTeethAgent) + mock_agent = mock.create_autospec(agent.TeethAgent) mock_agent.list_command_results.return_value = [ cmd_result, ] @@ -144,7 +145,7 @@ class TestTeethAPI(unittest.TestCase): serialized_cmd_result = cmd_result.serialize( encoding.SerializationViews.PUBLIC) - mock_agent = mock.create_autospec(base.BaseTeethAgent) + mock_agent = mock.create_autospec(agent.TeethAgent) mock_agent.get_command_result.return_value = cmd_result api_server = api.TeethAgentAPIServer(mock_agent) diff --git a/teeth_agent/tests/decom_agent.py b/teeth_agent/tests/decom.py similarity index 79% rename from teeth_agent/tests/decom_agent.py rename to teeth_agent/tests/decom.py index c6e0f7b2c..920f2b97b 100644 --- a/teeth_agent/tests/decom_agent.py +++ b/teeth_agent/tests/decom.py @@ -19,9 +19,9 @@ import unittest from teeth_agent import decom -class TestBaseTeethAgent(unittest.TestCase): +class TestDecomMode(unittest.TestCase): def setUp(self): - self.agent = decom.DecomAgent(None, 9999, None, 9999, 'fake_api') + self.agent_mode = decom.DecomMode() def test_decom_mode(self): - self.assertEqual(self.agent.mode, 'DECOM') + self.assertEqual(self.agent_mode.name, 'DECOM') diff --git a/teeth_agent/tests/standby_agent.py b/teeth_agent/tests/standby.py similarity index 91% rename from teeth_agent/tests/standby_agent.py rename to teeth_agent/tests/standby.py index 96de673b8..fa7c49422 100644 --- a/teeth_agent/tests/standby_agent.py +++ b/teeth_agent/tests/standby.py @@ -21,12 +21,12 @@ from teeth_agent import errors from teeth_agent import standby -class TestBaseTeethAgent(unittest.TestCase): +class TestStandbyMode(unittest.TestCase): def setUp(self): - self.agent = standby.StandbyAgent(None, 9999, None, 9999, 'fake_api') + self.agent_mode = standby.StandbyMode() def test_standby_mode(self): - self.assertEqual(self.agent.mode, 'STANDBY') + self.assertEqual(self.agent_mode.name, 'STANDBY') def _build_fake_image_info(self): return { @@ -40,7 +40,7 @@ class TestBaseTeethAgent(unittest.TestCase): } def test_validate_image_info_success(self): - self.agent._validate_image_info(self._build_fake_image_info()) + self.agent_mode._validate_image_info(self._build_fake_image_info()) def test_validate_image_info_missing_field(self): for field in ['id', 'urls', 'hashes']: @@ -48,7 +48,7 @@ class TestBaseTeethAgent(unittest.TestCase): del invalid_info[field] self.assertRaises(errors.InvalidCommandParamsError, - self.agent._validate_image_info, + self.agent_mode._validate_image_info, invalid_info) def test_validate_image_info_invalid_urls(self): @@ -56,7 +56,7 @@ class TestBaseTeethAgent(unittest.TestCase): invalid_info['urls'] = 'this_is_not_a_list' self.assertRaises(errors.InvalidCommandParamsError, - self.agent._validate_image_info, + self.agent_mode._validate_image_info, invalid_info) def test_validate_image_info_empty_urls(self): @@ -64,7 +64,7 @@ class TestBaseTeethAgent(unittest.TestCase): invalid_info['urls'] = [] self.assertRaises(errors.InvalidCommandParamsError, - self.agent._validate_image_info, + self.agent_mode._validate_image_info, invalid_info) def test_validate_image_info_invalid_hashes(self): @@ -72,7 +72,7 @@ class TestBaseTeethAgent(unittest.TestCase): invalid_info['hashes'] = 'this_is_not_a_dict' self.assertRaises(errors.InvalidCommandParamsError, - self.agent._validate_image_info, + self.agent_mode._validate_image_info, invalid_info) def test_validate_image_info_empty_hashes(self): @@ -80,17 +80,17 @@ class TestBaseTeethAgent(unittest.TestCase): invalid_info['hashes'] = {} self.assertRaises(errors.InvalidCommandParamsError, - self.agent._validate_image_info, + self.agent_mode._validate_image_info, invalid_info) def test_cache_images_success(self): - result = self.agent.cache_images('cache_images', - [self._build_fake_image_info()]) + result = self.agent_mode.cache_images('cache_images', + [self._build_fake_image_info()]) result.join() def test_cache_images_invalid_image_list(self): self.assertRaises(errors.InvalidCommandParamsError, - self.agent.cache_images, + self.agent_mode.cache_images, 'cache_images', {'foo': 'bar'})