233 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #    Copyright 2016 Mirantis, Inc.
 | |
| #
 | |
| #    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | |
| #    not use this file except in compliance with the License. You may obtain
 | |
| #    a copy of the License at
 | |
| #
 | |
| #         http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #    Unless required by applicable law or agreed to in writing, software
 | |
| #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | |
| #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | |
| #    License for the specific language governing permissions and limitations
 | |
| #    under the License.
 | |
| 
 | |
| import logging
 | |
| import logging.handlers
 | |
| import multiprocessing
 | |
| import os
 | |
| import sys
 | |
| import threading
 | |
| import time
 | |
| import uuid
 | |
| 
 | |
| from oslo_config import cfg
 | |
| 
 | |
| import oslo_messaging
 | |
| from oslo_messaging._drivers.zmq_driver import zmq_async
 | |
| from oslo_messaging.tests.functional import utils
 | |
| 
 | |
| 
 | |
| zmq = zmq_async.import_zmq()
 | |
| 
 | |
| LOG = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class QueueHandler(logging.Handler):
 | |
|     """This is a logging handler which sends events to a multiprocessing queue.
 | |
| 
 | |
|     The plan is to add it to Python 3.2, but this can be copy pasted into
 | |
|     user code for use with earlier Python versions.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, queue):
 | |
|         """Initialise an instance, using the passed queue."""
 | |
|         logging.Handler.__init__(self)
 | |
|         self.queue = queue
 | |
| 
 | |
|     def emit(self, record):
 | |
|         """Emit a record.
 | |
| 
 | |
|         Writes the LogRecord to the queue.
 | |
|         """
 | |
|         try:
 | |
|             ei = record.exc_info
 | |
|             if ei:
 | |
|                 # just to get traceback text into record.exc_text
 | |
|                 dummy = self.format(record)  # noqa
 | |
|                 record.exc_info = None  # not needed any more
 | |
|             self.queue.put_nowait(record)
 | |
|         except (KeyboardInterrupt, SystemExit):
 | |
|             raise
 | |
|         except Exception:
 | |
|             self.handleError(record)
 | |
| 
 | |
| 
 | |
| def listener_configurer(conf):
 | |
|     root = logging.getLogger()
 | |
|     h = logging.StreamHandler(sys.stdout)
 | |
|     f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s '
 | |
|                           '%(levelname)-8s %(message)s')
 | |
|     h.setFormatter(f)
 | |
|     root.addHandler(h)
 | |
|     log_path = conf.rpc_zmq_ipc_dir + "/" + "zmq_multiproc.log"
 | |
|     file_handler = logging.StreamHandler(open(log_path, 'w'))
 | |
|     file_handler.setFormatter(f)
 | |
|     root.addHandler(file_handler)
 | |
| 
 | |
| 
 | |
| def server_configurer(queue):
 | |
|     h = QueueHandler(queue)
 | |
|     root = logging.getLogger()
 | |
|     root.addHandler(h)
 | |
|     root.setLevel(logging.DEBUG)
 | |
| 
 | |
| 
 | |
| def listener_thread(queue, configurer, conf):
 | |
|     configurer(conf)
 | |
|     while True:
 | |
|         time.sleep(0.3)
 | |
|         try:
 | |
|             record = queue.get()
 | |
|             if record is None:
 | |
|                 break
 | |
|             logger = logging.getLogger(record.name)
 | |
|             logger.handle(record)
 | |
|         except (KeyboardInterrupt, SystemExit):
 | |
|             raise
 | |
| 
 | |
| 
 | |
| class Client(oslo_messaging.RPCClient):
 | |
| 
 | |
|     def __init__(self, transport, topic):
 | |
|         super(Client, self).__init__(
 | |
|             transport=transport, target=oslo_messaging.Target(topic=topic))
 | |
|         self.replies = []
 | |
| 
 | |
|     def call_a(self):
 | |
|         LOG.warning("call_a - client side")
 | |
|         rep = self.call({}, 'call_a')
 | |
|         LOG.warning("after call_a - client side")
 | |
|         self.replies.append(rep)
 | |
|         return rep
 | |
| 
 | |
| 
 | |
| class ReplyServerEndpoint(object):
 | |
| 
 | |
|     def call_a(self, *args, **kwargs):
 | |
|         LOG.warning("call_a - Server endpoint reached!")
 | |
|         return "OK"
 | |
| 
 | |
| 
 | |
| class Server(object):
 | |
| 
 | |
|     def __init__(self, conf, log_queue, transport_url, name, topic=None):
 | |
|         self.conf = conf
 | |
|         self.log_queue = log_queue
 | |
|         self.transport_url = transport_url
 | |
|         self.name = name
 | |
|         self.topic = topic or str(uuid.uuid4())
 | |
|         self.ready = multiprocessing.Value('b', False)
 | |
|         self._stop = multiprocessing.Event()
 | |
| 
 | |
|     def start(self):
 | |
|         self.process = multiprocessing.Process(target=self._run_server,
 | |
|                                                name=self.name,
 | |
|                                                args=(self.conf,
 | |
|                                                      self.transport_url,
 | |
|                                                      self.log_queue,
 | |
|                                                      self.ready))
 | |
|         self.process.start()
 | |
|         LOG.debug("Server process started: pid: %d", self.process.pid)
 | |
| 
 | |
|     def _run_server(self, conf, url, log_queue, ready):
 | |
|         server_configurer(log_queue)
 | |
|         LOG.debug("Starting RPC server")
 | |
| 
 | |
|         transport = oslo_messaging.get_transport(conf, url=url)
 | |
|         target = oslo_messaging.Target(topic=self.topic, server=self.name)
 | |
|         self.rpc_server = oslo_messaging.get_rpc_server(
 | |
|             transport=transport, target=target,
 | |
|             endpoints=[ReplyServerEndpoint()],
 | |
|             executor='eventlet')
 | |
|         self.rpc_server.start()
 | |
|         ready.value = True
 | |
|         LOG.debug("RPC server being started")
 | |
|         while not self._stop.is_set():
 | |
|             LOG.debug("Waiting for the stop signal ...")
 | |
|             time.sleep(1)
 | |
|         self.rpc_server.stop()
 | |
|         LOG.debug("Leaving process T:%s Pid:%d", str(target), os.getpid())
 | |
| 
 | |
|     def cleanup(self):
 | |
|         LOG.debug("Stopping server")
 | |
|         self.shutdown()
 | |
| 
 | |
|     def shutdown(self):
 | |
|         self._stop.set()
 | |
| 
 | |
|     def restart(self, time_for_restart=1):
 | |
|         pass
 | |
| 
 | |
|     def hang(self):
 | |
|         pass
 | |
| 
 | |
|     def crash(self):
 | |
|         pass
 | |
| 
 | |
|     def ping(self):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class MutliprocTestCase(utils.SkipIfNoTransportURL):
 | |
| 
 | |
|     def setUp(self):
 | |
|         super(MutliprocTestCase, self).setUp(conf=cfg.ConfigOpts())
 | |
| 
 | |
|         if not self.url.startswith("zmq:"):
 | |
|             self.skipTest("ZeroMQ specific skipped ...")
 | |
| 
 | |
|         self.transport = oslo_messaging.get_transport(self.conf, url=self.url)
 | |
| 
 | |
|         LOG.debug("Start log queue")
 | |
| 
 | |
|         self.log_queue = multiprocessing.Queue()
 | |
|         self.log_listener = threading.Thread(target=listener_thread,
 | |
|                                              args=(self.log_queue,
 | |
|                                                    listener_configurer,
 | |
|                                                    self.conf))
 | |
|         self.log_listener.start()
 | |
|         self.spawned = []
 | |
| 
 | |
|         self.conf.prog = "test_prog"
 | |
|         self.conf.project = "test_project"
 | |
| 
 | |
|     def tearDown(self):
 | |
|         super(MutliprocTestCase, self).tearDown()
 | |
|         for process in self.spawned:
 | |
|             process.cleanup()
 | |
| 
 | |
|     def get_client(self, topic):
 | |
|         return Client(self.transport, topic)
 | |
| 
 | |
|     def spawn_server(self, name, wait_for_server=False, topic=None):
 | |
|         srv = Server(self.conf, self.log_queue, self.url, name, topic)
 | |
|         LOG.debug("[SPAWN] %s (starting)...", srv.name)
 | |
|         srv.start()
 | |
|         if wait_for_server:
 | |
|             while not srv.ready.value:
 | |
|                 LOG.debug("[SPAWN] %s (waiting for server ready)...",
 | |
|                           srv.name)
 | |
|                 time.sleep(1)
 | |
|         LOG.debug("[SPAWN] Server %s:%d started.", srv.name, srv.process.pid)
 | |
|         self.spawned.append(srv)
 | |
|         return srv
 | |
| 
 | |
|     def spawn_servers(self, number, wait_for_server=False, random_topic=True):
 | |
|         common_topic = str(uuid.uuid4()) if random_topic else None
 | |
|         names = ["server_%i_%s" % (i, str(uuid.uuid4())[:8])
 | |
|                  for i in range(number)]
 | |
|         for name in names:
 | |
|             server = self.spawn_server(name, wait_for_server, common_topic)
 | |
|             self.spawned.append(server)
 | 
