# Copyright 2014 Mirantis Inc # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. """Thread-safe connection pool for python-memcached.""" import collections import contextlib import itertools import threading import time try: import eventlet except ImportError: eventlet = None import memcache from oslo_log import log from six.moves import queue from six.moves import zip from oslo_cache._i18n import _ from oslo_cache import exception LOG = log.getLogger(__name__) class _MemcacheClient(memcache.Client): """Thread global memcache client As client is inherited from threading.local we have to restore object methods overloaded by threading.local so we can reuse clients in different threads """ __delattr__ = object.__delattr__ __getattribute__ = object.__getattribute__ __setattr__ = object.__setattr__ # Hack for lp 1812935 if eventlet and eventlet.patcher.is_monkey_patched('thread'): # NOTE(bnemec): I'm not entirely sure why this works in a # monkey-patched environment and not with vanilla stdlib, but it does. def __new__(cls, *args, **kwargs): return object.__new__(cls) else: __new__ = object.__new__ def __del__(self): pass _PoolItem = collections.namedtuple('_PoolItem', ['ttl', 'connection']) class ConnectionPool(queue.Queue): """Base connection pool class This class implements the basic connection pool logic as an abstract base class. """ def __init__(self, maxsize, unused_timeout, conn_get_timeout=None): """Initialize the connection pool. :param maxsize: maximum number of client connections for the pool :type maxsize: int :param unused_timeout: idle time to live for unused clients (in seconds). If a client connection object has been in the pool and idle for longer than the unused_timeout, it will be reaped. This is to ensure resources are released as utilization goes down. :type unused_timeout: int :param conn_get_timeout: maximum time in seconds to wait for a connection. If set to `None` timeout is indefinite. :type conn_get_timeout: int """ # super() cannot be used here because Queue in stdlib is an # old-style class queue.Queue.__init__(self, maxsize) self._unused_timeout = unused_timeout self._connection_get_timeout = conn_get_timeout self._acquired = 0 def _create_connection(self): """Returns a connection instance. This is called when the pool needs another instance created. :returns: a new connection instance """ raise NotImplementedError def _destroy_connection(self, conn): """Destroy and cleanup a connection instance. This is called when the pool wishes to get rid of an existing connection. This is the opportunity for a subclass to free up resources and cleanup after itself. :param conn: the connection object to destroy """ raise NotImplementedError def _do_log(self, level, msg, *args, **kwargs): if LOG.isEnabledFor(level): thread_id = threading.current_thread().ident args = (id(self), thread_id) + args prefix = 'Memcached pool %s, thread %s: ' LOG.log(level, prefix + msg, *args, **kwargs) def _debug_logger(self, msg, *args, **kwargs): self._do_log(log.DEBUG, msg, *args, **kwargs) def _trace_logger(self, msg, *args, **kwargs): self._do_log(log.TRACE, msg, *args, **kwargs) @contextlib.contextmanager def acquire(self): self._trace_logger('Acquiring connection') self._drop_expired_connections() try: conn = self.get(timeout=self._connection_get_timeout) except queue.Empty: raise exception.QueueEmpty( _('Unable to get a connection from pool id %(id)s after ' '%(seconds)s seconds.') % {'id': id(self), 'seconds': self._connection_get_timeout}) self._trace_logger('Acquired connection %s', id(conn)) try: yield conn finally: self._trace_logger('Releasing connection %s', id(conn)) try: # super() cannot be used here because Queue in stdlib is an # old-style class queue.Queue.put(self, conn, block=False) except queue.Full: self._trace_logger('Reaping exceeding connection %s', id(conn)) self._destroy_connection(conn) def _qsize(self): if self.maxsize: return self.maxsize - self._acquired else: # A value indicating there is always a free connection # if maxsize is None or 0 return 1 # NOTE(dstanek): stdlib and eventlet Queue implementations # have different names for the qsize method. This ensures # that we override both of them. if not hasattr(queue.Queue, '_qsize'): qsize = _qsize def _get(self): try: conn = self.queue.pop().connection except IndexError: conn = self._create_connection() self._acquired += 1 return conn def _drop_expired_connections(self): """Drop all expired connections from the left end of the queue.""" now = time.time() try: while self.queue[0].ttl < now: conn = self.queue.popleft().connection self._trace_logger('Reaping connection %s', id(conn)) self._destroy_connection(conn) except IndexError: # NOTE(amakarov): This is an expected excepton. so there's no # need to react. We have to handle exceptions instead of # checking queue length as IndexError is a result of race # condition too as well as of mere queue depletio of mere queue # depletionn. pass def _put(self, conn): self.queue.append(_PoolItem( ttl=time.time() + self._unused_timeout, connection=conn, )) self._acquired -= 1 class MemcacheClientPool(ConnectionPool): def __init__(self, urls, arguments, **kwargs): # super() cannot be used here because Queue in stdlib is an # old-style class ConnectionPool.__init__(self, **kwargs) self.urls = urls self._arguments = arguments # NOTE(morganfainberg): The host objects expect an int for the # deaduntil value. Initialize this at 0 for each host with 0 indicating # the host is not dead. self._hosts_deaduntil = [0] * len(urls) def _create_connection(self): # NOTE(morgan): Explicitly set flush_on_reconnect for pooled # connections. This should ensure that stale data is never consumed # from a server that pops in/out due to a network partition # or disconnect. # # See the help from python-memcached: # # param flush_on_reconnect: optional flag which prevents a # scenario that can cause stale data to be read: If there's more # than one memcached server and the connection to one is # interrupted, keys that mapped to that server will get # reassigned to another. If the first server comes back, those # keys will map to it again. If it still has its data, get()s # can read stale data that was overwritten on another # server. This flag is off by default for backwards # compatibility. # # The normal non-pooled clients connect explicitly on each use and # does not need the explicit flush_on_reconnect return _MemcacheClient(self.urls, flush_on_reconnect=True, **self._arguments) def _destroy_connection(self, conn): conn.disconnect_all() def _get(self): # super() cannot be used here because Queue in stdlib is an # old-style class conn = ConnectionPool._get(self) try: # Propagate host state known to us to this client's list now = time.time() for deaduntil, host in zip(self._hosts_deaduntil, conn.servers): if deaduntil > now and host.deaduntil <= now: host.mark_dead('propagating death mark from the pool') host.deaduntil = deaduntil except Exception: # We need to be sure that connection doesn't leak from the pool. # This code runs before we enter context manager's try-finally # block, so we need to explicitly release it here. # super() cannot be used here because Queue in stdlib is an # old-style class ConnectionPool._put(self, conn) raise return conn def _put(self, conn): try: # If this client found that one of the hosts is dead, mark it as # such in our internal list now = time.time() for i, host in zip(itertools.count(), conn.servers): deaduntil = self._hosts_deaduntil[i] # Do nothing if we already know this host is dead if deaduntil <= now: if host.deaduntil > now: self._hosts_deaduntil[i] = host.deaduntil self._debug_logger( 'Marked host %s dead until %s', self.urls[i], host.deaduntil) else: self._hosts_deaduntil[i] = 0 # If all hosts are dead we should forget that they're dead. This # way we won't get completely shut off until dead_retry seconds # pass, but will be checking servers as frequent as we can (over # way smaller socket_timeout) if all(deaduntil > now for deaduntil in self._hosts_deaduntil): self._debug_logger('All hosts are dead. Marking them as live.') self._hosts_deaduntil[:] = [0] * len(self._hosts_deaduntil) finally: # super() cannot be used here because Queue in stdlib is an # old-style class ConnectionPool._put(self, conn)