Memcached client TLS support
This patch specifies a set of configuration options required to build
a TLS context, which is used to wrap the client connection socket.
Closes-Bug: #1906846
Change-Id: I03a92168b90508956f367fbb60b7712f95b97f60
(cherry picked from commit 6930bc24b2
)
This commit is contained in:
parent
8e744ba24a
commit
be74f21411
@ -26,3 +26,25 @@
|
||||
# tries = 3
|
||||
# Timeout for read and writes
|
||||
# io_timeout = 2.0
|
||||
#
|
||||
# (Optional) Global toggle for TLS usage when comunicating with
|
||||
# the caching servers.
|
||||
# tls_enabled = false
|
||||
#
|
||||
# (Optional) Path to a file of concatenated CA certificates in PEM
|
||||
# format necessary to establish the caching server's authenticity.
|
||||
# If tls_enabled is False, this option is ignored.
|
||||
# tls_cafile =
|
||||
#
|
||||
# (Optional) Path to a single file in PEM format containing the
|
||||
# client's certificate as well as any number of CA certificates
|
||||
# needed to establish the certificate's authenticity. This file
|
||||
# is only required when client side authentication is necessary.
|
||||
# If tls_enabled is False, this option is ignored.
|
||||
# tls_certfile =
|
||||
#
|
||||
# (Optional) Path to a single file containing the client's private
|
||||
# key in. Otherwhise the private key will be taken from the file
|
||||
# specified in tls_certfile. If tls_enabled is False, this option
|
||||
# is ignored.
|
||||
# tls_keyfile =
|
||||
|
@ -693,6 +693,10 @@ use = egg:swift#memcache
|
||||
# Sets the maximum number of connections to each memcached server per worker
|
||||
# memcache_max_connections = 2
|
||||
#
|
||||
# (Optional) Global toggle for TLS usage when comunicating with
|
||||
# the caching servers.
|
||||
# tls_enabled =
|
||||
#
|
||||
# More options documented in memcache.conf-sample
|
||||
|
||||
[filter:ratelimit]
|
||||
|
@ -128,11 +128,12 @@ class MemcacheConnPool(Pool):
|
||||
:func:`swift.common.utils.parse_socket_string` for details.
|
||||
"""
|
||||
|
||||
def __init__(self, server, size, connect_timeout):
|
||||
def __init__(self, server, size, connect_timeout, tls_context=None):
|
||||
Pool.__init__(self, max_size=size)
|
||||
self.host, self.port = utils.parse_socket_string(
|
||||
server, DEFAULT_MEMCACHED_PORT)
|
||||
self._connect_timeout = connect_timeout
|
||||
self._tls_context = tls_context
|
||||
|
||||
def create(self):
|
||||
addrs = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC,
|
||||
@ -142,6 +143,9 @@ class MemcacheConnPool(Pool):
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
with Timeout(self._connect_timeout):
|
||||
sock.connect(sockaddr)
|
||||
if self._tls_context:
|
||||
sock = self._tls_context.wrap_socket(sock,
|
||||
server_hostname=self.host)
|
||||
return (sock.makefile('rwb'), sock)
|
||||
|
||||
def get(self):
|
||||
@ -160,7 +164,7 @@ class MemcacheRing(object):
|
||||
def __init__(self, servers, connect_timeout=CONN_TIMEOUT,
|
||||
io_timeout=IO_TIMEOUT, pool_timeout=POOL_TIMEOUT,
|
||||
tries=TRY_COUNT, allow_pickle=False, allow_unpickle=False,
|
||||
max_conns=2, logger=None):
|
||||
max_conns=2, tls_context=None, logger=None):
|
||||
self._ring = {}
|
||||
self._errors = dict(((serv, []) for serv in servers))
|
||||
self._error_limited = dict(((serv, 0) for serv in servers))
|
||||
@ -169,10 +173,10 @@ class MemcacheRing(object):
|
||||
self._ring[md5hash('%s-%s' % (server, i))] = server
|
||||
self._tries = tries if tries <= len(servers) else len(servers)
|
||||
self._sorted = sorted(self._ring)
|
||||
self._client_cache = dict(((server,
|
||||
MemcacheConnPool(server, max_conns,
|
||||
connect_timeout))
|
||||
for server in servers))
|
||||
self._client_cache = dict((
|
||||
(server, MemcacheConnPool(server, max_conns, connect_timeout,
|
||||
tls_context=tls_context))
|
||||
for server in servers))
|
||||
self._connect_timeout = connect_timeout
|
||||
self._io_timeout = io_timeout
|
||||
self._pool_timeout = pool_timeout
|
||||
|
@ -15,11 +15,12 @@
|
||||
|
||||
import os
|
||||
|
||||
from eventlet.green import ssl
|
||||
from six.moves.configparser import ConfigParser, NoSectionError, NoOptionError
|
||||
|
||||
from swift.common.memcached import (MemcacheRing, CONN_TIMEOUT, POOL_TIMEOUT,
|
||||
IO_TIMEOUT, TRY_COUNT)
|
||||
from swift.common.utils import get_logger
|
||||
from swift.common.utils import get_logger, config_true_value
|
||||
|
||||
|
||||
class MemcacheMiddleware(object):
|
||||
@ -86,6 +87,17 @@ class MemcacheMiddleware(object):
|
||||
'pool_timeout', POOL_TIMEOUT))
|
||||
tries = int(memcache_options.get('tries', TRY_COUNT))
|
||||
io_timeout = float(memcache_options.get('io_timeout', IO_TIMEOUT))
|
||||
if config_true_value(memcache_options.get('tls_enabled', 'false')):
|
||||
tls_cafile = memcache_options.get('tls_cafile')
|
||||
tls_certfile = memcache_options.get('tls_certfile')
|
||||
tls_keyfile = memcache_options.get('tls_keyfile')
|
||||
self.tls_context = ssl.create_default_context(
|
||||
cafile=tls_cafile)
|
||||
if tls_certfile:
|
||||
self.tls_context.load_cert_chain(tls_certfile,
|
||||
tls_keyfile)
|
||||
else:
|
||||
self.tls_context = None
|
||||
|
||||
if not self.memcache_servers:
|
||||
self.memcache_servers = '127.0.0.1:11211'
|
||||
@ -105,6 +117,7 @@ class MemcacheMiddleware(object):
|
||||
allow_pickle=(serialization_format == 0),
|
||||
allow_unpickle=(serialization_format <= 1),
|
||||
max_conns=max_conns,
|
||||
tls_context=self.tls_context,
|
||||
logger=self.logger)
|
||||
|
||||
def __call__(self, env, start_response):
|
||||
|
@ -17,6 +17,7 @@ import os
|
||||
from textwrap import dedent
|
||||
import unittest
|
||||
|
||||
from eventlet.green import ssl
|
||||
import mock
|
||||
from six.moves.configparser import NoSectionError, NoOptionError
|
||||
|
||||
@ -160,6 +161,22 @@ class TestCacheMiddleware(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
app.memcache._client_cache['6.7.8.9:10'].max_size, 5)
|
||||
|
||||
def test_conf_inline_tls(self):
|
||||
fake_context = mock.Mock()
|
||||
with mock.patch.object(ssl, 'create_default_context',
|
||||
return_value=fake_context):
|
||||
with mock.patch.object(memcache, 'ConfigParser',
|
||||
get_config_parser()):
|
||||
memcache.MemcacheMiddleware(
|
||||
FakeApp(),
|
||||
{'tls_enabled': 'true',
|
||||
'tls_cafile': 'cafile',
|
||||
'tls_certfile': 'certfile',
|
||||
'tls_keyfile': 'keyfile'})
|
||||
ssl.create_default_context.assert_called_with(cafile='cafile')
|
||||
fake_context.load_cert_chain.assert_called_with('certfile',
|
||||
'keyfile')
|
||||
|
||||
def test_conf_extra_no_section(self):
|
||||
with mock.patch.object(memcache, 'ConfigParser',
|
||||
get_config_parser(section='foobar')):
|
||||
@ -323,6 +340,7 @@ class TestCacheMiddleware(unittest.TestCase):
|
||||
pool_timeout = 0.5
|
||||
tries = 4
|
||||
io_timeout = 1.0
|
||||
tls_enabled = true
|
||||
"""
|
||||
config_path = os.path.join(tempdir, 'test.conf')
|
||||
with open(config_path, 'w') as f:
|
||||
@ -336,6 +354,9 @@ class TestCacheMiddleware(unittest.TestCase):
|
||||
# tries is limited to server count
|
||||
self.assertEqual(memcache_ring._tries, 4)
|
||||
self.assertEqual(memcache_ring._io_timeout, 1.0)
|
||||
self.assertIsInstance(
|
||||
list(memcache_ring._client_cache.values())[0]._tls_context,
|
||||
ssl.SSLContext)
|
||||
|
||||
@with_tempdir
|
||||
def test_real_memcache_config(self, tempdir):
|
||||
|
@ -194,6 +194,20 @@ class TestMemcached(unittest.TestCase):
|
||||
client = memcached.MemcacheRing([server_socket], logger=self.logger)
|
||||
self.assertIs(client.logger, self.logger)
|
||||
|
||||
def test_tls_context_kwarg(self):
|
||||
with patch('swift.common.memcached.socket.socket'):
|
||||
server = '%s:%s' % ('[::1]', 11211)
|
||||
client = memcached.MemcacheRing([server])
|
||||
self.assertIsNone(client._client_cache[server]._tls_context)
|
||||
|
||||
context = mock.Mock()
|
||||
client = memcached.MemcacheRing([server], tls_context=context)
|
||||
self.assertIs(client._client_cache[server]._tls_context, context)
|
||||
|
||||
key = uuid4().hex.encode('ascii')
|
||||
list(client._get_conns(key))
|
||||
context.wrap_socket.assert_called_once()
|
||||
|
||||
def test_get_conns(self):
|
||||
sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock1.bind(('127.0.0.1', 0))
|
||||
|
Loading…
Reference in New Issue
Block a user