Merge pull request #1 from methane/remove-makefile

Don't use `socket.makefile()`.
This commit is contained in:
Marcel Rodrigues
2013-08-29 06:13:15 -07:00

View File

@@ -223,7 +223,7 @@ class MysqlPacket(object):
def __recv_packet(self): def __recv_packet(self):
"""Parse the packet header and read entire packet payload into buffer.""" """Parse the packet header and read entire packet payload into buffer."""
packet_header = self.connection.rfile.read(4) packet_header = self.connection._read_bytes(4)
if len(packet_header) < 4: if len(packet_header) < 4:
raise OperationalError(2013, "Lost connection to MySQL server during query") raise OperationalError(2013, "Lost connection to MySQL server during query")
@@ -234,7 +234,7 @@ class MysqlPacket(object):
bin_length = packet_length_bin + int2byte(0) # pad little-endian number bin_length = packet_length_bin + int2byte(0) # pad little-endian number
bytes_to_read = struct.unpack('<I', bin_length)[0] bytes_to_read = struct.unpack('<I', bin_length)[0]
recv_data = self.connection.rfile.read(bytes_to_read) recv_data = self.connection._read_bytes(bytes_to_read)
if len(recv_data) < bytes_to_read: if len(recv_data) < bytes_to_read:
raise OperationalError(2013, "Lost connection to MySQL server during query") raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(recv_data) if DEBUG: dump_packet(recv_data)
@@ -614,13 +614,9 @@ class Connection(object):
if self.socket is None: if self.socket is None:
raise Error("Already closed") raise Error("Already closed")
send_data = struct.pack('<i',1) + int2byte(COM_QUIT) send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
self.wfile.write(send_data) self._write_bytes(send_data)
self.wfile.close()
self.rfile.close()
self.socket.close() self.socket.close()
self.socket = None self.socket = None
self.rfile = None
self.wfile = None
def autocommit(self, value): def autocommit(self, value):
self.autocommit_mode = value self.autocommit_mode = value
@@ -752,8 +748,6 @@ class Connection(object):
if self.no_delay: if self.no_delay:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.socket = sock self.socket = sock
self.rfile = self.socket.makefile("rb")
self.wfile = self.socket.makefile("wb")
self._get_server_information() self._get_server_information()
self._request_authentication() self._request_authentication()
@@ -769,6 +763,23 @@ class Connection(object):
packet.check_error() packet.check_error()
return packet return packet
def _read_bytes(self, num_bytes):
d = self.socket.recv(num_bytes)
num_bytes -= len(d)
if num_bytes == 0:
return d
buff = bytearray(d)
while num_bytes:
d = self.socket.recv(num_bytes)
if not d:
break
num_bytes -= len(d)
buff += d
return bytes(buff)
def _write_bytes(self, data):
return self.socket.sendall(data)
def _read_query_result(self, unbuffered=False): def _read_query_result(self, unbuffered=False):
if unbuffered: if unbuffered:
try: try:
@@ -804,17 +815,13 @@ class Connection(object):
sql = sql.encode(self.charset) sql = sql.encode(self.charset)
prelude = struct.pack('<i', len(sql)+1) + int2byte(command) prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
self.wfile.write(prelude + sql) self._write_bytes(prelude + sql)
self.wfile.flush()
if DEBUG: dump_packet(prelude + sql) if DEBUG: dump_packet(prelude + sql)
def _execute_command(self, command, sql): def _execute_command(self, command, sql):
self._send_command(command, sql) self._send_command(command, sql)
def _request_authentication(self): def _request_authentication(self):
self._send_authentication()
def _send_authentication(self):
self.client_flag |= CAPABILITIES self.client_flag |= CAPABILITIES
if self.server_version.startswith('5'): if self.server_version.startswith('5'):
self.client_flag |= MULTI_RESULTS self.client_flag |= MULTI_RESULTS
@@ -836,15 +843,12 @@ class Connection(object):
if DEBUG: dump_packet(data) if DEBUG: dump_packet(data)
self.wfile.write(data) self._write_bytes(data)
self.wfile.flush()
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key, self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
certfile=self.cert, certfile=self.cert,
ssl_version=ssl.PROTOCOL_TLSv1, ssl_version=ssl.PROTOCOL_TLSv1,
cert_reqs=ssl.CERT_REQUIRED, cert_reqs=ssl.CERT_REQUIRED,
ca_certs=self.ca) ca_certs=self.ca)
self.rfile = self.socket.makefile("rb")
self.wfile = self.socket.makefile("wb")
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt) data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
@@ -857,8 +861,7 @@ class Connection(object):
if DEBUG: dump_packet(data) if DEBUG: dump_packet(data)
self.wfile.write(data) self._write_bytes(data)
self.wfile.flush()
auth_packet = MysqlPacket(self) auth_packet = MysqlPacket(self)
auth_packet.check_error() auth_packet.check_error()
@@ -874,8 +877,7 @@ class Connection(object):
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0) data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
data = pack_int24(len(data)) + int2byte(next_packet) + data data = pack_int24(len(data)) + int2byte(next_packet) + data
self.wfile.write(data) self._write_bytes(data)
self.wfile.flush()
auth_packet = MysqlPacket(self) auth_packet = MysqlPacket(self)
auth_packet.check_error() auth_packet.check_error()
if DEBUG: auth_packet.dump() if DEBUG: auth_packet.dump()