Python 3 compat; Improve WSGI, WS, threading and tests

This includes:

* patching more tests to pass
* removing few unit tests which I think are redundant
* repeating SSL socket reads in a loop to read all data (I suspect this is
  related to the fact that writelines is used in the server code there and
  Python 3 writelines calls write/send repeatedly while on Python 2 it
  calls it once; on one hand there's no guarantee that single recv/read
  will return all data sent by the server, on the other hand it's quite
  suspicious that the number of required reads seems to be connected to
  the number of sends on the other side of the connection)
* working through Python 2/Python 3 threading and thread differences; the
  lock code I used is the simplest way I could make the tests pass but
  will likely need to be modified in order to match the original

This commit includes 6bcb1dc and closes GH #153
This commit is contained in:
Jakub Stasiak
2014-11-11 22:38:30 +00:00
parent 3870f6b055
commit 23beb7d43e
7 changed files with 129 additions and 137 deletions

View File

@@ -1,6 +1,6 @@
"""Implements the standard thread module, using greenthreads.""" """Implements the standard thread module, using greenthreads."""
from eventlet.support.six.moves import _thread as __thread from eventlet.support.six.moves import _thread as __thread
from eventlet.support import greenlets as greenlet from eventlet.support import greenlets as greenlet, six
from eventlet import greenthread from eventlet import greenthread
from eventlet.semaphore import Semaphore as LockType from eventlet.semaphore import Semaphore as LockType
@@ -13,6 +13,15 @@ error = __thread.error
__threadcount = 0 __threadcount = 0
if six.PY3:
def _set_sentinel():
# TODO this is a dummy code, reimplementing this may be needed:
# https://hg.python.org/cpython/file/b5e9bc4352e1/Modules/_threadmodule.c#l1203
return allocate_lock()
TIMEOUT_MAX = __thread.TIMEOUT_MAX
def _count(): def _count():
return __threadcount return __threadcount

View File

@@ -2,12 +2,17 @@
from eventlet import patcher from eventlet import patcher
from eventlet.green import thread from eventlet.green import thread
from eventlet.green import time from eventlet.green import time
from eventlet.support import greenlets as greenlet from eventlet.support import greenlets as greenlet, six
__patched__ = ['_start_new_thread', '_allocate_lock', '_get_ident', '_sleep', __patched__ = ['_start_new_thread', '_allocate_lock',
'local', 'stack_size', 'Lock', 'currentThread', '_sleep', 'local', 'stack_size', 'Lock', 'currentThread',
'current_thread', '_after_fork', '_shutdown'] 'current_thread', '_after_fork', '_shutdown']
if six.PY2:
__patched__ += ['_get_ident']
else:
__patched__ += ['get_ident', '_set_sentinel']
__orig_threading = patcher.original('threading') __orig_threading = patcher.original('threading')
__threadlocal = __orig_threading.local() __threadlocal = __orig_threading.local()
@@ -15,7 +20,7 @@ __threadlocal = __orig_threading.local()
patcher.inject( patcher.inject(
'threading', 'threading',
globals(), globals(),
('thread', thread), ('thread' if six.PY2 else '_thread', thread),
('time', time)) ('time', time))
del patcher del patcher

View File

@@ -168,25 +168,24 @@ class WebSocketWSGI(object):
if qs is not None: if qs is not None:
location += '?' + qs location += '?' + qs
if self.protocol_version == 75: if self.protocol_version == 75:
handshake_reply = ("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" handshake_reply = (
"Upgrade: WebSocket\r\n" b"HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
"Connection: Upgrade\r\n" b"Upgrade: WebSocket\r\n"
"WebSocket-Origin: %s\r\n" b"Connection: Upgrade\r\n"
"WebSocket-Location: %s\r\n\r\n" % ( b"WebSocket-Origin: " + environ.get('HTTP_ORIGIN') + b"\r\n"
environ.get('HTTP_ORIGIN'), b"WebSocket-Location: " + six.b(location) + b"\r\n\r\n"
location)) )
elif self.protocol_version == 76: elif self.protocol_version == 76:
handshake_reply = ("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" handshake_reply = (
"Upgrade: WebSocket\r\n" b"HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
"Connection: Upgrade\r\n" b"Upgrade: WebSocket\r\n"
"Sec-WebSocket-Origin: %s\r\n" b"Connection: Upgrade\r\n"
"Sec-WebSocket-Protocol: %s\r\n" b"Sec-WebSocket-Origin: " + six.b(environ.get('HTTP_ORIGIN')) + b"\r\n"
"Sec-WebSocket-Location: %s\r\n" b"Sec-WebSocket-Protocol: " +
"\r\n%s" % ( six.b(environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default')) + b"\r\n"
environ.get('HTTP_ORIGIN'), b"Sec-WebSocket-Location: " + six.b(location) + b"\r\n"
environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default'), b"\r\n" + response
location, )
response))
else: # pragma NO COVER else: # pragma NO COVER
raise ValueError("Unknown WebSocket protocol version.") raise ValueError("Unknown WebSocket protocol version.")
sock.sendall(handshake_reply) sock.sendall(handshake_reply)
@@ -244,7 +243,7 @@ class WebSocketWSGI(object):
out += char out += char
elif char == " ": elif char == " ":
spaces += 1 spaces += 1
return int(out) / spaces return int(out) // spaces
class WebSocket(object): class WebSocket(object):
@@ -281,7 +280,7 @@ class WebSocket(object):
self.environ = environ self.environ = environ
self.version = version self.version = version
self.websocket_closed = False self.websocket_closed = False
self._buf = "" self._buf = b""
self._msgs = collections.deque() self._msgs = collections.deque()
self._sendlock = semaphore.Semaphore() self._sendlock = semaphore.Semaphore()
@@ -294,8 +293,8 @@ class WebSocket(object):
if isinstance(message, six.text_type): if isinstance(message, six.text_type):
message = message.encode('utf-8') message = message.encode('utf-8')
elif not isinstance(message, six.binary_type): elif not isinstance(message, six.binary_type):
message = b'%s' % (message,) message = six.b(str(message))
packed = b"\x00%s\xFF" % message packed = b"\x00" + message + b"\xFF"
return packed return packed
def _parse_messages(self): def _parse_messages(self):
@@ -309,17 +308,17 @@ class WebSocket(object):
end_idx = 0 end_idx = 0
buf = self._buf buf = self._buf
while buf: while buf:
frame_type = ord(buf[0]) frame_type = six.indexbytes(buf, 0)
if frame_type == 0: if frame_type == 0:
# Normal message. # Normal message.
end_idx = buf.find("\xFF") end_idx = buf.find(b"\xFF")
if end_idx == -1: # pragma NO COVER if end_idx == -1: # pragma NO COVER
break break
msgs.append(buf[1:end_idx].decode('utf-8', 'replace')) msgs.append(buf[1:end_idx].decode('utf-8', 'replace'))
buf = buf[end_idx + 1:] buf = buf[end_idx + 1:]
elif frame_type == 255: elif frame_type == 255:
# Closing handshake. # Closing handshake.
assert ord(buf[1]) == 0, "Unexpected closing handshake: %r" % buf assert six.indexbytes(buf, 1) == 0, "Unexpected closing handshake: %r" % buf
self.websocket_closed = True self.websocket_closed = True
break break
else: else:
@@ -355,7 +354,7 @@ class WebSocket(object):
return None return None
# no parsed messages, must mean buf needs more data # no parsed messages, must mean buf needs more data
delta = self.socket.recv(8096) delta = self.socket.recv(8096)
if delta == '': if delta == b'':
return None return None
self._buf += delta self._buf += delta
msgs = self._parse_messages() msgs = self._parse_messages()

View File

@@ -69,11 +69,13 @@ class Input(object):
def __init__(self, def __init__(self,
rfile, rfile,
content_length, content_length,
sock,
wfile=None, wfile=None,
wfile_line=None, wfile_line=None,
chunked_input=False): chunked_input=False):
self.rfile = rfile self.rfile = rfile
self._sock = sock
if content_length is not None: if content_length is not None:
content_length = int(content_length) content_length = int(content_length)
self.content_length = content_length self.content_length = content_length
@@ -199,7 +201,7 @@ class Input(object):
return iter(self.read, b'') return iter(self.read, b'')
def get_socket(self): def get_socket(self):
return self.rfile._sock return self._sock
def set_hundred_continue_response_headers(self, headers, def set_hundred_continue_response_headers(self, headers,
capitalize_response_headers=True): capitalize_response_headers=True):
@@ -570,7 +572,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
wfile_line = None wfile_line = None
chunked = env.get('HTTP_TRANSFER_ENCODING', '').lower() == 'chunked' chunked = env.get('HTTP_TRANSFER_ENCODING', '').lower() == 'chunked'
env['wsgi.input'] = env['eventlet.input'] = Input( env['wsgi.input'] = env['eventlet.input'] = Input(
self.rfile, length, wfile=wfile, wfile_line=wfile_line, self.rfile, length, self.connection, wfile=wfile, wfile_line=wfile_line,
chunked_input=chunked) chunked_input=chunked)
env['eventlet.posthooks'] = [] env['eventlet.posthooks'] = []

View File

@@ -273,7 +273,7 @@ class TestGreenSocket(LimitedTestCase):
# by closing the socket prior to using the made file # by closing the socket prior to using the made file
try: try:
conn, addr = listener.accept() conn, addr = listener.accept()
fd = conn.makefile('w') fd = conn.makefile('wb')
conn.close() conn.close()
fd.write(b'hello\n') fd.write(b'hello\n')
fd.close() fd.close()
@@ -287,7 +287,7 @@ class TestGreenSocket(LimitedTestCase):
# by closing the made file and then sending a character # by closing the made file and then sending a character
try: try:
conn, addr = listener.accept() conn, addr = listener.accept()
fd = conn.makefile('w') fd = conn.makefile('wb')
fd.write(b'hello') fd.write(b'hello')
fd.close() fd.close()
conn.send(b'\n') conn.send(b'\n')
@@ -300,7 +300,7 @@ class TestGreenSocket(LimitedTestCase):
def did_it_work(server): def did_it_work(server):
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(('127.0.0.1', server.getsockname()[1])) client.connect(('127.0.0.1', server.getsockname()[1]))
fd = client.makefile() fd = client.makefile('rb')
client.close() client.close()
assert fd.readline() == b'hello\n' assert fd.readline() == b'hello\n'
assert fd.read() == b'' assert fd.read() == b''
@@ -329,7 +329,7 @@ class TestGreenSocket(LimitedTestCase):
# closing the file object should close everything # closing the file object should close everything
try: try:
conn, addr = listener.accept() conn, addr = listener.accept()
conn = conn.makefile('w') conn = conn.makefile('wb')
conn.write(b'hello\n') conn.write(b'hello\n')
conn.close() conn.close()
gc.collect() gc.collect()
@@ -344,7 +344,7 @@ class TestGreenSocket(LimitedTestCase):
killer = eventlet.spawn(accept_once, server) killer = eventlet.spawn(accept_once, server)
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(('127.0.0.1', server.getsockname()[1])) client.connect(('127.0.0.1', server.getsockname()[1]))
fd = client.makefile() fd = client.makefile('rb')
client.close() client.close()
assert fd.read() == b'hello\n' assert fd.read() == b'hello\n'
assert fd.read() == b'' assert fd.read() == b''
@@ -603,7 +603,7 @@ class TestGreenSocket(LimitedTestCase):
def test_sockopt_interface(self): def test_sockopt_interface(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
assert sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0 assert sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0
assert sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) == '\000' assert sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) == b'\000'
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
def test_socketpair_select(self): def test_socketpair_select(self):

View File

@@ -4,12 +4,11 @@ import socket
import eventlet import eventlet
from eventlet import event from eventlet import event
from eventlet import greenio from eventlet import greenio
from eventlet import wsgi
from eventlet.green import httplib from eventlet.green import httplib
from eventlet.green import urllib2 from eventlet.support import six
from eventlet.websocket import WebSocket, WebSocketWSGI from eventlet.websocket import WebSocketWSGI
from tests import mock, LimitedTestCase, certificate_file, private_key_file from tests import certificate_file, private_key_file
from tests import skip_if_no_ssl from tests import skip_if_no_ssl
from tests.wsgi_test import _TestBase from tests.wsgi_test import _TestBase
@@ -42,13 +41,10 @@ class TestWebSocket(_TestBase):
self.site = wsapp self.site = wsapp
def test_incorrect_headers(self): def test_incorrect_headers(self):
def raiser(): http = httplib.HTTPConnection('localhost', self.port)
try: http.request("GET", "/echo")
urllib2.urlopen("http://localhost:%s/echo" % self.port) response = http.getresponse()
except urllib2.HTTPError as e: assert response.status == 400
self.assertEqual(e.code, 400)
raise
self.assertRaises(urllib2.HTTPError, raiser)
def test_incomplete_headers_75(self): def test_incomplete_headers_75(self):
headers = dict(kv.split(': ') for kv in [ headers = dict(kv.split(': ') for kv in [
@@ -113,7 +109,7 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
result = sock.recv(1024) result = sock.recv(1024)
# The server responds the correct Websocket handshake # The server responds the correct Websocket handshake
self.assertEqual(result, '\r\n'.join([ self.assertEqual(result, '\r\n'.join([
@@ -138,7 +134,7 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
result = sock.recv(1024) result = sock.recv(1024)
# The server responds the correct Websocket handshake # The server responds the correct Websocket handshake
self.assertEqual(result, '\r\n'.join([ self.assertEqual(result, '\r\n'.join([
@@ -165,7 +161,7 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\r\n'.join([ self.assertEqual(result, '\r\n'.join([
'HTTP/1.1 101 WebSocket Protocol Handshake', 'HTTP/1.1 101 WebSocket Protocol Handshake',
@@ -191,7 +187,7 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\r\n'.join([ self.assertEqual(result, '\r\n'.join([
'HTTP/1.1 101 WebSocket Protocol Handshake', 'HTTP/1.1 101 WebSocket Protocol Handshake',
@@ -214,16 +210,16 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
sock.recv(1024) sock.recv(1024)
sock.sendall('\x00hello\xFF') sock.sendall(b'\x00hello\xFF')
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\x00hello\xff') self.assertEqual(result, '\x00hello\xff')
sock.sendall('\x00start') sock.sendall(b'\x00start')
eventlet.sleep(0.001) eventlet.sleep(0.001)
sock.sendall(' end\xff') sock.sendall(b' end\xff')
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\x00start end\xff') self.assertEqual(result, b'\x00start end\xff')
sock.shutdown(socket.SHUT_RDWR) sock.shutdown(socket.SHUT_RDWR)
sock.close() sock.close()
eventlet.sleep(0.01) eventlet.sleep(0.01)
@@ -242,16 +238,16 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
sock.recv(1024) sock.recv(1024)
sock.sendall('\x00hello\xFF') sock.sendall(b'\x00hello\xFF')
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\x00hello\xff') self.assertEqual(result, b'\x00hello\xff')
sock.sendall('\x00start') sock.sendall(b'\x00start')
eventlet.sleep(0.001) eventlet.sleep(0.001)
sock.sendall(' end\xff') sock.sendall(b' end\xff')
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\x00start end\xff') self.assertEqual(result, b'\x00start end\xff')
sock.shutdown(socket.SHUT_RDWR) sock.shutdown(socket.SHUT_RDWR)
sock.close() sock.close()
eventlet.sleep(0.01) eventlet.sleep(0.01)
@@ -268,16 +264,16 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
resp = sock.recv(1024) resp = sock.recv(1024)
headers, result = resp.split('\r\n\r\n') headers, result = resp.split(b'\r\n\r\n')
msgs = [result.strip('\x00\xff')] msgs = [result.strip(b'\x00\xff')]
cnt = 10 cnt = 10
while cnt: while cnt:
msgs.append(sock.recv(20).strip('\x00\xff')) msgs.append(sock.recv(20).strip(b'\x00\xff'))
cnt -= 1 cnt -= 1
# Last item in msgs is an empty string # Last item in msgs is an empty string
self.assertEqual(msgs[:-1], ['msg %d' % i for i in range(10)]) self.assertEqual(msgs[:-1], [six.b('msg %d' % i) for i in range(10)])
def test_getting_messages_from_websocket_76(self): def test_getting_messages_from_websocket_76(self):
connect = [ connect = [
@@ -293,16 +289,16 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
resp = sock.recv(1024) resp = sock.recv(1024)
headers, result = resp.split('\r\n\r\n') headers, result = resp.split(b'\r\n\r\n')
msgs = [result[16:].strip('\x00\xff')] msgs = [result[16:].strip(b'\x00\xff')]
cnt = 10 cnt = 10
while cnt: while cnt:
msgs.append(sock.recv(20).strip('\x00\xff')) msgs.append(sock.recv(20).strip(b'\x00\xff'))
cnt -= 1 cnt -= 1
# Last item in msgs is an empty string # Last item in msgs is an empty string
self.assertEqual(msgs[:-1], ['msg %d' % i for i in range(10)]) self.assertEqual(msgs[:-1], [six.b('msg %d' % i) for i in range(10)])
def test_breaking_the_connection_75(self): def test_breaking_the_connection_75(self):
error_detected = [False] error_detected = [False]
@@ -330,7 +326,7 @@ class TestWebSocket(_TestBase):
] ]
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
sock.recv(1024) # get the headers sock.recv(1024) # get the headers
sock.close() # close while the app is running sock.close() # close while the app is running
done_with_request.wait() done_with_request.wait()
@@ -364,7 +360,7 @@ class TestWebSocket(_TestBase):
] ]
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
sock.recv(1024) # get the headers sock.recv(1024) # get the headers
sock.close() # close while the app is running sock.close() # close while the app is running
done_with_request.wait() done_with_request.wait()
@@ -398,9 +394,9 @@ class TestWebSocket(_TestBase):
] ]
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
sock.recv(1024) # get the headers sock.recv(1024) # get the headers
sock.sendall('\xff\x00') # "Close the connection" packet. sock.sendall(b'\xff\x00') # "Close the connection" packet.
done_with_request.wait() done_with_request.wait()
assert not error_detected[0] assert not error_detected[0]
@@ -432,9 +428,9 @@ class TestWebSocket(_TestBase):
] ]
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
sock.recv(1024) # get the headers sock.recv(1024) # get the headers
sock.sendall('\xef\x00') # Weird packet. sock.sendall(b'\xef\x00') # Weird packet.
done_with_request.wait() done_with_request.wait()
assert error_detected[0] assert error_detected[0]
@@ -452,11 +448,11 @@ class TestWebSocket(_TestBase):
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
resp = sock.recv(1024) resp = sock.recv(1024)
headers, result = resp.split('\r\n\r\n') headers, result = resp.split(b'\r\n\r\n')
# The remote server should have immediately closed the connection. # The remote server should have immediately closed the connection.
self.assertEqual(result[16:], '\xff\x00') self.assertEqual(result[16:], b'\xff\x00')
def test_app_socket_errors_75(self): def test_app_socket_errors_75(self):
error_detected = [False] error_detected = [False]
@@ -484,7 +480,7 @@ class TestWebSocket(_TestBase):
] ]
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
sock.recv(1024) sock.recv(1024)
done_with_request.wait() done_with_request.wait()
assert error_detected[0] assert error_detected[0]
@@ -517,7 +513,7 @@ class TestWebSocket(_TestBase):
] ]
sock = eventlet.connect( sock = eventlet.connect(
('localhost', self.port)) ('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
sock.recv(1024) sock.recv(1024)
done_with_request.wait() done_with_request.wait()
assert error_detected[0] assert error_detected[0]
@@ -547,55 +543,25 @@ class TestWebSocketSSL(_TestBase):
sock = eventlet.wrap_ssl(eventlet.connect( sock = eventlet.wrap_ssl(eventlet.connect(
('localhost', self.port))) ('localhost', self.port)))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U') sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n^n:ds[4U'))
first_resp = sock.recv(1024) first_resp = b''
while b'\r\n\r\n' not in first_resp:
first_resp += sock.recv()
print('resp now:')
print(first_resp)
# make sure it sets the wss: protocol on the location header # make sure it sets the wss: protocol on the location header
loc_line = [x for x in first_resp.split("\r\n") loc_line = [x for x in first_resp.split(b"\r\n")
if x.lower().startswith('sec-websocket-location')][0] if x.lower().startswith(b'sec-websocket-location')][0]
self.assert_("wss://localhost" in loc_line, self.assert_(b"wss://localhost" in loc_line,
"Expecting wss protocol in location: %s" % loc_line) "Expecting wss protocol in location: %s" % loc_line)
sock.sendall('\x00hello\xFF') sock.sendall(b'\x00hello\xFF')
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\x00hello\xff') self.assertEqual(result, b'\x00hello\xff')
sock.sendall('\x00start') sock.sendall(b'\x00start')
eventlet.sleep(0.001) eventlet.sleep(0.001)
sock.sendall(' end\xff') sock.sendall(b' end\xff')
result = sock.recv(1024) result = sock.recv(1024)
self.assertEqual(result, '\x00start end\xff') self.assertEqual(result, b'\x00start end\xff')
greenio.shutdown_safe(sock) greenio.shutdown_safe(sock)
sock.close() sock.close()
eventlet.sleep(0.01) eventlet.sleep(0.01)
class TestWebSocketObject(LimitedTestCase):
def setUp(self):
self.mock_socket = s = mock.Mock()
self.environ = env = dict(HTTP_ORIGIN='http://localhost', HTTP_WEBSOCKET_PROTOCOL='ws',
PATH_INFO='test')
self.test_ws = WebSocket(s, env)
super(TestWebSocketObject, self).setUp()
def test_recieve(self):
ws = self.test_ws
ws.socket.recv.return_value = '\x00hello\xFF'
self.assertEqual(ws.wait(), 'hello')
self.assertEqual(ws._buf, '')
self.assertEqual(len(ws._msgs), 0)
ws.socket.recv.return_value = ''
self.assertEqual(ws.wait(), None)
self.assertEqual(ws._buf, '')
self.assertEqual(len(ws._msgs), 0)
def test_send_to_ws(self):
ws = self.test_ws
ws.send(u'hello')
assert ws.socket.sendall.called_with("\x00hello\xFF")
ws.send(10)
assert ws.socket.sendall.called_with("\x0010\xFF")
def test_close_ws(self):
ws = self.test_ws
ws.close()
assert ws.socket.shutdown.called_with(True)

View File

@@ -145,6 +145,17 @@ hello world
""" """
def recvall(socket_):
result = b''
while True:
chunk = socket_.recv()
result += chunk
if chunk == b'':
break
return result
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
@@ -449,8 +460,8 @@ class TestHttpd(_TestBase):
sock.write( sock.write(
b'POST /foo HTTP/1.1\r\nHost: localhost\r\n' b'POST /foo HTTP/1.1\r\nHost: localhost\r\n'
b'Connection: close\r\nContent-length:3\r\n\r\nabc') b'Connection: close\r\nContent-length:3\r\n\r\nabc')
result = sock.read(8192) result = recvall(sock)
self.assertEqual(result[-3:], b'abc') assert result.endswith(b'abc')
@tests.skip_if_no_ssl @tests.skip_if_no_ssl
def test_013_empty_return(self): def test_013_empty_return(self):
@@ -469,8 +480,8 @@ class TestHttpd(_TestBase):
sock = eventlet.connect(('localhost', server_sock.getsockname()[1])) sock = eventlet.connect(('localhost', server_sock.getsockname()[1]))
sock = eventlet.wrap_ssl(sock) sock = eventlet.wrap_ssl(sock)
sock.write(b'GET /foo HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n') sock.write(b'GET /foo HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n')
result = sock.read(8192) result = recvall(sock)
self.assertEqual(result[-4:], b'\r\n\r\n') assert result[-4:] == b'\r\n\r\n'
def test_014_chunked_post(self): def test_014_chunked_post(self):
self.site.application = chunked_post self.site.application = chunked_post
@@ -1118,9 +1129,9 @@ class TestHttpd(_TestBase):
try: try:
client = ssl.wrap_socket(eventlet.connect(('localhost', port))) client = ssl.wrap_socket(eventlet.connect(('localhost', port)))
client.write(b'GET / HTTP/1.0\r\nHost: localhost\r\n\r\n') client.write(b'GET / HTTP/1.0\r\nHost: localhost\r\n\r\n')
result = client.read() result = recvall(client)
assert result.startswith('HTTP'), result assert result.startswith(b'HTTP'), result
assert result.endswith('hello world') assert result.endswith(b'hello world')
except ImportError: except ImportError:
pass # TODO(openssl): should test with OpenSSL pass # TODO(openssl): should test with OpenSSL
greenthread.kill(g) greenthread.kill(g)