568 lines
20 KiB
Python
568 lines
20 KiB
Python
"""Zookeeper Protocol Connection Handler"""
|
|
import errno
|
|
import logging
|
|
import random
|
|
import select
|
|
import socket
|
|
import time
|
|
from contextlib import contextmanager
|
|
|
|
from kazoo.exceptions import (
|
|
AuthFailedError,
|
|
ConnectionDropped,
|
|
EXCEPTIONS,
|
|
SessionExpiredError,
|
|
NoNodeError
|
|
)
|
|
from kazoo.protocol.serialization import (
|
|
Auth,
|
|
Close,
|
|
Connect,
|
|
Exists,
|
|
GetChildren,
|
|
Ping,
|
|
ReplyHeader,
|
|
Transaction,
|
|
Watch,
|
|
int_struct
|
|
)
|
|
from kazoo.protocol.states import (
|
|
Callback,
|
|
KeeperState,
|
|
WatchedEvent,
|
|
EVENT_TYPE_MAP,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
CREATED_EVENT = 1
|
|
DELETED_EVENT = 2
|
|
CHANGED_EVENT = 3
|
|
CHILD_EVENT = 4
|
|
|
|
WATCH_XID = -1
|
|
PING_XID = -2
|
|
AUTH_XID = -4
|
|
|
|
|
|
@contextmanager
|
|
def socket_error_handling():
|
|
try:
|
|
yield
|
|
except (socket.error, select.error) as e:
|
|
if isinstance(e.args, tuple):
|
|
raise ConnectionDropped("socket connection error: %s",
|
|
errno.errorcode[e[0]])
|
|
else: # pragma: nocover
|
|
# This is only possible on Python 2.5 or earlier
|
|
raise ConnectionDropped("socket connection error: %s", e)
|
|
|
|
|
|
class RWPinger(object):
|
|
"""A Read/Write Server Pinger Iterable
|
|
|
|
This object is initialized with the hosts iterator object and the
|
|
socket creation function. Anytime `next` is called on its iterator
|
|
it yields either False, or a host, port tuple if it found a r/w
|
|
capable Zookeeper node.
|
|
|
|
After the first run-through of hosts, an exponential back-off delay
|
|
is added before the next run. This delay is tracked internally and
|
|
the iterator will yield False if called too soon.
|
|
|
|
"""
|
|
def __init__(self, hosts, socket_func):
|
|
self.hosts = hosts
|
|
self.socket = socket_func
|
|
self.last_attempt = None
|
|
|
|
def __iter__(self):
|
|
if not self.last_attempt:
|
|
self.last_attempt = time.time()
|
|
delay = 0.5
|
|
while True:
|
|
jitter = random.randint(0, 100) / 100.0
|
|
while time.time() < self.last_attempt + delay + jitter:
|
|
# Skip rw ping checks if its too soon
|
|
yield False
|
|
for host, port in self.hosts:
|
|
sock = self.socket()
|
|
log.debug("Pinging server for r/w: %s:%s", host, port)
|
|
self.last_attempt = time.time()
|
|
try:
|
|
with socket_error_handling():
|
|
sock.connect((host, port))
|
|
sock.sendall("isro")
|
|
result = sock.recv(8192)
|
|
sock.close()
|
|
if result == 'rw':
|
|
yield (host, port)
|
|
else:
|
|
yield False
|
|
except ConnectionDropped:
|
|
yield False
|
|
|
|
# Add some jitter between host pings
|
|
while time.time() < self.last_attempt + jitter:
|
|
yield False
|
|
delay *= 2
|
|
|
|
|
|
class RWServerAvailable(Exception):
|
|
"""Thrown if a RW Server becomes available"""
|
|
|
|
|
|
class ConnectionHandler(object):
|
|
"""Zookeeper connection handler"""
|
|
def __init__(self, client, retry_sleeper, log_debug=False):
|
|
self.client = client
|
|
self.handler = client.handler
|
|
self.retry_sleeper = retry_sleeper
|
|
|
|
# Our event objects
|
|
self.reader_started = client.handler.event_object()
|
|
self.reader_done = client.handler.event_object()
|
|
self.writer_stopped = client.handler.event_object()
|
|
self.writer_stopped.set()
|
|
|
|
self.log_debug = log_debug
|
|
self._socket = None
|
|
self._xid = None
|
|
self._rw_server = None
|
|
self._ro_mode = False
|
|
|
|
def start(self):
|
|
"""Start the connection up"""
|
|
self.handler.spawn(self.writer)
|
|
|
|
def stop(self, timeout=None):
|
|
"""Ensure the writer has stopped, wait to see if it does"""
|
|
self.writer_stopped.wait(timeout)
|
|
return self.writer_stopped.is_set()
|
|
|
|
def _server_pinger(self):
|
|
"""Returns a server pinger iterable, that will ping the next
|
|
server in the list, and apply a back-off between attempts"""
|
|
return RWPinger(self.client.hosts, self.handler.socket)
|
|
|
|
def _read_header(self, timeout):
|
|
b = self._read(4, timeout)
|
|
length = int_struct.unpack(b)[0]
|
|
b = self._read(length, timeout)
|
|
header, offset = ReplyHeader.deserialize(b, 0)
|
|
return header, b, offset
|
|
|
|
def _read(self, length, timeout):
|
|
msgparts = []
|
|
remaining = length
|
|
with socket_error_handling():
|
|
while remaining > 0:
|
|
s = self.handler.select([self._socket], [], [], timeout)[0]
|
|
chunk = s[0].recv(remaining)
|
|
if chunk == '':
|
|
raise ConnectionDropped('socket connection broken')
|
|
msgparts.append(chunk)
|
|
remaining -= len(chunk)
|
|
return b"".join(msgparts)
|
|
|
|
def _invoke(self, timeout, request, xid=None):
|
|
"""A special writer used during connection establishment
|
|
only"""
|
|
b = bytearray()
|
|
if xid:
|
|
b.extend(int_struct.pack(xid))
|
|
if request.type:
|
|
b.extend(int_struct.pack(request.type))
|
|
b.extend(request.serialize())
|
|
|
|
if self.log_debug:
|
|
log.debug("Sending request: %s", request)
|
|
self._write(int_struct.pack(len(b)) + b, timeout)
|
|
|
|
zxid = None
|
|
if xid:
|
|
header, buffer, offset = self._read_header(timeout)
|
|
if header.xid != xid:
|
|
raise RuntimeError('xids do not match, expected %r received %r',
|
|
xid, header.xid)
|
|
if header.zxid > 0:
|
|
zxid = header.zxid
|
|
if header.err:
|
|
callback_exception = EXCEPTIONS[header.err]()
|
|
log.debug('Received error %r', callback_exception)
|
|
raise callback_exception
|
|
return zxid
|
|
|
|
msg = self._read(4, timeout)
|
|
length = int_struct.unpack(msg)[0]
|
|
msg = self._read(length, timeout)
|
|
|
|
if hasattr(request, 'deserialize'):
|
|
obj, _ = request.deserialize(msg, 0)
|
|
log.debug('Read response %s', obj)
|
|
return obj, zxid
|
|
|
|
return zxid
|
|
|
|
def _submit(self, request, timeout, xid=None):
|
|
"""Submit a request object with a timeout value and optional
|
|
xid"""
|
|
b = bytearray()
|
|
b.extend(int_struct.pack(xid))
|
|
if request.type:
|
|
b.extend(int_struct.pack(request.type))
|
|
b += request.serialize()
|
|
self._write(int_struct.pack(len(b)) + b, timeout)
|
|
|
|
def _write(self, msg, timeout):
|
|
"""Write a raw msg to the socket"""
|
|
sent = 0
|
|
msg_length = len(msg)
|
|
with socket_error_handling():
|
|
while sent < msg_length:
|
|
s = self.handler.select([], [self._socket], [], timeout)[1]
|
|
msg_slice = buffer(msg, sent)
|
|
bytes_sent = s[0].send(msg_slice)
|
|
if not bytes_sent:
|
|
raise ConnectionDropped('socket connection broken')
|
|
sent += bytes_sent
|
|
|
|
def _read_watch_event(self, buffer, offset):
|
|
client = self.client
|
|
watch, offset = Watch.deserialize(buffer, offset)
|
|
path = watch.path
|
|
|
|
if self.log_debug:
|
|
log.debug('Received EVENT: %s', watch)
|
|
|
|
watchers = []
|
|
with client._state_lock:
|
|
# Ignore watches if we've been stopped
|
|
if client._stopped.is_set():
|
|
return
|
|
|
|
if watch.type in (CREATED_EVENT, CHANGED_EVENT):
|
|
watchers.extend(self.client._data_watchers.pop(path, []))
|
|
elif watch.type == DELETED_EVENT:
|
|
watchers.extend(self.client._data_watchers.pop(path, []))
|
|
watchers.extend(self.client._child_watchers.pop(path, []))
|
|
elif watch.type == CHILD_EVENT:
|
|
watchers.extend(self.client._child_watchers.pop(path, []))
|
|
else:
|
|
log.warn('Received unknown event %r', watch.type)
|
|
return
|
|
|
|
# Strip the chroot if needed
|
|
path = self.client.unchroot(path)
|
|
|
|
ev = WatchedEvent(EVENT_TYPE_MAP[watch.type], client._state, path)
|
|
|
|
# Dump the watchers to the watch thread
|
|
for watch in watchers:
|
|
client.handler.dispatch_callback(Callback('watch', watch, (ev,)))
|
|
|
|
def _read_response(self, header, buffer, offset):
|
|
client = self.client
|
|
request, async_object, xid = client._pending.get()
|
|
if header.zxid and header.zxid > 0:
|
|
client.last_zxid = header.zxid
|
|
if header.xid != xid:
|
|
raise RuntimeError('xids do not match, expected %r '
|
|
'received %r', xid, header.xid)
|
|
|
|
# Determine if its an exists request and a no node error
|
|
exists_error = header.err == NoNodeError.code and \
|
|
request.type == Exists.type
|
|
|
|
# Set the exception if its not an exists error
|
|
if header.err and not exists_error:
|
|
callback_exception = EXCEPTIONS[header.err]()
|
|
if self.log_debug:
|
|
log.debug('Received error %r', callback_exception)
|
|
if async_object:
|
|
async_object.set_exception(callback_exception)
|
|
elif request and async_object:
|
|
if exists_error:
|
|
# It's a NoNodeError, which is fine for an exists
|
|
# request
|
|
async_object.set(None)
|
|
else:
|
|
try:
|
|
response = request.deserialize(buffer, offset)
|
|
except Exception as exc:
|
|
if self.log_debug:
|
|
log.debug("Exception raised during deserialization"
|
|
" of request: %s", request)
|
|
log.exception(exc)
|
|
async_object.set_exception(exc)
|
|
return
|
|
log.debug('Received response: %r', response)
|
|
|
|
# We special case a Transaction as we have to unchroot things
|
|
if request.type == Transaction.type:
|
|
response = Transaction.unchroot(client, response)
|
|
|
|
async_object.set(response)
|
|
|
|
# Determine if watchers should be registered
|
|
watcher = getattr(request, 'watcher', None)
|
|
with client._state_lock:
|
|
if not client._stopped.is_set() and watcher:
|
|
if isinstance(request, GetChildren):
|
|
client._child_watchers[request.path].append(watcher)
|
|
else:
|
|
client._data_watchers[request.path].append(watcher)
|
|
|
|
if isinstance(request, Close):
|
|
if self.log_debug:
|
|
log.debug('Read close response')
|
|
self._socket.close()
|
|
return True
|
|
|
|
def reader(self, read_timeout):
|
|
"""Main reader function to read off the ZK connection"""
|
|
self.reader_started.set()
|
|
client = self.client
|
|
if self.log_debug:
|
|
log.debug("Reader started")
|
|
|
|
while True:
|
|
try:
|
|
header, buffer, offset = self._read_header(read_timeout)
|
|
if header.xid == PING_XID:
|
|
if self.log_debug:
|
|
log.debug('Received PING')
|
|
elif header.xid == AUTH_XID:
|
|
if self.log_debug:
|
|
log.debug('Received AUTH')
|
|
|
|
if header.err:
|
|
# We go ahead and fail out the connection, mainly because
|
|
# thats what Zookeeper client docs think is appropriate
|
|
|
|
# XXX TODO: Should we fail out? Or handle auth failure
|
|
# differently here since the session id is actually valid!
|
|
with client._state_lock:
|
|
client._session_callback(KeeperState.AUTH_FAILED)
|
|
self.reader_done.set()
|
|
break
|
|
elif header.xid == WATCH_XID:
|
|
self._read_watch_event(buffer, offset)
|
|
else:
|
|
if self.log_debug:
|
|
log.debug('Reading for header %r', header)
|
|
|
|
if self._read_response(header, buffer, offset):
|
|
# _process_response returns True if the response
|
|
# indicated we should cease
|
|
break
|
|
except ConnectionDropped as e:
|
|
if self.log_debug:
|
|
log.debug('Connection dropped for reader: %s', e)
|
|
break
|
|
except Exception as e:
|
|
log.exception(e)
|
|
break
|
|
|
|
self.reader_done.set()
|
|
|
|
if self.log_debug:
|
|
log.debug('Reader stopped')
|
|
|
|
def writer(self):
|
|
"""Main writer function that writes to the ZK connection and
|
|
handles other state management"""
|
|
if self.log_debug:
|
|
log.debug('Writer started')
|
|
|
|
self.writer_stopped.clear()
|
|
|
|
retry = self.retry_sleeper.copy()
|
|
while not self.client._stopped.is_set():
|
|
# If the connect_loop returns False, stop retrying
|
|
if self._connect_loop(retry) is False:
|
|
break
|
|
|
|
# Still going, increment our retry then go through the
|
|
# list of hosts again
|
|
if not self.client._stopped.is_set():
|
|
retry.increment()
|
|
|
|
self.writer_stopped.set()
|
|
if self.log_debug:
|
|
log.debug('Writer stopped')
|
|
|
|
def _connect_loop(self, retry):
|
|
client = self.client
|
|
writer_done = False
|
|
for host, port in client.hosts:
|
|
self._socket = self.handler.socket()
|
|
|
|
# Were we given a r/w server? If so, use that instead
|
|
if self._rw_server:
|
|
if self.log_debug:
|
|
log.debug("Found r/w server to use, %s:%s", host, port)
|
|
host, port = self._rw_server
|
|
self._rw_server = None
|
|
|
|
if client._state != KeeperState.CONNECTING:
|
|
with client._state_lock:
|
|
client._session_callback(KeeperState.CONNECTING)
|
|
|
|
try:
|
|
read_timeout, connect_timeout = self._connect(host, port)
|
|
|
|
# Now that connection is good, reset the retries
|
|
retry.reset()
|
|
|
|
# Reset the reader events and spin it up
|
|
self.reader_started.clear()
|
|
self.reader_done.clear()
|
|
self.handler.spawn(self.reader, read_timeout)
|
|
self.reader_started.wait()
|
|
self._xid = 0
|
|
|
|
while not writer_done:
|
|
writer_done = self._send_request(read_timeout,
|
|
connect_timeout)
|
|
|
|
if self.log_debug:
|
|
log.debug('Waiting for reader to read close response')
|
|
self.reader_done.wait()
|
|
if self.log_debug:
|
|
log.info('Closing connection to %s:%s', host, port)
|
|
|
|
if writer_done:
|
|
with client._state_lock:
|
|
client._session_callback(KeeperState.CLOSED)
|
|
return False
|
|
except ConnectionDropped:
|
|
log.warning('Connection dropped')
|
|
if client._state != KeeperState.CONNECTING:
|
|
with client._state_lock:
|
|
client._session_callback(KeeperState.CONNECTING)
|
|
except AuthFailedError:
|
|
log.warning('AUTH_FAILED closing')
|
|
with client._state_lock:
|
|
client._session_callback(KeeperState.AUTH_FAILED)
|
|
return False
|
|
except SessionExpiredError:
|
|
log.warning('Session has expired')
|
|
with client._state_lock:
|
|
client._session_callback(KeeperState.EXPIRED_SESSION)
|
|
except RWServerAvailable:
|
|
log.warning('Found a RW server, dropping connection')
|
|
with client._state_lock:
|
|
client._session_callback(KeeperState.CONNECTING)
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise
|
|
finally:
|
|
if not writer_done:
|
|
# The read thread will close the socket since there
|
|
# could be a number of pending requests whose response
|
|
# still needs to be read from the socket.
|
|
self._socket.close()
|
|
|
|
def _connect(self, host, port):
|
|
client = self.client
|
|
log.info('Connecting to %s:%s', host, port)
|
|
|
|
if self.log_debug:
|
|
log.debug(' Using session_id: %r session_passwd: 0x%s',
|
|
client._session_id,
|
|
client._session_passwd.encode('hex'))
|
|
|
|
with socket_error_handling():
|
|
self._socket.connect((host, port))
|
|
|
|
self._socket.setblocking(0)
|
|
|
|
connect = Connect(0, client.last_zxid, client._session_timeout,
|
|
client._session_id or 0, client._session_passwd,
|
|
client.read_only)
|
|
|
|
connect_result, zxid = self._invoke(client._session_timeout, connect)
|
|
|
|
if connect_result.time_out <= 0:
|
|
raise SessionExpiredError("Session has expired")
|
|
|
|
if zxid:
|
|
client.last_zxid = zxid
|
|
|
|
# Load return values
|
|
client._session_id = connect_result.session_id
|
|
negotiated_session_timeout = connect_result.time_out
|
|
connect_timeout = negotiated_session_timeout / len(client.hosts)
|
|
read_timeout = negotiated_session_timeout * 2.0 / 3.0
|
|
client._session_passwd = connect_result.passwd
|
|
|
|
if self.log_debug:
|
|
log.debug('Session created, session_id: %r session_passwd: 0x%s\n'
|
|
' negotiated session timeout: %s\n'
|
|
' connect timeout: %s\n'
|
|
' read timeout: %s', client._session_id,
|
|
client._session_passwd.encode('hex'),
|
|
negotiated_session_timeout, connect_timeout,
|
|
read_timeout)
|
|
|
|
if connect_result.read_only:
|
|
client._session_callback(KeeperState.CONNECTED_RO)
|
|
self._ro_mode = iter(self._server_pinger())
|
|
else:
|
|
client._session_callback(KeeperState.CONNECTED)
|
|
self._ro_mode = None
|
|
|
|
for scheme, auth in client.auth_data:
|
|
ap = Auth(0, scheme, auth)
|
|
zxid = self._invoke(connect_timeout, ap, xid=-4)
|
|
if zxid:
|
|
client.last_zxid = zxid
|
|
return read_timeout, connect_timeout
|
|
|
|
def _send_request(self, read_timeout, connect_timeout):
|
|
client = self.client
|
|
ret = None
|
|
try:
|
|
timeout = read_timeout / 2000.0 - random.randint(0, 40) / 100.0
|
|
request, async_object = client._queue.peek(True, timeout)
|
|
|
|
# Special case for auth packets
|
|
if request.type == Auth.type:
|
|
with client._state_lock:
|
|
self._submit(request, connect_timeout, AUTH_XID)
|
|
client._queue.get()
|
|
return
|
|
|
|
self._xid += 1
|
|
|
|
if self.log_debug:
|
|
log.debug('xid: %r', self._xid)
|
|
|
|
with client._state_lock:
|
|
self._submit(request, connect_timeout, self._xid)
|
|
|
|
if isinstance(request, Close):
|
|
if self.log_debug:
|
|
log.debug('Received close req, closing')
|
|
ret = True
|
|
|
|
client._queue.get()
|
|
client._pending.put((request, async_object, self._xid))
|
|
except self.handler.empty:
|
|
if self.log_debug:
|
|
log.debug('Queue timeout. Sending PING')
|
|
self._submit(Ping, connect_timeout, PING_XID)
|
|
|
|
# Determine if we need to check for a r/w server
|
|
if self._ro_mode:
|
|
result = self._ro_mode.next()
|
|
if result:
|
|
self._rw_server = result
|
|
raise RWServerAvailable()
|
|
except Exception as e:
|
|
log.exception(e)
|
|
ret = True
|
|
return ret
|