Merge "Memcached client TLS support"
This commit is contained in:
		@@ -34,3 +34,25 @@
 | 
			
		||||
#
 | 
			
		||||
# How many errors can accumulate before a server is temporarily ignored.
 | 
			
		||||
# error_suppression_limit = 10
 | 
			
		||||
#
 | 
			
		||||
# (Optional) Global toggle for TLS usage when comunicating with
 | 
			
		||||
# the caching servers.
 | 
			
		||||
# tls_enabled = false
 | 
			
		||||
#
 | 
			
		||||
# (Optional) Path to a file of concatenated CA certificates in PEM
 | 
			
		||||
# format necessary to establish the caching server's authenticity.
 | 
			
		||||
# If tls_enabled is False, this option is ignored.
 | 
			
		||||
# tls_cafile =
 | 
			
		||||
#
 | 
			
		||||
# (Optional) Path to a single file in PEM format containing the
 | 
			
		||||
# client's certificate as well as any number of CA certificates
 | 
			
		||||
# needed to establish the certificate's authenticity. This file
 | 
			
		||||
# is only required when client side authentication is necessary.
 | 
			
		||||
# If tls_enabled is False, this option is ignored.
 | 
			
		||||
# tls_certfile =
 | 
			
		||||
#
 | 
			
		||||
# (Optional) Path to a single file containing the client's private
 | 
			
		||||
# key in. Otherwhise the private key will be taken from the file
 | 
			
		||||
# specified in tls_certfile. If tls_enabled is False, this option
 | 
			
		||||
# is ignored.
 | 
			
		||||
# tls_keyfile =
 | 
			
		||||
 
 | 
			
		||||
@@ -712,6 +712,10 @@ use = egg:swift#memcache
 | 
			
		||||
# How many errors can accumulate before a server is temporarily ignored.
 | 
			
		||||
# error_suppression_limit = 10
 | 
			
		||||
#
 | 
			
		||||
# (Optional) Global toggle for TLS usage when comunicating with
 | 
			
		||||
# the caching servers.
 | 
			
		||||
# tls_enabled =
 | 
			
		||||
#
 | 
			
		||||
# More options documented in memcache.conf-sample
 | 
			
		||||
 | 
			
		||||
[filter:ratelimit]
 | 
			
		||||
 
 | 
			
		||||
@@ -127,11 +127,12 @@ class MemcacheConnPool(Pool):
 | 
			
		||||
    :func:`swift.common.utils.parse_socket_string` for details.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, server, size, connect_timeout):
 | 
			
		||||
    def __init__(self, server, size, connect_timeout, tls_context=None):
 | 
			
		||||
        Pool.__init__(self, max_size=size)
 | 
			
		||||
        self.host, self.port = utils.parse_socket_string(
 | 
			
		||||
            server, DEFAULT_MEMCACHED_PORT)
 | 
			
		||||
        self._connect_timeout = connect_timeout
 | 
			
		||||
        self._tls_context = tls_context
 | 
			
		||||
 | 
			
		||||
    def create(self):
 | 
			
		||||
        addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
 | 
			
		||||
@@ -141,6 +142,9 @@ class MemcacheConnPool(Pool):
 | 
			
		||||
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 | 
			
		||||
        with Timeout(self._connect_timeout):
 | 
			
		||||
            sock.connect(sockaddr)
 | 
			
		||||
        if self._tls_context:
 | 
			
		||||
            sock = self._tls_context.wrap_socket(sock,
 | 
			
		||||
                                                 server_hostname=self.host)
 | 
			
		||||
        return (sock.makefile('rwb'), sock)
 | 
			
		||||
 | 
			
		||||
    def get(self):
 | 
			
		||||
@@ -159,7 +163,7 @@ class MemcacheRing(object):
 | 
			
		||||
    def __init__(self, servers, connect_timeout=CONN_TIMEOUT,
 | 
			
		||||
                 io_timeout=IO_TIMEOUT, pool_timeout=POOL_TIMEOUT,
 | 
			
		||||
                 tries=TRY_COUNT, allow_pickle=False, allow_unpickle=False,
 | 
			
		||||
                 max_conns=2, logger=None,
 | 
			
		||||
                 max_conns=2, tls_context=None, logger=None,
 | 
			
		||||
                 error_limit_count=ERROR_LIMIT_COUNT,
 | 
			
		||||
                 error_limit_time=ERROR_LIMIT_TIME,
 | 
			
		||||
                 error_limit_duration=ERROR_LIMIT_DURATION):
 | 
			
		||||
@@ -174,10 +178,10 @@ class MemcacheRing(object):
 | 
			
		||||
                self._ring[md5hash('%s-%s' % (server, i))] = server
 | 
			
		||||
        self._tries = tries if tries <= len(servers) else len(servers)
 | 
			
		||||
        self._sorted = sorted(self._ring)
 | 
			
		||||
        self._client_cache = dict(((server,
 | 
			
		||||
                                    MemcacheConnPool(server, max_conns,
 | 
			
		||||
                                                     connect_timeout))
 | 
			
		||||
                                  for server in servers))
 | 
			
		||||
        self._client_cache = dict((
 | 
			
		||||
            (server, MemcacheConnPool(server, max_conns, connect_timeout,
 | 
			
		||||
                                      tls_context=tls_context))
 | 
			
		||||
            for server in servers))
 | 
			
		||||
        self._connect_timeout = connect_timeout
 | 
			
		||||
        self._io_timeout = io_timeout
 | 
			
		||||
        self._pool_timeout = pool_timeout
 | 
			
		||||
 
 | 
			
		||||
@@ -15,12 +15,13 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from eventlet.green import ssl
 | 
			
		||||
from six.moves.configparser import ConfigParser, NoSectionError, NoOptionError
 | 
			
		||||
 | 
			
		||||
from swift.common.memcached import (
 | 
			
		||||
    MemcacheRing, CONN_TIMEOUT, POOL_TIMEOUT, IO_TIMEOUT, TRY_COUNT,
 | 
			
		||||
    ERROR_LIMIT_COUNT, ERROR_LIMIT_TIME)
 | 
			
		||||
from swift.common.utils import get_logger
 | 
			
		||||
from swift.common.utils import get_logger, config_true_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MemcacheMiddleware(object):
 | 
			
		||||
@@ -87,6 +88,17 @@ class MemcacheMiddleware(object):
 | 
			
		||||
            'pool_timeout', POOL_TIMEOUT))
 | 
			
		||||
        tries = int(memcache_options.get('tries', TRY_COUNT))
 | 
			
		||||
        io_timeout = float(memcache_options.get('io_timeout', IO_TIMEOUT))
 | 
			
		||||
        if config_true_value(memcache_options.get('tls_enabled', 'false')):
 | 
			
		||||
            tls_cafile = memcache_options.get('tls_cafile')
 | 
			
		||||
            tls_certfile = memcache_options.get('tls_certfile')
 | 
			
		||||
            tls_keyfile = memcache_options.get('tls_keyfile')
 | 
			
		||||
            self.tls_context = ssl.create_default_context(
 | 
			
		||||
                cafile=tls_cafile)
 | 
			
		||||
            if tls_certfile:
 | 
			
		||||
                self.tls_context.load_cert_chain(tls_certfile,
 | 
			
		||||
                                                 tls_keyfile)
 | 
			
		||||
        else:
 | 
			
		||||
            self.tls_context = None
 | 
			
		||||
        error_suppression_interval = float(memcache_options.get(
 | 
			
		||||
            'error_suppression_interval', ERROR_LIMIT_TIME))
 | 
			
		||||
        error_suppression_limit = float(memcache_options.get(
 | 
			
		||||
@@ -110,6 +122,7 @@ class MemcacheMiddleware(object):
 | 
			
		||||
            allow_pickle=(serialization_format == 0),
 | 
			
		||||
            allow_unpickle=(serialization_format <= 1),
 | 
			
		||||
            max_conns=max_conns,
 | 
			
		||||
            tls_context=self.tls_context,
 | 
			
		||||
            logger=self.logger,
 | 
			
		||||
            error_limit_count=error_suppression_limit,
 | 
			
		||||
            error_limit_time=error_suppression_interval,
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,7 @@ import os
 | 
			
		||||
from textwrap import dedent
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from eventlet.green import ssl
 | 
			
		||||
import mock
 | 
			
		||||
from six.moves.configparser import NoSectionError, NoOptionError
 | 
			
		||||
 | 
			
		||||
@@ -170,6 +171,22 @@ class TestCacheMiddleware(unittest.TestCase):
 | 
			
		||||
        self.assertEqual(app.memcache._error_limit_time, 2.5)
 | 
			
		||||
        self.assertEqual(app.memcache._error_limit_duration, 2.5)
 | 
			
		||||
 | 
			
		||||
    def test_conf_inline_tls(self):
 | 
			
		||||
        fake_context = mock.Mock()
 | 
			
		||||
        with mock.patch.object(ssl, 'create_default_context',
 | 
			
		||||
                               return_value=fake_context):
 | 
			
		||||
            with mock.patch.object(memcache, 'ConfigParser',
 | 
			
		||||
                                   get_config_parser()):
 | 
			
		||||
                memcache.MemcacheMiddleware(
 | 
			
		||||
                    FakeApp(),
 | 
			
		||||
                    {'tls_enabled': 'true',
 | 
			
		||||
                     'tls_cafile': 'cafile',
 | 
			
		||||
                     'tls_certfile': 'certfile',
 | 
			
		||||
                     'tls_keyfile': 'keyfile'})
 | 
			
		||||
            ssl.create_default_context.assert_called_with(cafile='cafile')
 | 
			
		||||
            fake_context.load_cert_chain.assert_called_with('certfile',
 | 
			
		||||
                                                            'keyfile')
 | 
			
		||||
 | 
			
		||||
    def test_conf_extra_no_section(self):
 | 
			
		||||
        with mock.patch.object(memcache, 'ConfigParser',
 | 
			
		||||
                               get_config_parser(section='foobar')):
 | 
			
		||||
@@ -333,6 +350,7 @@ class TestCacheMiddleware(unittest.TestCase):
 | 
			
		||||
        pool_timeout = 0.5
 | 
			
		||||
        tries = 4
 | 
			
		||||
        io_timeout = 1.0
 | 
			
		||||
        tls_enabled = true
 | 
			
		||||
        """
 | 
			
		||||
        config_path = os.path.join(tempdir, 'test.conf')
 | 
			
		||||
        with open(config_path, 'w') as f:
 | 
			
		||||
@@ -349,6 +367,9 @@ class TestCacheMiddleware(unittest.TestCase):
 | 
			
		||||
        self.assertEqual(memcache_ring._error_limit_count, 10)
 | 
			
		||||
        self.assertEqual(memcache_ring._error_limit_time, 60)
 | 
			
		||||
        self.assertEqual(memcache_ring._error_limit_duration, 60)
 | 
			
		||||
        self.assertIsInstance(
 | 
			
		||||
            list(memcache_ring._client_cache.values())[0]._tls_context,
 | 
			
		||||
            ssl.SSLContext)
 | 
			
		||||
 | 
			
		||||
    @with_tempdir
 | 
			
		||||
    def test_real_memcache_config(self, tempdir):
 | 
			
		||||
 
 | 
			
		||||
@@ -198,6 +198,20 @@ class TestMemcached(unittest.TestCase):
 | 
			
		||||
        client = memcached.MemcacheRing([server_socket], logger=self.logger)
 | 
			
		||||
        self.assertIs(client.logger, self.logger)
 | 
			
		||||
 | 
			
		||||
    def test_tls_context_kwarg(self):
 | 
			
		||||
        with patch('swift.common.memcached.socket.socket'):
 | 
			
		||||
            server = '%s:%s' % ('[::1]', 11211)
 | 
			
		||||
            client = memcached.MemcacheRing([server])
 | 
			
		||||
            self.assertIsNone(client._client_cache[server]._tls_context)
 | 
			
		||||
 | 
			
		||||
            context = mock.Mock()
 | 
			
		||||
            client = memcached.MemcacheRing([server], tls_context=context)
 | 
			
		||||
            self.assertIs(client._client_cache[server]._tls_context, context)
 | 
			
		||||
 | 
			
		||||
            key = uuid4().hex.encode('ascii')
 | 
			
		||||
            list(client._get_conns(key))
 | 
			
		||||
            context.wrap_socket.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    def test_get_conns(self):
 | 
			
		||||
        sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | 
			
		||||
        sock1.bind(('127.0.0.1', 0))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user