wsgi: websocket: Reformat code + tests (PEP-8)

This commit is contained in:
Jakub Stasiak
2014-07-13 20:43:33 +01:00
committed by Sergey Shepelev
parent 6afd8bdee2
commit 99f4f18c33
5 changed files with 278 additions and 265 deletions

View File

@@ -11,7 +11,7 @@ from socket import error as SocketError
try: try:
from hashlib import md5, sha1 from hashlib import md5, sha1
except ImportError: #pragma NO COVER except ImportError: # pragma NO COVER
from md5 import md5 from md5 import md5
from sha import sha as sha1 from sha import sha as sha1
@@ -72,6 +72,7 @@ class WebSocketWSGI(object):
function. Note that the server will log the websocket request at function. Note that the server will log the websocket request at
the time of closure. the time of closure.
""" """
def __init__(self, handler): def __init__(self, handler):
self.handler = handler self.handler = handler
self.protocol_version = None self.protocol_version = None
@@ -185,7 +186,7 @@ class WebSocketWSGI(object):
environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default'), environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default'),
location, location,
response)) response))
else: #pragma NO COVER else: # pragma NO COVER
raise ValueError("Unknown WebSocket protocol version.") raise ValueError("Unknown WebSocket protocol version.")
sock.sendall(handshake_reply) sock.sendall(handshake_reply)
return WebSocket(sock, environ, self.protocol_version) return WebSocket(sock, environ, self.protocol_version)
@@ -215,7 +216,7 @@ class WebSocketWSGI(object):
negotiated_protocol = p negotiated_protocol = p
break break
#extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS', None) #extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS', None)
#if extensions: # if extensions:
# extensions = [i.strip() for i in extensions.split(',')] # extensions = [i.strip() for i in extensions.split(',')]
key = environ['HTTP_SEC_WEBSOCKET_KEY'] key = environ['HTTP_SEC_WEBSOCKET_KEY']
@@ -245,6 +246,7 @@ class WebSocketWSGI(object):
spaces += 1 spaces += 1
return int(out) / spaces return int(out) / spaces
class WebSocket(object): class WebSocket(object):
"""A websocket object that handles the details of """A websocket object that handles the details of
serialization/deserialization to the socket. serialization/deserialization to the socket.
@@ -264,6 +266,7 @@ class WebSocket(object):
The full WSGI environment for this request. The full WSGI environment for this request.
""" """
def __init__(self, sock, environ, version=76): def __init__(self, sock, environ, version=76):
""" """
:param socket: The eventlet socket :param socket: The eventlet socket
@@ -310,10 +313,10 @@ class WebSocket(object):
if frame_type == 0: if frame_type == 0:
# Normal message. # Normal message.
end_idx = buf.find("\xFF") end_idx = buf.find("\xFF")
if end_idx == -1: #pragma NO COVER if end_idx == -1: # pragma NO COVER
break break
msgs.append(buf[1:end_idx].decode('utf-8', 'replace')) msgs.append(buf[1:end_idx].decode('utf-8', 'replace'))
buf = buf[end_idx+1:] buf = buf[end_idx + 1:]
elif frame_type == 255: elif frame_type == 255:
# Closing handshake. # Closing handshake.
assert ord(buf[1]) == 0, "Unexpected closing handshake: %r" % buf assert ord(buf[1]) == 0, "Unexpected closing handshake: %r" % buf
@@ -367,7 +370,7 @@ class WebSocket(object):
except SocketError: except SocketError:
# Sometimes, like when the remote side cuts off the connection, # Sometimes, like when the remote side cuts off the connection,
# we don't care about this. # we don't care about this.
if not ignore_send_errors: #pragma NO COVER if not ignore_send_errors: # pragma NO COVER
raise raise
self.websocket_closed = True self.websocket_closed = True

View File

@@ -21,7 +21,7 @@ MAX_HEADER_LINE = 8192
MAX_TOTAL_HEADER_SIZE = 65536 MAX_TOTAL_HEADER_SIZE = 65536
MINIMUM_CHUNK_SIZE = 4096 MINIMUM_CHUNK_SIZE = 4096
# %(client_port)s is also available # %(client_port)s is also available
DEFAULT_LOG_FORMAT= ('%(client_ip)s - - [%(date_time)s] "%(request_line)s"' DEFAULT_LOG_FORMAT = ('%(client_ip)s - - [%(date_time)s] "%(request_line)s"'
' %(status_code)s %(body_length)s %(wall_seconds).6f') ' %(status_code)s %(body_length)s %(wall_seconds).6f')
__all__ = ['server', 'format_date_time'] __all__ = ['server', 'format_date_time']
@@ -85,7 +85,7 @@ class Input(object):
def _do_read(self, reader, length=None): def _do_read(self, reader, length=None):
if self.wfile is not None: if self.wfile is not None:
## 100 Continue # 100 Continue
self.wfile.write(self.wfile_line) self.wfile.write(self.wfile_line)
self.wfile = None self.wfile = None
self.wfile_line = None self.wfile_line = None
@@ -105,7 +105,7 @@ class Input(object):
def _chunked_read(self, rfile, length=None, use_readline=False): def _chunked_read(self, rfile, length=None, use_readline=False):
if self.wfile is not None: if self.wfile is not None:
## 100 Continue # 100 Continue
self.wfile.write(self.wfile_line) self.wfile.write(self.wfile_line)
self.wfile = None self.wfile = None
self.wfile_line = None self.wfile_line = None
@@ -221,7 +221,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
self.wfile = socket._fileobject(conn, "wb", self.wbufsize) self.wfile = socket._fileobject(conn, "wb", self.wbufsize)
else: else:
# it's a SSLObject, or a martian # it's a SSLObject, or a martian
raise NotImplementedError("wsgi.py doesn't support sockets "\ raise NotImplementedError("wsgi.py doesn't support sockets "
"of type %s" % type(conn)) "of type %s" % type(conn))
def handle_one_request(self): def handle_one_request(self):
@@ -325,7 +325,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
client_conn = self.headers.get('Connection', '').lower() client_conn = self.headers.get('Connection', '').lower()
send_keep_alive = False send_keep_alive = False
if self.close_connection == 0 and \ if self.close_connection == 0 and \
self.server.keepalive and (client_conn == 'keep-alive' or \ self.server.keepalive and (client_conn == 'keep-alive' or
(self.request_version == 'HTTP/1.1' and (self.request_version == 'HTTP/1.1' and
not client_conn == 'close')): not client_conn == 'close')):
# only send keep-alives back to clients that sent them, # only send keep-alives back to clients that sent them,
@@ -351,7 +351,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
# end of header writing # end of header writing
if use_chunked[0]: if use_chunked[0]:
## Write the chunked encoding # Write the chunked encoding
towrite.append("%x\r\n%s\r\n" % (len(data), data)) towrite.append("%x\r\n%s\r\n" % (len(data), data))
else: else:
towrite.append(data) towrite.append(data)
@@ -359,7 +359,8 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
_writelines(towrite) _writelines(towrite)
length[0] = length[0] + sum(map(len, towrite)) length[0] = length[0] + sum(map(len, towrite))
except UnicodeEncodeError: except UnicodeEncodeError:
self.server.log_message("Encountered non-ascii unicode while attempting to write wsgi response: %r" % [x for x in towrite if isinstance(x, six.text_type)]) self.server.log_message(
"Encountered non-ascii unicode while attempting to write wsgi response: %r" % [x for x in towrite if isinstance(x, six.text_type)])
self.server.log_message(traceback.format_exc()) self.server.log_message(traceback.format_exc())
_writelines( _writelines(
["HTTP/1.1 500 Internal Server Error\r\n", ["HTTP/1.1 500 Internal Server Error\r\n",
@@ -441,9 +442,9 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
if hasattr(result, 'close'): if hasattr(result, 'close'):
result.close() result.close()
if (self.environ['eventlet.input'].chunked_input or if (self.environ['eventlet.input'].chunked_input or
self.environ['eventlet.input'].position \ self.environ['eventlet.input'].position
< self.environ['eventlet.input'].content_length): < self.environ['eventlet.input'].content_length):
## Read and discard body if there was no pending 100-continue # Read and discard body if there was no pending 100-continue
if not self.environ['eventlet.input'].wfile: if not self.environ['eventlet.input'].wfile:
# NOTE: MINIMUM_CHUNK_SIZE is used here for purpose different than chunking. # NOTE: MINIMUM_CHUNK_SIZE is used here for purpose different than chunking.
# We use it only cause it's at hand and has reasonable value in terms of # We use it only cause it's at hand and has reasonable value in terms of
@@ -740,8 +741,8 @@ def server(sock, site,
try: try:
pool.spawn_n(serv.process_request, client_socket) pool.spawn_n(serv.process_request, client_socket)
except AttributeError: except AttributeError:
warnings.warn("wsgi's pool should be an instance of " \ warnings.warn("wsgi's pool should be an instance of "
"eventlet.greenpool.GreenPool, is %s. Please convert your"\ "eventlet.greenpool.GreenPool, is %s. Please convert your"
" call site to use GreenPool instead" % type(pool), " call site to use GreenPool instead" % type(pool),
DeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
pool.execute_async(serv.process_request, client_socket) pool.execute_async(serv.process_request, client_socket)

View File

@@ -81,7 +81,6 @@ class TestWebSocket(_TestBase):
self.assertEqual(resp.getheader('connection'), 'close') self.assertEqual(resp.getheader('connection'), 'close')
self.assertEqual(resp.read(), '') self.assertEqual(resp.read(), '')
def test_correct_upgrade_request_13(self): def test_correct_upgrade_request_13(self):
for http_connection in ['Upgrade', 'UpGrAdE', 'keep-alive, Upgrade']: for http_connection in ['Upgrade', 'UpGrAdE', 'keep-alive, Upgrade']:
connect = [ connect = [
@@ -97,7 +96,7 @@ class TestWebSocket(_TestBase):
sock.sendall('\r\n'.join(connect) + '\r\n\r\n') sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
result = sock.recv(1024) result = sock.recv(1024)
## The server responds the correct Websocket handshake # The server responds the correct Websocket handshake
print('Connection string: %r' % http_connection) print('Connection string: %r' % http_connection)
self.assertEqual(result, '\r\n'.join([ self.assertEqual(result, '\r\n'.join([
'HTTP/1.1 101 Switching Protocols', 'HTTP/1.1 101 Switching Protocols',
@@ -134,6 +133,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -165,6 +165,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -197,6 +198,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:

View File

@@ -1,13 +1,13 @@
import socket
import errno import errno
import socket
import eventlet import eventlet
from eventlet.green import urllib2
from eventlet.green import httplib
from eventlet.websocket import WebSocket, WebSocketWSGI
from eventlet import wsgi
from eventlet import event from eventlet import event
from eventlet import greenio from eventlet import greenio
from eventlet import wsgi
from eventlet.green import httplib
from eventlet.green import urllib2
from eventlet.websocket import WebSocket, WebSocketWSGI
from tests import mock, LimitedTestCase, certificate_file, private_key_file from tests import mock, LimitedTestCase, certificate_file, private_key_file
from tests import skip_if_no_ssl from tests import skip_if_no_ssl
@@ -34,6 +34,7 @@ def handle(ws):
wsapp = WebSocketWSGI(handle) wsapp = WebSocketWSGI(handle)
class TestWebSocket(_TestBase): class TestWebSocket(_TestBase):
TEST_TIMEOUT = 5 TEST_TIMEOUT = 5
@@ -114,7 +115,7 @@ class TestWebSocket(_TestBase):
sock.sendall('\r\n'.join(connect) + '\r\n\r\n') sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
result = sock.recv(1024) result = sock.recv(1024)
## The server responds the correct Websocket handshake # The server responds the correct Websocket handshake
self.assertEqual(result, self.assertEqual(result,
'\r\n'.join(['HTTP/1.1 101 Web Socket Protocol Handshake', '\r\n'.join(['HTTP/1.1 101 Web Socket Protocol Handshake',
'Upgrade: WebSocket', 'Upgrade: WebSocket',
@@ -138,7 +139,7 @@ class TestWebSocket(_TestBase):
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U')
result = sock.recv(1024) result = sock.recv(1024)
## The server responds the correct Websocket handshake # The server responds the correct Websocket handshake
self.assertEqual(result, self.assertEqual(result,
'\r\n'.join(['HTTP/1.1 101 WebSocket Protocol Handshake', '\r\n'.join(['HTTP/1.1 101 WebSocket Protocol Handshake',
'Upgrade: WebSocket', 'Upgrade: WebSocket',
@@ -147,7 +148,6 @@ class TestWebSocket(_TestBase):
'Sec-WebSocket-Protocol: ws', 'Sec-WebSocket-Protocol: ws',
'Sec-WebSocket-Location: ws://localhost:%s/echo\r\n\r\n8jKS\'y:G*Co,Wxa-' % self.port])) 'Sec-WebSocket-Location: ws://localhost:%s/echo\r\n\r\n8jKS\'y:G*Co,Wxa-' % self.port]))
def test_query_string(self): def test_query_string(self):
# verify that the query string comes out the other side unscathed # verify that the query string comes out the other side unscathed
connect = [ connect = [
@@ -190,14 +190,14 @@ class TestWebSocket(_TestBase):
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U')
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, self.assertEqual(result, '\r\n'.join([
'\r\n'.join(['HTTP/1.1 101 WebSocket Protocol Handshake', 'HTTP/1.1 101 WebSocket Protocol Handshake',
'Upgrade: WebSocket', 'Upgrade: WebSocket',
'Connection: Upgrade', 'Connection: Upgrade',
'Sec-WebSocket-Origin: http://localhost:%s' % self.port, 'Sec-WebSocket-Origin: http://localhost:%s' % self.port,
'Sec-WebSocket-Protocol: ws', 'Sec-WebSocket-Protocol: ws',
'Sec-WebSocket-Location: ws://localhost:%s/echo?\r\n\r\n8jKS\'y:G*Co,Wxa-' % self.port])) 'Sec-WebSocket-Location: ws://localhost:%s/echo?\r\n\r\n8jKS\'y:G*Co,Wxa-' % self.port,
]))
def test_sending_messages_to_websocket_75(self): def test_sending_messages_to_websocket_75(self):
connect = [ connect = [
@@ -305,6 +305,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -336,6 +337,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -369,6 +371,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -402,6 +405,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -455,6 +459,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -485,6 +490,7 @@ class TestWebSocket(_TestBase):
error_detected = [False] error_detected = [False]
done_with_request = event.Event() done_with_request = event.Event()
site = self.site site = self.site
def error_detector(environ, start_response): def error_detector(environ, start_response):
try: try:
try: try:
@@ -558,7 +564,6 @@ class TestWebSocketSSL(_TestBase):
eventlet.sleep(0.01) eventlet.sleep(0.01)
class TestWebSocketObject(LimitedTestCase): class TestWebSocketObject(LimitedTestCase):
def setUp(self): def setUp(self):
@@ -580,7 +585,6 @@ class TestWebSocketObject(LimitedTestCase):
self.assertEqual(ws._buf, '') self.assertEqual(ws._buf, '')
self.assertEqual(len(ws._msgs), 0) self.assertEqual(len(ws._msgs), 0)
def test_send_to_ws(self): def test_send_to_ws(self):
ws = self.test_ws ws = self.test_ws
ws.send(u'hello') ws.send(u'hello')

View File

@@ -262,9 +262,9 @@ class TestHttpd(_TestBase):
fd.flush() fd.flush()
result = fd.read() result = fd.read()
fd.close() fd.close()
## The server responds with the maximum version it supports # The server responds with the maximum version it supports
assert result.startswith('HTTP'), result assert result.startswith('HTTP'), result
assert result.endswith('hello world') assert result.endswith('hello world'), result
def test_002_keepalive(self): def test_002_keepalive(self):
sock = eventlet.connect( sock = eventlet.connect(
@@ -445,7 +445,8 @@ class TestHttpd(_TestBase):
sock = eventlet.connect(('localhost', self.port)) sock = eventlet.connect(('localhost', self.port))
sock = eventlet.wrap_ssl(sock) sock = eventlet.wrap_ssl(sock)
sock.write(b'POST /foo HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nContent-length:3\r\n\r\nabc') sock.write(
b'POST /foo HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nContent-length:3\r\n\r\nabc')
result = sock.read(8192) result = sock.read(8192)
self.assertEqual(result[-3:], 'abc') self.assertEqual(result[-3:], 'abc')
@@ -750,7 +751,8 @@ class TestHttpd(_TestBase):
result = read_http(sock) result = read_http(sock)
self.assertEqual(result.status, 'HTTP/1.1 417 Expectation Failed') self.assertEqual(result.status, 'HTTP/1.1 417 Expectation Failed')
self.assertEqual(result.body, 'failure') self.assertEqual(result.body, 'failure')
fd.write(b'PUT / HTTP/1.1\r\nHost: localhost\r\nContent-length: 7\r\nExpect: 100-continue\r\n\r\ntesting') fd.write(
b'PUT / HTTP/1.1\r\nHost: localhost\r\nContent-length: 7\r\nExpect: 100-continue\r\n\r\ntesting')
fd.flush() fd.flush()
header_lines = [] header_lines = []
while True: while True:
@@ -1448,7 +1450,8 @@ class TestChunkedInput(_TestBase):
def test_chunked_readline(self): def test_chunked_readline(self):
body = self.body() body = self.body()
req = "POST /lines HTTP/1.1\r\nContent-Length: %s\r\ntransfer-encoding: Chunked\r\n\r\n%s" % (len(body), body) req = "POST /lines HTTP/1.1\r\nContent-Length: %s\r\ntransfer-encoding: Chunked\r\n\r\n%s" % (
len(body), body)
fd = self.connect() fd = self.connect()
fd.sendall(req.encode()) fd.sendall(req.encode())