Merge "Bind a new socket per-worker"

This commit is contained in:
Zuul
2020-09-03 01:59:13 +00:00
committed by Gerrit Code Review
2 changed files with 232 additions and 328 deletions

View File

@@ -636,7 +636,7 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol):
return environ return environ
def run_server(conf, logger, sock, global_conf=None): def run_server(conf, logger, sock, global_conf=None, ready_callback=None):
# Ensure TZ environment variable exists to avoid stat('/etc/localtime') on # Ensure TZ environment variable exists to avoid stat('/etc/localtime') on
# some platforms. This locks in reported times to UTC. # some platforms. This locks in reported times to UTC.
os.environ['TZ'] = 'UTC+0' os.environ['TZ'] = 'UTC+0'
@@ -677,6 +677,8 @@ def run_server(conf, logger, sock, global_conf=None):
# header; "Etag" just won't do). # header; "Etag" just won't do).
'capitalize_response_headers': False, 'capitalize_response_headers': False,
} }
if ready_callback:
ready_callback()
try: try:
wsgi.server(sock, app, wsgi_logger, **server_kwargs) wsgi.server(sock, app, wsgi_logger, **server_kwargs)
except socket.error as err: except socket.error as err:
@@ -689,6 +691,15 @@ class StrategyBase(object):
""" """
Some operations common to all strategy classes. Some operations common to all strategy classes.
""" """
def __init__(self, conf, logger):
self.conf = conf
self.logger = logger
self.signaled_ready = False
# Each strategy is welcome to track data however it likes, but all
# socket refs should be somewhere in this dict. This allows forked-off
# children to easily drop refs to sibling sockets in post_fork_hook().
self.tracking_data = {}
def post_fork_hook(self): def post_fork_hook(self):
""" """
@@ -696,7 +707,10 @@ class StrategyBase(object):
wsgi server, to perform any initialization such as drop privileges. wsgi server, to perform any initialization such as drop privileges.
""" """
if not self.signaled_ready:
capture_stdio(self.logger)
drop_privileges(self.conf.get('user', 'swift')) drop_privileges(self.conf.get('user', 'swift'))
del self.tracking_data # children don't need to track siblings
def shutdown_sockets(self): def shutdown_sockets(self):
""" """
@@ -721,12 +735,40 @@ class StrategyBase(object):
# on socket objects is provided to toggle it. # on socket objects is provided to toggle it.
sock.set_inheritable(False) sock.set_inheritable(False)
def signal_ready(self):
"""
Signal that the server is up and accepting connections.
"""
if self.signaled_ready:
return # Already did it
# Redirect errors to logger and close stdio. swift-init (for example)
# uses this to know that the service is ready to accept connections.
capture_stdio(self.logger)
# If necessary, signal an old copy of us that it's okay to shutdown
# its listen sockets now because ours are up and ready to receive
# connections. This is used for seamless reloading using SIGUSR1.
reexec_signal_fd = os.getenv(NOTIFY_FD_ENV_KEY)
if reexec_signal_fd:
reexec_signal_fd = int(reexec_signal_fd)
os.write(reexec_signal_fd, str(os.getpid()).encode('utf8'))
os.close(reexec_signal_fd)
# Finally, signal systemd (if appropriate) that process started
# properly.
systemd_notify(logger=self.logger)
self.signaled_ready = True
class WorkersStrategy(StrategyBase): class WorkersStrategy(StrategyBase):
""" """
WSGI server management strategy object for a single bind port and listen WSGI server management strategy object for a single bind port and listen
socket shared by a configured number of forked-off workers. socket shared by a configured number of forked-off workers.
Tracking data is a map of ``pid -> socket``.
Used in :py:func:`run_wsgi`. Used in :py:func:`run_wsgi`.
:param dict conf: Server configuration dictionary. :param dict conf: Server configuration dictionary.
@@ -735,10 +777,7 @@ class WorkersStrategy(StrategyBase):
""" """
def __init__(self, conf, logger): def __init__(self, conf, logger):
self.conf = conf super(WorkersStrategy, self).__init__(conf, logger)
self.logger = logger
self.sock = None
self.children = []
self.worker_count = config_auto_int_value(conf.get('workers'), self.worker_count = config_auto_int_value(conf.get('workers'),
CPU_COUNT) CPU_COUNT)
@@ -753,18 +792,6 @@ class WorkersStrategy(StrategyBase):
return 0.5 return 0.5
def do_bind_ports(self):
"""
Bind the one listen socket for this strategy.
"""
try:
self.sock = get_socket(self.conf)
except ConfigFilePortError:
msg = 'bind_port wasn\'t properly set in the config file. ' \
'It must be explicitly set to a valid port number.'
return msg
def no_fork_sock(self): def no_fork_sock(self):
""" """
Return a server listen socket if the server should run in the Return a server listen socket if the server should run in the
@@ -773,7 +800,7 @@ class WorkersStrategy(StrategyBase):
# Useful for profiling [no forks]. # Useful for profiling [no forks].
if self.worker_count == 0: if self.worker_count == 0:
return self.sock return get_socket(self.conf)
def new_worker_socks(self): def new_worker_socks(self):
""" """
@@ -785,8 +812,8 @@ class WorkersStrategy(StrategyBase):
where it will be ignored. where it will be ignored.
""" """
while len(self.children) < self.worker_count: while len(self.tracking_data) < self.worker_count:
yield self.sock, None yield get_socket(self.conf), None
def log_sock_exit(self, sock, _unused): def log_sock_exit(self, sock, _unused):
""" """
@@ -810,7 +837,7 @@ class WorkersStrategy(StrategyBase):
self.logger.notice('Started child %s from parent %s', self.logger.notice('Started child %s from parent %s',
pid, os.getpid()) pid, os.getpid())
self.children.append(pid) self.tracking_data[pid] = sock
def register_worker_exit(self, pid): def register_worker_exit(self, pid):
""" """
@@ -823,139 +850,22 @@ class WorkersStrategy(StrategyBase):
:param int pid: The PID of the worker that exited. :param int pid: The PID of the worker that exited.
""" """
if pid in self.children: sock = self.tracking_data.pop(pid, None)
if sock is None:
self.logger.info('Ignoring wait() result from unknown PID %s', pid)
else:
self.logger.error('Removing dead child %s from parent %s', self.logger.error('Removing dead child %s from parent %s',
pid, os.getpid()) pid, os.getpid())
self.children.remove(pid) greenio.shutdown_safe(sock)
else: sock.close()
self.logger.info('Ignoring wait() result from unknown PID %s', pid)
def iter_sockets(self): def iter_sockets(self):
""" """
Yields all known listen sockets. Yields all known listen sockets.
""" """
if self.sock: for sock in self.tracking_data.values():
yield self.sock yield sock
class PortPidState(object):
"""
A helper class for :py:class:`ServersPerPortStrategy` to track listen
sockets and PIDs for each port.
:param int servers_per_port: The configured number of servers per port.
:param logger: The server's :py:class:`~swift.common.utils.LogAdaptor`
"""
def __init__(self, servers_per_port, logger):
self.servers_per_port = servers_per_port
self.logger = logger
self.sock_data_by_port = {}
def sock_for_port(self, port):
"""
:param int port: The port whose socket is desired.
:returns: The bound listen socket for the given port.
"""
return self.sock_data_by_port[port]['sock']
def port_for_sock(self, sock):
"""
:param socket sock: A tracked bound listen socket
:returns: The port the socket is bound to.
"""
for port, sock_data in self.sock_data_by_port.items():
if sock_data['sock'] == sock:
return port
def _pid_to_port_and_index(self, pid):
for port, sock_data in self.sock_data_by_port.items():
for server_idx, a_pid in enumerate(sock_data['pids']):
if pid == a_pid:
return port, server_idx
def port_index_pairs(self):
"""
Returns current (port, server index) pairs.
:returns: A set of (port, server_idx) tuples for currently-tracked
ports, sockets, and PIDs.
"""
current_port_index_pairs = set()
for port, pid_state in self.sock_data_by_port.items():
current_port_index_pairs |= set(
(port, i)
for i, pid in enumerate(pid_state['pids'])
if pid is not None)
return current_port_index_pairs
def track_port(self, port, sock):
"""
Start tracking servers for the given port and listen socket.
:param int port: The port to start tracking
:param socket sock: The bound listen socket for the port.
"""
self.sock_data_by_port[port] = {
'sock': sock,
'pids': [None] * self.servers_per_port,
}
def not_tracking(self, port):
"""
Return True if the specified port is not being tracked.
:param int port: A port to check.
"""
return port not in self.sock_data_by_port
def all_socks(self):
"""
Yield all current listen sockets.
"""
for orphan_data in self.sock_data_by_port.values():
yield orphan_data['sock']
def forget_port(self, port):
"""
Idempotently forget a port, closing the listen socket at most once.
"""
orphan_data = self.sock_data_by_port.pop(port, None)
if orphan_data:
greenio.shutdown_safe(orphan_data['sock'])
orphan_data['sock'].close()
self.logger.notice('Closing unnecessary sock for port %d', port)
def add_pid(self, port, index, pid):
self.sock_data_by_port[port]['pids'][index] = pid
def forget_pid(self, pid):
"""
Idempotently forget a PID. It's okay if the PID is no longer in our
data structure (it could have been removed by the "orphan port" removal
in :py:meth:`new_worker_socks`).
:param int pid: The PID which exited.
"""
port_server_idx = self._pid_to_port_and_index(pid)
if port_server_idx is None:
# This method can lose a race with the "orphan port" removal, when
# a ring reload no longer contains a port. So it's okay if we were
# unable to find a (port, server_idx) pair.
return
dead_port, server_idx = port_server_idx
self.logger.error('Removing dead child %d (PID: %s) for port %s',
server_idx, pid, dead_port)
self.sock_data_by_port[dead_port]['pids'][server_idx] = None
class ServersPerPortStrategy(StrategyBase): class ServersPerPortStrategy(StrategyBase):
@@ -965,6 +875,8 @@ class ServersPerPortStrategy(StrategyBase):
`servers_per_port` integer config setting determines how many workers are `servers_per_port` integer config setting determines how many workers are
run per port. run per port.
Tracking data is a map like ``port -> [(pid, socket), ...]``.
Used in :py:func:`run_wsgi`. Used in :py:func:`run_wsgi`.
:param dict conf: Server configuration dictionary. :param dict conf: Server configuration dictionary.
@@ -974,12 +886,10 @@ class ServersPerPortStrategy(StrategyBase):
""" """
def __init__(self, conf, logger, servers_per_port): def __init__(self, conf, logger, servers_per_port):
self.conf = conf super(ServersPerPortStrategy, self).__init__(conf, logger)
self.logger = logger
self.servers_per_port = servers_per_port self.servers_per_port = servers_per_port
self.swift_dir = conf.get('swift_dir', '/etc/swift') self.swift_dir = conf.get('swift_dir', '/etc/swift')
self.ring_check_interval = int(conf.get('ring_check_interval', 15)) self.ring_check_interval = int(conf.get('ring_check_interval', 15))
self.port_pid_state = PortPidState(servers_per_port, logger)
bind_ip = conf.get('bind_ip', '0.0.0.0') bind_ip = conf.get('bind_ip', '0.0.0.0')
self.cache = BindPortsCache(self.swift_dir, bind_ip) self.cache = BindPortsCache(self.swift_dir, bind_ip)
@@ -990,8 +900,7 @@ class ServersPerPortStrategy(StrategyBase):
def _bind_port(self, port): def _bind_port(self, port):
new_conf = self.conf.copy() new_conf = self.conf.copy()
new_conf['bind_port'] = port new_conf['bind_port'] = port
sock = get_socket(new_conf) return get_socket(new_conf)
self.port_pid_state.track_port(port, sock)
def loop_timeout(self): def loop_timeout(self):
""" """
@@ -1003,15 +912,6 @@ class ServersPerPortStrategy(StrategyBase):
return self.ring_check_interval return self.ring_check_interval
def do_bind_ports(self):
"""
Bind one listen socket per unique local storage policy ring port.
"""
self._reload_bind_ports()
for port in self.bind_ports:
self._bind_port(port)
def no_fork_sock(self): def no_fork_sock(self):
""" """
This strategy does not support running in the foreground. This strategy does not support running in the foreground.
@@ -1021,8 +921,8 @@ class ServersPerPortStrategy(StrategyBase):
def new_worker_socks(self): def new_worker_socks(self):
""" """
Yield a sequence of (socket, server_idx) tuples for each server which Yield a sequence of (socket, (port, server_idx)) tuples for each server
should be forked-off and started. which should be forked-off and started.
Any sockets for "orphaned" ports no longer in any ring will be closed Any sockets for "orphaned" ports no longer in any ring will be closed
(causing their associated workers to gracefully exit) after all new (causing their associated workers to gracefully exit) after all new
@@ -1033,11 +933,15 @@ class ServersPerPortStrategy(StrategyBase):
""" """
self._reload_bind_ports() self._reload_bind_ports()
desired_port_index_pairs = set( desired_port_index_pairs = {
(p, i) for p in self.bind_ports (p, i) for p in self.bind_ports
for i in range(self.servers_per_port)) for i in range(self.servers_per_port)}
current_port_index_pairs = self.port_pid_state.port_index_pairs() current_port_index_pairs = {
(p, i)
for p, port_data in self.tracking_data.items()
for i, (pid, sock) in enumerate(port_data)
if pid is not None}
if desired_port_index_pairs != current_port_index_pairs: if desired_port_index_pairs != current_port_index_pairs:
# Orphan ports are ports which had object-server processes running, # Orphan ports are ports which had object-server processes running,
@@ -1046,36 +950,44 @@ class ServersPerPortStrategy(StrategyBase):
orphan_port_index_pairs = current_port_index_pairs - \ orphan_port_index_pairs = current_port_index_pairs - \
desired_port_index_pairs desired_port_index_pairs
# Fork off worker(s) for every port who's supposed to have # Fork off worker(s) for every port that's supposed to have
# worker(s) but doesn't # worker(s) but doesn't
missing_port_index_pairs = desired_port_index_pairs - \ missing_port_index_pairs = desired_port_index_pairs - \
current_port_index_pairs current_port_index_pairs
for port, server_idx in sorted(missing_port_index_pairs): for port, server_idx in sorted(missing_port_index_pairs):
if self.port_pid_state.not_tracking(port): try:
try: sock = self._bind_port(port)
self._bind_port(port) except Exception as e:
except Exception as e: self.logger.critical('Unable to bind to port %d: %s',
self.logger.critical('Unable to bind to port %d: %s', port, e)
port, e) continue
continue yield sock, (port, server_idx)
yield self.port_pid_state.sock_for_port(port), server_idx
for orphan_pair in orphan_port_index_pairs: for port, idx in orphan_port_index_pairs:
# For any port in orphan_port_index_pairs, it is guaranteed # For any port in orphan_port_index_pairs, it is guaranteed
# that there should be no listen socket for that port, so we # that there should be no listen socket for that port, so we
# can close and forget them. # can close and forget them.
self.port_pid_state.forget_port(orphan_pair[0]) pid, sock = self.tracking_data[port][idx]
greenio.shutdown_safe(sock)
sock.close()
self.logger.notice(
'Closing unnecessary sock for port %d (child pid %d)',
port, pid)
self.tracking_data[port][idx] = (None, None)
if all(sock is None
for _pid, sock in self.tracking_data[port]):
del self.tracking_data[port]
def log_sock_exit(self, sock, server_idx): def log_sock_exit(self, sock, data):
""" """
Log a server's exit. Log a server's exit.
""" """
port = self.port_pid_state.port_for_sock(sock) port, server_idx = data
self.logger.notice('Child %d (PID %d, port %d) exiting normally', self.logger.notice('Child %d (PID %d, port %d) exiting normally',
server_idx, os.getpid(), port) server_idx, os.getpid(), port)
def register_worker_start(self, sock, server_idx, pid): def register_worker_start(self, sock, data, pid):
""" """
Called when a new worker is started. Called when a new worker is started.
@@ -1085,10 +997,12 @@ class ServersPerPortStrategy(StrategyBase):
:param int pid: The new worker process' PID :param int pid: The new worker process' PID
""" """
port = self.port_pid_state.port_for_sock(sock) port, server_idx = data
self.logger.notice('Started child %d (PID %d) for port %d', self.logger.notice('Started child %d (PID %d) for port %d',
server_idx, pid, port) server_idx, pid, port)
self.port_pid_state.add_pid(port, server_idx, pid) if port not in self.tracking_data:
self.tracking_data[port] = [(None, None)] * self.servers_per_port
self.tracking_data[port][server_idx] = (pid, sock)
def register_worker_exit(self, pid): def register_worker_exit(self, pid):
""" """
@@ -1097,15 +1011,22 @@ class ServersPerPortStrategy(StrategyBase):
:param int pid: The PID of the worker that exited. :param int pid: The PID of the worker that exited.
""" """
self.port_pid_state.forget_pid(pid) for port_data in self.tracking_data.values():
for idx, (child_pid, sock) in enumerate(port_data):
if child_pid == pid:
port_data[idx] = (None, None)
greenio.shutdown_safe(sock)
sock.close()
return
def iter_sockets(self): def iter_sockets(self):
""" """
Yields all known listen sockets. Yields all known listen sockets.
""" """
for sock in self.port_pid_state.all_socks(): for port_data in self.tracking_data.values():
yield sock for _pid, sock in port_data:
yield sock
def run_wsgi(conf_path, app_section, *args, **kwargs): def run_wsgi(conf_path, app_section, *args, **kwargs):
@@ -1140,6 +1061,15 @@ def run_wsgi(conf_path, app_section, *args, **kwargs):
conf, logger, servers_per_port=servers_per_port) conf, logger, servers_per_port=servers_per_port)
else: else:
strategy = WorkersStrategy(conf, logger) strategy = WorkersStrategy(conf, logger)
try:
# Quick sanity check
int(conf['bind_port'])
except (ValueError, KeyError, TypeError):
error_msg = 'bind_port wasn\'t properly set in the config file. ' \
'It must be explicitly set to a valid port number.'
logger.error(error_msg)
print(error_msg)
return 1
# patch event before loadapp # patch event before loadapp
utils.eventlet_monkey_patch() utils.eventlet_monkey_patch()
@@ -1154,35 +1084,14 @@ def run_wsgi(conf_path, app_section, *args, **kwargs):
utils.FALLOCATE_RESERVE, utils.FALLOCATE_IS_PERCENT = \ utils.FALLOCATE_RESERVE, utils.FALLOCATE_IS_PERCENT = \
utils.config_fallocate_value(conf.get('fallocate_reserve', '1%')) utils.config_fallocate_value(conf.get('fallocate_reserve', '1%'))
# Start listening on bind_addr/port
error_msg = strategy.do_bind_ports()
if error_msg:
logger.error(error_msg)
print(error_msg)
return 1
# Do some daemonization process hygene before we fork any children or run a # Do some daemonization process hygene before we fork any children or run a
# server without forking. # server without forking.
clean_up_daemon_hygiene() clean_up_daemon_hygiene()
# Redirect errors to logger and close stdio. Do this *after* binding ports;
# we use this to signal that the service is ready to accept connections.
capture_stdio(logger)
# If necessary, signal an old copy of us that it's okay to shutdown its
# listen sockets now because ours are up and ready to receive connections.
reexec_signal_fd = os.getenv(NOTIFY_FD_ENV_KEY)
if reexec_signal_fd:
reexec_signal_fd = int(reexec_signal_fd)
os.write(reexec_signal_fd, str(os.getpid()).encode('utf8'))
os.close(reexec_signal_fd)
# Finally, signal systemd (if appropriate) that process started properly.
systemd_notify(logger=logger)
no_fork_sock = strategy.no_fork_sock() no_fork_sock = strategy.no_fork_sock()
if no_fork_sock: if no_fork_sock:
run_server(conf, logger, no_fork_sock, global_conf=global_conf) run_server(conf, logger, no_fork_sock, global_conf=global_conf,
ready_callback=strategy.signal_ready)
return 0 return 0
def stop_with_signal(signum, *args): def stop_with_signal(signum, *args):
@@ -1198,17 +1107,38 @@ def run_wsgi(conf_path, app_section, *args, **kwargs):
while running_context[0]: while running_context[0]:
for sock, sock_info in strategy.new_worker_socks(): for sock, sock_info in strategy.new_worker_socks():
read_fd, write_fd = os.pipe()
pid = os.fork() pid = os.fork()
if pid == 0: if pid == 0:
os.close(read_fd)
signal.signal(signal.SIGHUP, signal.SIG_DFL) signal.signal(signal.SIGHUP, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGUSR1, signal.SIG_DFL) signal.signal(signal.SIGUSR1, signal.SIG_DFL)
strategy.post_fork_hook() strategy.post_fork_hook()
run_server(conf, logger, sock)
def notify():
os.write(write_fd, b'ready')
os.close(write_fd)
run_server(conf, logger, sock, ready_callback=notify)
strategy.log_sock_exit(sock, sock_info) strategy.log_sock_exit(sock, sock_info)
return 0 return 0
else: else:
strategy.register_worker_start(sock, sock_info, pid) os.close(write_fd)
worker_status = os.read(read_fd, 30)
os.close(read_fd)
# TODO: delay this status checking until after we've tried
# to start all workers. But, we currently use the register
# event to know when we've got enough workers :-/
if worker_status == b'ready':
strategy.register_worker_start(sock, sock_info, pid)
else:
raise Exception(
'worker did not start normally: %r' % worker_status)
# TODO: signal_ready() as soon as we have at least one new worker for
# each port, instead of waiting for all of them
strategy.signal_ready()
# The strategy may need to pay attention to something in addition to # The strategy may need to pay attention to something in addition to
# child process exits (like new ports showing up in a ring). # child process exits (like new ports showing up in a ring).

View File

@@ -778,7 +778,7 @@ class TestWSGI(unittest.TestCase):
def _initrp(conf_file, app_section, *args, **kwargs): def _initrp(conf_file, app_section, *args, **kwargs):
return ( return (
{'__file__': 'test', 'workers': 0}, {'__file__': 'test', 'workers': 0, 'bind_port': 12345},
'logger', 'logger',
'log_name') 'log_name')
@@ -788,7 +788,8 @@ class TestWSGI(unittest.TestCase):
def _global_conf_callback(preloaded_app_conf, global_conf): def _global_conf_callback(preloaded_app_conf, global_conf):
calls['_global_conf_callback'] += 1 calls['_global_conf_callback'] += 1
self.assertEqual( self.assertEqual(
preloaded_app_conf, {'__file__': 'test', 'workers': 0}) preloaded_app_conf,
{'__file__': 'test', 'workers': 0, 'bind_port': 12345})
self.assertEqual(global_conf, {'log_name': 'log_name'}) self.assertEqual(global_conf, {'log_name': 'log_name'})
global_conf['test1'] = to_inject global_conf['test1'] = to_inject
@@ -827,7 +828,7 @@ class TestWSGI(unittest.TestCase):
def _initrp(conf_file, app_section, *args, **kwargs): def _initrp(conf_file, app_section, *args, **kwargs):
calls['_initrp'] += 1 calls['_initrp'] += 1
return ( return (
{'__file__': 'test', 'workers': 0}, {'__file__': 'test', 'workers': 0, 'bind_port': 12345},
'logger', 'logger',
'log_name') 'log_name')
@@ -862,11 +863,17 @@ class TestWSGI(unittest.TestCase):
mock_run_server): mock_run_server):
# Make sure the right strategy gets used in a number of different # Make sure the right strategy gets used in a number of different
# config cases. # config cases.
mock_per_port().do_bind_ports.return_value = 'stop early'
mock_workers().do_bind_ports.return_value = 'stop early' class StopAtCreatingSockets(Exception):
'''Dummy exception to make sure we don't actually bind ports'''
mock_per_port().no_fork_sock.return_value = None
mock_per_port().new_worker_socks.side_effect = StopAtCreatingSockets
mock_workers().no_fork_sock.return_value = None
mock_workers().new_worker_socks.side_effect = StopAtCreatingSockets
logger = FakeLogger() logger = FakeLogger()
stub__initrp = [ stub__initrp = [
{'__file__': 'test', 'workers': 2}, # conf {'__file__': 'test', 'workers': 2, 'bind_port': 12345}, # conf
logger, logger,
'log_name', 'log_name',
] ]
@@ -878,14 +885,13 @@ class TestWSGI(unittest.TestCase):
mock_per_port.reset_mock() mock_per_port.reset_mock()
mock_workers.reset_mock() mock_workers.reset_mock()
logger._clear() logger._clear()
self.assertEqual(1, wsgi.run_wsgi('conf_file', server_type)) with self.assertRaises(StopAtCreatingSockets):
self.assertEqual([ wsgi.run_wsgi('conf_file', server_type)
'stop early',
], logger.get_lines_for_level('error'))
self.assertEqual([], mock_per_port.mock_calls) self.assertEqual([], mock_per_port.mock_calls)
self.assertEqual([ self.assertEqual([
mock.call(stub__initrp[0], logger), mock.call(stub__initrp[0], logger),
mock.call().do_bind_ports(), mock.call().no_fork_sock(),
mock.call().new_worker_socks(),
], mock_workers.mock_calls) ], mock_workers.mock_calls)
stub__initrp[0]['servers_per_port'] = 3 stub__initrp[0]['servers_per_port'] = 3
@@ -893,26 +899,24 @@ class TestWSGI(unittest.TestCase):
mock_per_port.reset_mock() mock_per_port.reset_mock()
mock_workers.reset_mock() mock_workers.reset_mock()
logger._clear() logger._clear()
self.assertEqual(1, wsgi.run_wsgi('conf_file', server_type)) with self.assertRaises(StopAtCreatingSockets):
self.assertEqual([ wsgi.run_wsgi('conf_file', server_type)
'stop early',
], logger.get_lines_for_level('error'))
self.assertEqual([], mock_per_port.mock_calls) self.assertEqual([], mock_per_port.mock_calls)
self.assertEqual([ self.assertEqual([
mock.call(stub__initrp[0], logger), mock.call(stub__initrp[0], logger),
mock.call().do_bind_ports(), mock.call().no_fork_sock(),
mock.call().new_worker_socks(),
], mock_workers.mock_calls) ], mock_workers.mock_calls)
mock_per_port.reset_mock() mock_per_port.reset_mock()
mock_workers.reset_mock() mock_workers.reset_mock()
logger._clear() logger._clear()
self.assertEqual(1, wsgi.run_wsgi('conf_file', 'object-server')) with self.assertRaises(StopAtCreatingSockets):
self.assertEqual([ wsgi.run_wsgi('conf_file', 'object-server')
'stop early',
], logger.get_lines_for_level('error'))
self.assertEqual([ self.assertEqual([
mock.call(stub__initrp[0], logger, servers_per_port=3), mock.call(stub__initrp[0], logger, servers_per_port=3),
mock.call().do_bind_ports(), mock.call().no_fork_sock(),
mock.call().new_worker_socks(),
], mock_per_port.mock_calls) ], mock_per_port.mock_calls)
self.assertEqual([], mock_workers.mock_calls) self.assertEqual([], mock_workers.mock_calls)
@@ -1331,12 +1335,16 @@ class TestProxyProtocol(ProtocolTest):
class CommonTestMixin(object): class CommonTestMixin(object):
def test_post_fork_hook(self): @mock.patch('swift.common.wsgi.capture_stdio')
def test_post_fork_hook(self, mock_capture):
self.strategy.post_fork_hook() self.strategy.post_fork_hook()
self.assertEqual([ self.assertEqual([
mock.call('bob'), mock.call('bob'),
], self.mock_drop_privileges.mock_calls) ], self.mock_drop_privileges.mock_calls)
self.assertEqual([
mock.call(self.logger),
], mock_capture.mock_calls)
class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
@@ -1350,9 +1358,9 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
'bind_ip': '2.3.4.5', 'bind_ip': '2.3.4.5',
} }
self.servers_per_port = 3 self.servers_per_port = 3
self.s1, self.s2 = mock.MagicMock(), mock.MagicMock() self.sockets = [mock.MagicMock() for _ in range(6)]
patcher = mock.patch('swift.common.wsgi.get_socket', patcher = mock.patch('swift.common.wsgi.get_socket',
side_effect=[self.s1, self.s2]) side_effect=self.sockets)
self.mock_get_socket = patcher.start() self.mock_get_socket = patcher.start()
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
patcher = mock.patch('swift.common.wsgi.drop_privileges') patcher = mock.patch('swift.common.wsgi.drop_privileges')
@@ -1391,39 +1399,10 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
self.assertEqual(15, self.strategy.loop_timeout()) self.assertEqual(15, self.strategy.loop_timeout())
def test_bind_ports(self):
self.strategy.do_bind_ports()
self.assertEqual(set((6006, 6007)), self.strategy.bind_ports)
self.assertEqual([
mock.call({'workers': 100, # ignored
'user': 'bob',
'swift_dir': '/jim/cricket',
'ring_check_interval': '76',
'bind_ip': '2.3.4.5',
'bind_port': 6006}),
mock.call({'workers': 100, # ignored
'user': 'bob',
'swift_dir': '/jim/cricket',
'ring_check_interval': '76',
'bind_ip': '2.3.4.5',
'bind_port': 6007}),
], self.mock_get_socket.mock_calls)
self.assertEqual(
6006, self.strategy.port_pid_state.port_for_sock(self.s1))
self.assertEqual(
6007, self.strategy.port_pid_state.port_for_sock(self.s2))
# strategy binding no longer does clean_up_deemon_hygene() actions, the
# user of the strategy does.
self.assertEqual([], self.mock_setsid.mock_calls)
self.assertEqual([], self.mock_chdir.mock_calls)
self.assertEqual([], self.mock_umask.mock_calls)
def test_no_fork_sock(self): def test_no_fork_sock(self):
self.assertIsNone(self.strategy.no_fork_sock()) self.assertIsNone(self.strategy.no_fork_sock())
def test_new_worker_socks(self): def test_new_worker_socks(self):
self.strategy.do_bind_ports()
self.all_bind_ports_for_node.reset_mock() self.all_bind_ports_for_node.reset_mock()
pid = 88 pid = 88
@@ -1434,8 +1413,12 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
pid += 1 pid += 1
self.assertEqual([ self.assertEqual([
(self.s1, 0), (self.s1, 1), (self.s1, 2), (self.sockets[0], (6006, 0)),
(self.s2, 0), (self.s2, 1), (self.s2, 2), (self.sockets[1], (6006, 1)),
(self.sockets[2], (6006, 2)),
(self.sockets[3], (6007, 0)),
(self.sockets[4], (6007, 1)),
(self.sockets[5], (6007, 2)),
], got_si) ], got_si)
self.assertEqual([ self.assertEqual([
'Started child %d (PID %d) for port %d' % (0, 88, 6006), 'Started child %d (PID %d) for port %d' % (0, 88, 6006),
@@ -1454,8 +1437,8 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
# Get rid of servers for ports which disappear from the ring # Get rid of servers for ports which disappear from the ring
self.ports = (6007,) self.ports = (6007,)
self.all_bind_ports_for_node.return_value = set(self.ports) self.all_bind_ports_for_node.return_value = set(self.ports)
self.s1.reset_mock() for s in self.sockets:
self.s2.reset_mock() s.reset_mock()
with mock.patch('swift.common.wsgi.greenio') as mock_greenio: with mock.patch('swift.common.wsgi.greenio') as mock_greenio:
self.assertEqual([], list(self.strategy.new_worker_socks())) self.assertEqual([], list(self.strategy.new_worker_socks()))
@@ -1464,23 +1447,28 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
mock.call(), # ring_check_interval has passed... mock.call(), # ring_check_interval has passed...
], self.all_bind_ports_for_node.mock_calls) ], self.all_bind_ports_for_node.mock_calls)
self.assertEqual([ self.assertEqual([
mock.call.shutdown_safe(self.s1), [mock.call.close()]
], mock_greenio.mock_calls) for _ in range(3)
], [s.mock_calls for s in self.sockets[:3]])
self.assertEqual({
('shutdown_safe', (self.sockets[0],)),
('shutdown_safe', (self.sockets[1],)),
('shutdown_safe', (self.sockets[2],)),
}, {call[:2] for call in mock_greenio.mock_calls})
self.assertEqual([ self.assertEqual([
mock.call.close(), [] for _ in range(3)
], self.s1.mock_calls) ], [s.mock_calls for s in self.sockets[3:]]) # not closed
self.assertEqual([], self.s2.mock_calls) # not closed self.assertEqual({
self.assertEqual([ 'Closing unnecessary sock for port %d (child pid %d)' % (6006, p)
'Closing unnecessary sock for port %d' % 6006, for p in range(88, 91)
], self.logger.get_lines_for_level('notice')) }, set(self.logger.get_lines_for_level('notice')))
self.logger._clear() self.logger._clear()
# Create new socket & workers for new ports that appear in ring # Create new socket & workers for new ports that appear in ring
self.ports = (6007, 6009) self.ports = (6007, 6009)
self.all_bind_ports_for_node.return_value = set(self.ports) self.all_bind_ports_for_node.return_value = set(self.ports)
self.s1.reset_mock() for s in self.sockets:
self.s2.reset_mock() s.reset_mock()
s3 = mock.MagicMock()
self.mock_get_socket.side_effect = Exception('ack') self.mock_get_socket.side_effect = Exception('ack')
# But first make sure we handle failure to bind to the requested port! # But first make sure we handle failure to bind to the requested port!
@@ -1499,7 +1487,8 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
self.logger._clear() self.logger._clear()
# Will keep trying, so let it succeed again # Will keep trying, so let it succeed again
self.mock_get_socket.side_effect = [s3] new_sockets = self.mock_get_socket.side_effect = [
mock.MagicMock() for _ in range(3)]
got_si = [] got_si = []
for s, i in self.strategy.new_worker_socks(): for s, i in self.strategy.new_worker_socks():
@@ -1508,7 +1497,7 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
pid += 1 pid += 1
self.assertEqual([ self.assertEqual([
(s3, 0), (s3, 1), (s3, 2), (s, (6009, i)) for i, s in enumerate(new_sockets)
], got_si) ], got_si)
self.assertEqual([ self.assertEqual([
'Started child %d (PID %d) for port %d' % (0, 94, 6009), 'Started child %d (PID %d) for port %d' % (0, 94, 6009),
@@ -1524,6 +1513,11 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
# Restart a guy who died on us # Restart a guy who died on us
self.strategy.register_worker_exit(95) # server_idx == 1 self.strategy.register_worker_exit(95) # server_idx == 1
# TODO: check that the socket got cleaned up
new_socket = mock.MagicMock()
self.mock_get_socket.side_effect = [new_socket]
got_si = [] got_si = []
for s, i in self.strategy.new_worker_socks(): for s, i in self.strategy.new_worker_socks():
got_si.append((s, i)) got_si.append((s, i))
@@ -1531,7 +1525,7 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
pid += 1 pid += 1
self.assertEqual([ self.assertEqual([
(s3, 1), (new_socket, (6009, 1)),
], got_si) ], got_si)
self.assertEqual([ self.assertEqual([
'Started child %d (PID %d) for port %d' % (1, 97, 6009), 'Started child %d (PID %d) for port %d' % (1, 97, 6009),
@@ -1539,7 +1533,7 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
self.logger._clear() self.logger._clear()
# Check log_sock_exit # Check log_sock_exit
self.strategy.log_sock_exit(self.s2, 2) self.strategy.log_sock_exit(self.sockets[5], (6007, 2))
self.assertEqual([ self.assertEqual([
'Child %d (PID %d, port %d) exiting normally' % ( 'Child %d (PID %d, port %d) exiting normally' % (
2, os.getpid(), 6007), 2, os.getpid(), 6007),
@@ -1551,21 +1545,22 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin):
self.assertIsNone(self.strategy.register_worker_exit(89)) self.assertIsNone(self.strategy.register_worker_exit(89))
def test_shutdown_sockets(self): def test_shutdown_sockets(self):
self.strategy.do_bind_ports() pid = 88
for s, i in self.strategy.new_worker_socks():
self.strategy.register_worker_start(s, i, pid)
pid += 1
with mock.patch('swift.common.wsgi.greenio') as mock_greenio: with mock.patch('swift.common.wsgi.greenio') as mock_greenio:
self.strategy.shutdown_sockets() self.strategy.shutdown_sockets()
self.assertEqual([ self.assertEqual([
mock.call.shutdown_safe(self.s1), mock.call.shutdown_safe(s)
mock.call.shutdown_safe(self.s2), for s in self.sockets
], mock_greenio.mock_calls) ], mock_greenio.mock_calls)
self.assertEqual([ self.assertEqual([
mock.call.close(), [mock.call.close()]
], self.s1.mock_calls) for _ in range(3)
self.assertEqual([ ], [s.mock_calls for s in self.sockets[:3]])
mock.call.close(),
], self.s2.mock_calls)
class TestWorkersStrategy(unittest.TestCase, CommonTestMixin): class TestWorkersStrategy(unittest.TestCase, CommonTestMixin):
@@ -1576,8 +1571,9 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin):
'user': 'bob', 'user': 'bob',
} }
self.strategy = wsgi.WorkersStrategy(self.conf, self.logger) self.strategy = wsgi.WorkersStrategy(self.conf, self.logger)
self.mock_socket = mock.Mock()
patcher = mock.patch('swift.common.wsgi.get_socket', patcher = mock.patch('swift.common.wsgi.get_socket',
return_value='abc') return_value=self.mock_socket)
self.mock_get_socket = patcher.start() self.mock_get_socket = patcher.start()
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
patcher = mock.patch('swift.common.wsgi.drop_privileges') patcher = mock.patch('swift.common.wsgi.drop_privileges')
@@ -1593,41 +1589,19 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin):
# gets checked). # gets checked).
self.assertEqual(0.5, self.strategy.loop_timeout()) self.assertEqual(0.5, self.strategy.loop_timeout())
def test_binding(self):
self.assertIsNone(self.strategy.do_bind_ports())
self.assertEqual('abc', self.strategy.sock)
self.assertEqual([
mock.call(self.conf),
], self.mock_get_socket.mock_calls)
# strategy binding no longer drops privileges nor does
# clean_up_deemon_hygene() actions.
self.assertEqual([], self.mock_drop_privileges.mock_calls)
self.assertEqual([], self.mock_clean_up_daemon_hygene.mock_calls)
self.mock_get_socket.side_effect = wsgi.ConfigFilePortError()
self.assertEqual(
'bind_port wasn\'t properly set in the config file. '
'It must be explicitly set to a valid port number.',
self.strategy.do_bind_ports())
def test_no_fork_sock(self): def test_no_fork_sock(self):
self.strategy.do_bind_ports()
self.assertIsNone(self.strategy.no_fork_sock()) self.assertIsNone(self.strategy.no_fork_sock())
self.conf['workers'] = 0 self.conf['workers'] = 0
self.strategy = wsgi.WorkersStrategy(self.conf, self.logger) self.strategy = wsgi.WorkersStrategy(self.conf, self.logger)
self.strategy.do_bind_ports()
self.assertEqual('abc', self.strategy.no_fork_sock()) self.assertIs(self.mock_socket, self.strategy.no_fork_sock())
def test_new_worker_socks(self): def test_new_worker_socks(self):
self.strategy.do_bind_ports()
pid = 88 pid = 88
sock_count = 0 sock_count = 0
for s, i in self.strategy.new_worker_socks(): for s, i in self.strategy.new_worker_socks():
self.assertEqual('abc', s) self.assertEqual(self.mock_socket, s)
self.assertIsNone(i) # unused for this strategy self.assertIsNone(i) # unused for this strategy
self.strategy.register_worker_start(s, 'unused', pid) self.strategy.register_worker_start(s, 'unused', pid)
pid += 1 pid += 1
@@ -1650,7 +1624,7 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin):
], self.logger.get_lines_for_level('error')) ], self.logger.get_lines_for_level('error'))
for s, i in self.strategy.new_worker_socks(): for s, i in self.strategy.new_worker_socks():
self.assertEqual('abc', s) self.assertEqual(self.mock_socket, s)
self.assertIsNone(i) # unused for this strategy self.assertIsNone(i) # unused for this strategy
self.strategy.register_worker_start(s, 'unused', pid) self.strategy.register_worker_start(s, 'unused', pid)
pid += 1 pid += 1
@@ -1664,23 +1638,23 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin):
], self.logger.get_lines_for_level('notice')) ], self.logger.get_lines_for_level('notice'))
def test_shutdown_sockets(self): def test_shutdown_sockets(self):
self.mock_get_socket.return_value = mock.MagicMock() self.mock_get_socket.side_effect = sockets = [
self.strategy.do_bind_ports() mock.MagicMock(), mock.MagicMock()]
pid = 88
for s, i in self.strategy.new_worker_socks():
self.strategy.register_worker_start(s, 'unused', pid)
pid += 1
with mock.patch('swift.common.wsgi.greenio') as mock_greenio: with mock.patch('swift.common.wsgi.greenio') as mock_greenio:
self.strategy.shutdown_sockets() self.strategy.shutdown_sockets()
self.assertEqual([ self.assertEqual([
mock.call.shutdown_safe(self.mock_get_socket.return_value), mock.call.shutdown_safe(s)
for s in sockets
], mock_greenio.mock_calls) ], mock_greenio.mock_calls)
if six.PY2: self.assertEqual([
self.assertEqual([ [mock.call.close()] for _ in range(2)
mock.call.__nonzero__(), ], [s.mock_calls for s in sockets])
mock.call.close(),
], self.mock_get_socket.return_value.mock_calls)
else:
self.assertEqual([
mock.call.__bool__(),
mock.call.close(),
], self.mock_get_socket.return_value.mock_calls)
def test_log_sock_exit(self): def test_log_sock_exit(self):
self.strategy.log_sock_exit('blahblah', 'blahblah') self.strategy.log_sock_exit('blahblah', 'blahblah')