diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index 30231d1feb..b0573664bd 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -33,6 +33,30 @@ from tests.base import iterate_timeout, ZuulWebFixture from ws4py.client import WebSocketBaseClient +class WSClient(WebSocketBaseClient): + def __init__(self, port, build_uuid): + self.port = port + self.build_uuid = build_uuid + self.results = '' + self.event = threading.Event() + uri = 'ws://[::1]:%s/api/tenant/tenant-one/console-stream' % port + super(WSClient, self).__init__(uri) + + self.thread = threading.Thread(target=self.run) + self.thread.start() + + def received_message(self, message): + if message.is_text: + self.results += message.data.decode('utf-8') + + def run(self): + self.connect() + req = {'uuid': self.build_uuid, 'logfile': None} + self.send(json.dumps(req)) + self.event.set() + super(WSClient, self).run() + + class TestLogStreamer(tests.base.BaseTestCase): def startStreamer(self, host, port, root=None): @@ -173,24 +197,10 @@ class TestStreaming(tests.base.AnsibleZuulTestCase): self.log.debug("\n\nStreamed: %s\n\n", self.streaming_data) self.assertEqual(file_contents, self.streaming_data) - def runWSClient(self, port, build_uuid, event): - class TestWSClient(WebSocketBaseClient): - def __init__(self, *args, **kw): - super(TestWSClient, self).__init__(*args, **kw) - self.results = '' - - def received_message(self, message): - if message.is_text: - self.results += message.data.decode('utf-8') - - uri = 'ws://[::1]:%s/api/tenant/tenant-one/console-stream' % port - ws = TestWSClient(uri) - ws.connect() - req = {'uuid': build_uuid, 'logfile': None} - ws.send(json.dumps(req)) - event.set() - ws.run() - self.ws_client_results += ws.results + def runWSClient(self, port, build_uuid): + client = WSClient(port, build_uuid) + client.event.wait() + return client def runFingerClient(self, build_uuid, gateway_address, event): # Wait until the gateway is started @@ -286,14 +296,8 @@ class TestStreaming(tests.base.AnsibleZuulTestCase): self.addCleanup(logfile.close) # Start a thread with the websocket client - ws_client_event = threading.Event() - self.ws_client_results = '' - ws_client_thread = threading.Thread( - target=self.runWSClient, args=(web.port, build.uuid, - ws_client_event) - ) - ws_client_thread.start() - ws_client_event.wait() + client1 = self.runWSClient(web.port, build.uuid) + client1.event.wait() # Allow the job to complete flag_file = os.path.join(build_dir, 'test_wait') @@ -301,15 +305,15 @@ class TestStreaming(tests.base.AnsibleZuulTestCase): # Wait for the websocket client to complete, which it should when # it's received the full log. - ws_client_thread.join() + client1.thread.join() self.waitUntilSettled() file_contents = logfile.read() logfile.close() self.log.debug("\n\nFile contents: %s\n\n", file_contents) - self.log.debug("\n\nStreamed: %s\n\n", self.ws_client_results) - self.assertEqual(file_contents, self.ws_client_results) + self.log.debug("\n\nStreamed: %s\n\n", client1.results) + self.assertEqual(file_contents, client1.results) def test_websocket_streaming(self): # Start the web server @@ -361,14 +365,10 @@ class TestStreaming(tests.base.AnsibleZuulTestCase): self.addCleanup(logfile.close) # Start a thread with the websocket client - ws_client_event = threading.Event() - self.ws_client_results = '' - ws_client_thread = threading.Thread( - target=self.runWSClient, args=(web.port, build.uuid, - ws_client_event) - ) - ws_client_thread.start() - ws_client_event.wait() + client1 = self.runWSClient(web.port, build.uuid) + client1.event.wait() + client2 = self.runWSClient(web.port, build.uuid) + client2.event.wait() # Allow the job to complete flag_file = os.path.join(build_dir, 'test_wait') @@ -376,14 +376,17 @@ class TestStreaming(tests.base.AnsibleZuulTestCase): # Wait for the websocket client to complete, which it should when # it's received the full log. - ws_client_thread.join() + client1.thread.join() + client2.thread.join() self.waitUntilSettled() file_contents = logfile.read() self.log.debug("\n\nFile contents: %s\n\n", file_contents) - self.log.debug("\n\nStreamed: %s\n\n", self.ws_client_results) - self.assertEqual(file_contents, self.ws_client_results) + self.log.debug("\n\nStreamed: %s\n\n", client1.results) + self.assertEqual(file_contents, client1.results) + self.log.debug("\n\nStreamed: %s\n\n", client2.results) + self.assertEqual(file_contents, client2.results) def test_finger_gateway(self): # Start the finger streamer daemon diff --git a/zuul/web/__init__.py b/zuul/web/__init__.py index 6e3857000d..464ad84c19 100755 --- a/zuul/web/__init__.py +++ b/zuul/web/__init__.py @@ -25,6 +25,8 @@ import json import logging import os import time +import select +import threading import zuul.model import zuul.rpcclient @@ -80,13 +82,29 @@ class ChangeFilter(object): class LogStreamHandler(WebSocket): log = logging.getLogger("zuul.web") + def __init__(self, *args, **kw): + super(LogStreamHandler, self).__init__(*args, **kw) + self.streamer = None + def received_message(self, message): if message.is_text: req = json.loads(message.data.decode('utf-8')) self.log.debug("Websocket request: %s", req) - code, msg = self._streamLog(req) - self.log.debug("close Websocket request: %s %s", code, msg) + if self.streamer: + self.log.debug("Ignoring request due to existing streamer") + return + try: + self._streamLog(req) + except Exception: + self.log.exception("Error processing websocket message:") + raise + + def logClose(self, code, msg): + self.log.debug("Websocket close: %s %s", code, msg) + try: self.close(code, msg) + except Exception: + self.log.exception("Error closing websocket:") def _streamLog(self, request): """ @@ -96,20 +114,26 @@ class LogStreamHandler(WebSocket): """ for key in ('uuid', 'logfile'): if key not in request: - return (4000, "'{key}' missing from request payload".format( + return self.logClose( + 4000, + "'{key}' missing from request payload".format( key=key)) - port_location = self.rpc.get_job_log_stream_address(request['uuid']) + port_location = self.zuulweb.rpc.get_job_log_stream_address( + request['uuid']) if not port_location: - return (4011, "Error with Gearman") + return self.logClose(4011, "Error with Gearman") - self._fingerClient( + self.streamer = LogStreamer( + self.zuulweb, self, port_location['server'], port_location['port'], request['uuid']) - return (1000, "No more data") - def _fingerClient(self, server, port, build_uuid): +class LogStreamer(object): + log = logging.getLogger("zuul.web") + + def __init__(self, zuulweb, websocket, server, port, build_uuid): """ Create a client to connect to the finger streamer and pull results. @@ -119,25 +143,40 @@ class LogStreamHandler(WebSocket): """ self.log.debug("Connecting to finger server %s:%s", server, port) Decoder = codecs.getincrementaldecoder('utf8') - decoder = Decoder() - with socket.create_connection((server, port), timeout=10) as s: - # timeout only on the connection, let recv() wait forever - s.settimeout(None) - msg = "%s\n" % build_uuid # Must have a trailing newline! - s.sendall(msg.encode('utf-8')) - while True: - data = s.recv(1024) + self.decoder = Decoder() + self.finger_socket = socket.create_connection( + (server, port), timeout=10) + self.finger_socket.settimeout(None) + self.websocket = websocket + self.zuulweb = zuulweb + self.uuid = build_uuid + msg = "%s\n" % build_uuid # Must have a trailing newline! + self.finger_socket.sendall(msg.encode('utf-8')) + self.zuulweb.stream_manager.registerStreamer(self) + + def __repr__(self): + return '' % (self.websocket, self.uuid) + + def errorClose(self): + self.websocket.logClose(4011, "Unknown error") + + def handle(self, event): + if event & select.POLLIN: + data = self.finger_socket.recv(1024) + if data: + data = self.decoder.decode(data) if data: - data = decoder.decode(data) - if data: - self.send(data, False) - else: - # Make sure we flush anything left in the decoder - data = decoder.decode(b'', final=True) - if data: - self.send(data, False) - self.close() - return + self.websocket.send(data, False) + else: + # Make sure we flush anything left in the decoder + data = self.decoder.decode(b'', final=True) + if data: + self.websocket.send(data, False) + self.zuulweb.stream_manager.unregisterStreamer(self) + return self.websocket.logClose(1000, "No more data") + else: + self.zuulweb.stream_manager.unregisterStreamer(self) + return self.websocket.logClose(1000, "Remote error") class ZuulWebAPI(object): @@ -271,7 +310,7 @@ class ZuulWebAPI(object): @cherrypy.tools.save_params() @cherrypy.tools.websocket(handler_cls=LogStreamHandler) def console_stream(self, tenant): - cherrypy.request.ws_handler.rpc = self.rpc + cherrypy.request.ws_handler.zuulweb = self.zuulweb class TenantStaticHandler(object): @@ -292,6 +331,69 @@ class RootStaticHandler(object): } +class StreamManager(object): + log = logging.getLogger("zuul.web") + + def __init__(self): + self.streamers = {} + self.poll = select.poll() + self.bitmask = (select.POLLIN | select.POLLERR | + select.POLLHUP | select.POLLNVAL) + self.wake_read, self.wake_write = os.pipe() + self.poll.register(self.wake_read, self.bitmask) + + def start(self): + self._stopped = False + self.thread = threading.Thread( + target=self.run, + name='StreamManager') + self.thread.start() + + def stop(self): + self._stopped = True + os.write(self.wake_write, b'\n') + self.thread.join() + + def run(self): + while True: + for fd, event in self.poll.poll(): + if self._stopped: + return + if fd == self.wake_read: + os.read(self.wake_read, 1024) + continue + streamer = self.streamers.get(fd) + if streamer: + try: + streamer.handle(event) + except Exception: + self.log.exception("Error in streamer:") + streamer.errorClose() + self.unregisterStreamer(streamer) + else: + try: + self.poll.unregister(fd) + except KeyError: + pass + + def registerStreamer(self, streamer): + self.log.debug("Registering streamer %s", streamer) + self.streamers[streamer.finger_socket.fileno()] = streamer + self.poll.register(streamer.finger_socket.fileno(), self.bitmask) + os.write(self.wake_write, b'\n') + + def unregisterStreamer(self, streamer): + self.log.debug("Unregistering streamer %s", streamer) + try: + self.poll.unregister(streamer.finger_socket) + except KeyError: + pass + try: + del self.streamers[streamer.finger_socket.fileno()] + except KeyError: + pass + + class ZuulWeb(object): log = logging.getLogger("zuul.web.ZuulWeb") @@ -315,6 +417,7 @@ class ZuulWeb(object): self.rpc = zuul.rpcclient.RPCClient(gear_server, gear_port, ssl_key, ssl_cert, ssl_ca) self.connections = connections + self.stream_manager = StreamManager() route_map = cherrypy.dispatch.RoutesDispatcher() api = ZuulWebAPI(self) @@ -373,6 +476,7 @@ class ZuulWeb(object): def start(self): self.log.debug("ZuulWeb starting") + self.stream_manager.start() self.wsplugin = WebSocketPlugin(cherrypy.engine) self.wsplugin.subscribe() cherrypy.engine.start() @@ -386,6 +490,7 @@ class ZuulWeb(object): # same host/port settings. cherrypy.server.httpserver = None self.wsplugin.unsubscribe() + self.stream_manager.stop() if __name__ == "__main__":