Merge pull request #1 from methane/remove-makefile
Don't use `socket.makefile()`.
This commit is contained in:
@@ -223,7 +223,7 @@ class MysqlPacket(object):
|
||||
|
||||
def __recv_packet(self):
|
||||
"""Parse the packet header and read entire packet payload into buffer."""
|
||||
packet_header = self.connection.rfile.read(4)
|
||||
packet_header = self.connection._read_bytes(4)
|
||||
if len(packet_header) < 4:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
|
||||
@@ -234,7 +234,7 @@ class MysqlPacket(object):
|
||||
|
||||
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
|
||||
bytes_to_read = struct.unpack('<I', bin_length)[0]
|
||||
recv_data = self.connection.rfile.read(bytes_to_read)
|
||||
recv_data = self.connection._read_bytes(bytes_to_read)
|
||||
if len(recv_data) < bytes_to_read:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
if DEBUG: dump_packet(recv_data)
|
||||
@@ -614,13 +614,9 @@ class Connection(object):
|
||||
if self.socket is None:
|
||||
raise Error("Already closed")
|
||||
send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
|
||||
self.wfile.write(send_data)
|
||||
self.wfile.close()
|
||||
self.rfile.close()
|
||||
self._write_bytes(send_data)
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
self.rfile = None
|
||||
self.wfile = None
|
||||
|
||||
def autocommit(self, value):
|
||||
self.autocommit_mode = value
|
||||
@@ -752,8 +748,6 @@ class Connection(object):
|
||||
if self.no_delay:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
self.socket = sock
|
||||
self.rfile = self.socket.makefile("rb")
|
||||
self.wfile = self.socket.makefile("wb")
|
||||
self._get_server_information()
|
||||
self._request_authentication()
|
||||
|
||||
@@ -769,6 +763,23 @@ class Connection(object):
|
||||
packet.check_error()
|
||||
return packet
|
||||
|
||||
def _read_bytes(self, num_bytes):
|
||||
d = self.socket.recv(num_bytes)
|
||||
num_bytes -= len(d)
|
||||
if num_bytes == 0:
|
||||
return d
|
||||
buff = bytearray(d)
|
||||
while num_bytes:
|
||||
d = self.socket.recv(num_bytes)
|
||||
if not d:
|
||||
break
|
||||
num_bytes -= len(d)
|
||||
buff += d
|
||||
return bytes(buff)
|
||||
|
||||
def _write_bytes(self, data):
|
||||
return self.socket.sendall(data)
|
||||
|
||||
def _read_query_result(self, unbuffered=False):
|
||||
if unbuffered:
|
||||
try:
|
||||
@@ -804,17 +815,13 @@ class Connection(object):
|
||||
sql = sql.encode(self.charset)
|
||||
|
||||
prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
|
||||
self.wfile.write(prelude + sql)
|
||||
self.wfile.flush()
|
||||
self._write_bytes(prelude + sql)
|
||||
if DEBUG: dump_packet(prelude + sql)
|
||||
|
||||
def _execute_command(self, command, sql):
|
||||
self._send_command(command, sql)
|
||||
|
||||
def _request_authentication(self):
|
||||
self._send_authentication()
|
||||
|
||||
def _send_authentication(self):
|
||||
self.client_flag |= CAPABILITIES
|
||||
if self.server_version.startswith('5'):
|
||||
self.client_flag |= MULTI_RESULTS
|
||||
@@ -836,15 +843,12 @@ class Connection(object):
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
|
||||
self.wfile.write(data)
|
||||
self.wfile.flush()
|
||||
self._write_bytes(data)
|
||||
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
|
||||
certfile=self.cert,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1,
|
||||
cert_reqs=ssl.CERT_REQUIRED,
|
||||
ca_certs=self.ca)
|
||||
self.rfile = self.socket.makefile("rb")
|
||||
self.wfile = self.socket.makefile("wb")
|
||||
|
||||
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
|
||||
|
||||
@@ -857,8 +861,7 @@ class Connection(object):
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
|
||||
self.wfile.write(data)
|
||||
self.wfile.flush()
|
||||
self._write_bytes(data)
|
||||
|
||||
auth_packet = MysqlPacket(self)
|
||||
auth_packet.check_error()
|
||||
@@ -874,8 +877,7 @@ class Connection(object):
|
||||
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
|
||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
||||
|
||||
self.wfile.write(data)
|
||||
self.wfile.flush()
|
||||
self._write_bytes(data)
|
||||
auth_packet = MysqlPacket(self)
|
||||
auth_packet.check_error()
|
||||
if DEBUG: auth_packet.dump()
|
||||
|
||||
Reference in New Issue
Block a user