From 32e24181db5e621db4705e1c9c1838d4e3a99f45 Mon Sep 17 00:00:00 2001 From: KenjiTakahashi Date: Mon, 20 Apr 2015 00:41:42 +0200 Subject: [PATCH] Fix #175: Use builtin check_hostname when available --- websocket/_http.py | 15 ++++++++------- websocket/_ssl_compat.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/websocket/_http.py b/websocket/_http.py index 3440223..9a2f7db 100644 --- a/websocket/_http.py +++ b/websocket/_http.py @@ -122,11 +122,12 @@ def _can_use_sni(): return (six.PY2 and sys.version_info[1] >= 7 and sys.version_info[2] >= 9) or (six.PY3 and sys.version_info[2] >= 2) -def _wrap_sni_socket(sock, sslopt, hostname): - context = ssl.create_default_context(cafile=sslopt.get('ca_certs', None)) - context.options = sslopt.get('ssl_version', context.options) - context.check_hostname = sslopt.get('check_hostname', True) +def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): + context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23)) + context.load_verify_locations(cafile=sslopt.get('ca_certs', None)) context.verify_mode = sslopt['cert_reqs'] + if HAVE_CONTEXT_CHECK_HOSTNAME: + context.check_hostname = check_hostname if 'ciphers' in sslopt: context.set_ciphers(sslopt['ciphers']) @@ -145,15 +146,15 @@ def _ssl_socket(sock, user_sslopt, hostname): if os.path.isfile(certPath): sslopt['ca_certs'] = certPath sslopt.update(user_sslopt) - check_hostname = sslopt.get('check_hostname', True) + check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop('check_hostname', True) if _can_use_sni(): - sock = _wrap_sni_socket(sock, sslopt, hostname) + sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) else: sslopt.pop('check_hostname', True) sock = ssl.wrap_socket(sock, **sslopt) - if (sslopt["cert_reqs"] != ssl.CERT_NONE and check_hostname): + if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname: match_hostname(sock.getpeercert(), hostname) return sock diff --git a/websocket/_ssl_compat.py b/websocket/_ssl_compat.py index 8f2b6bd..d41ca79 100644 --- a/websocket/_ssl_compat.py +++ b/websocket/_ssl_compat.py @@ -25,11 +25,16 @@ __all__ = ["HAVE_SSL", "ssl", "SSLError"] try: import ssl from ssl import SSLError - if hasattr(ssl, "match_hostname"): - from ssl import match_hostname + if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): + HAVE_CONTEXT_CHECK_HOSTNAME = True else: - from backports.ssl_match_hostname import match_hostname - __all__.append("match_hostname") + HAVE_CONTEXT_CHECK_HOSTNAME = False + if hasattr(ssl, "match_hostname"): + from ssl import match_hostname + else: + from backports.ssl_match_hostname import match_hostname + __all__.append("match_hostname") + __all__.append("HAVE_CONTEXT_CHECK_HOSTNAME") HAVE_SSL = True except ImportError: