Merge pull request #1 from methane/remove-makefile
Don't use `socket.makefile()`.
This commit is contained in:
@@ -223,7 +223,7 @@ class MysqlPacket(object):
|
|||||||
|
|
||||||
def __recv_packet(self):
|
def __recv_packet(self):
|
||||||
"""Parse the packet header and read entire packet payload into buffer."""
|
"""Parse the packet header and read entire packet payload into buffer."""
|
||||||
packet_header = self.connection.rfile.read(4)
|
packet_header = self.connection._read_bytes(4)
|
||||||
if len(packet_header) < 4:
|
if len(packet_header) < 4:
|
||||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||||
|
|
||||||
@@ -234,7 +234,7 @@ class MysqlPacket(object):
|
|||||||
|
|
||||||
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
|
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
|
||||||
bytes_to_read = struct.unpack('<I', bin_length)[0]
|
bytes_to_read = struct.unpack('<I', bin_length)[0]
|
||||||
recv_data = self.connection.rfile.read(bytes_to_read)
|
recv_data = self.connection._read_bytes(bytes_to_read)
|
||||||
if len(recv_data) < bytes_to_read:
|
if len(recv_data) < bytes_to_read:
|
||||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||||
if DEBUG: dump_packet(recv_data)
|
if DEBUG: dump_packet(recv_data)
|
||||||
@@ -614,13 +614,9 @@ class Connection(object):
|
|||||||
if self.socket is None:
|
if self.socket is None:
|
||||||
raise Error("Already closed")
|
raise Error("Already closed")
|
||||||
send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
|
send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
|
||||||
self.wfile.write(send_data)
|
self._write_bytes(send_data)
|
||||||
self.wfile.close()
|
|
||||||
self.rfile.close()
|
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
self.socket = None
|
self.socket = None
|
||||||
self.rfile = None
|
|
||||||
self.wfile = None
|
|
||||||
|
|
||||||
def autocommit(self, value):
|
def autocommit(self, value):
|
||||||
self.autocommit_mode = value
|
self.autocommit_mode = value
|
||||||
@@ -752,8 +748,6 @@ class Connection(object):
|
|||||||
if self.no_delay:
|
if self.no_delay:
|
||||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||||
self.socket = sock
|
self.socket = sock
|
||||||
self.rfile = self.socket.makefile("rb")
|
|
||||||
self.wfile = self.socket.makefile("wb")
|
|
||||||
self._get_server_information()
|
self._get_server_information()
|
||||||
self._request_authentication()
|
self._request_authentication()
|
||||||
|
|
||||||
@@ -769,6 +763,23 @@ class Connection(object):
|
|||||||
packet.check_error()
|
packet.check_error()
|
||||||
return packet
|
return packet
|
||||||
|
|
||||||
|
def _read_bytes(self, num_bytes):
|
||||||
|
d = self.socket.recv(num_bytes)
|
||||||
|
num_bytes -= len(d)
|
||||||
|
if num_bytes == 0:
|
||||||
|
return d
|
||||||
|
buff = bytearray(d)
|
||||||
|
while num_bytes:
|
||||||
|
d = self.socket.recv(num_bytes)
|
||||||
|
if not d:
|
||||||
|
break
|
||||||
|
num_bytes -= len(d)
|
||||||
|
buff += d
|
||||||
|
return bytes(buff)
|
||||||
|
|
||||||
|
def _write_bytes(self, data):
|
||||||
|
return self.socket.sendall(data)
|
||||||
|
|
||||||
def _read_query_result(self, unbuffered=False):
|
def _read_query_result(self, unbuffered=False):
|
||||||
if unbuffered:
|
if unbuffered:
|
||||||
try:
|
try:
|
||||||
@@ -804,17 +815,13 @@ class Connection(object):
|
|||||||
sql = sql.encode(self.charset)
|
sql = sql.encode(self.charset)
|
||||||
|
|
||||||
prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
|
prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
|
||||||
self.wfile.write(prelude + sql)
|
self._write_bytes(prelude + sql)
|
||||||
self.wfile.flush()
|
|
||||||
if DEBUG: dump_packet(prelude + sql)
|
if DEBUG: dump_packet(prelude + sql)
|
||||||
|
|
||||||
def _execute_command(self, command, sql):
|
def _execute_command(self, command, sql):
|
||||||
self._send_command(command, sql)
|
self._send_command(command, sql)
|
||||||
|
|
||||||
def _request_authentication(self):
|
def _request_authentication(self):
|
||||||
self._send_authentication()
|
|
||||||
|
|
||||||
def _send_authentication(self):
|
|
||||||
self.client_flag |= CAPABILITIES
|
self.client_flag |= CAPABILITIES
|
||||||
if self.server_version.startswith('5'):
|
if self.server_version.startswith('5'):
|
||||||
self.client_flag |= MULTI_RESULTS
|
self.client_flag |= MULTI_RESULTS
|
||||||
@@ -836,15 +843,12 @@ class Connection(object):
|
|||||||
|
|
||||||
if DEBUG: dump_packet(data)
|
if DEBUG: dump_packet(data)
|
||||||
|
|
||||||
self.wfile.write(data)
|
self._write_bytes(data)
|
||||||
self.wfile.flush()
|
|
||||||
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
|
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
|
||||||
certfile=self.cert,
|
certfile=self.cert,
|
||||||
ssl_version=ssl.PROTOCOL_TLSv1,
|
ssl_version=ssl.PROTOCOL_TLSv1,
|
||||||
cert_reqs=ssl.CERT_REQUIRED,
|
cert_reqs=ssl.CERT_REQUIRED,
|
||||||
ca_certs=self.ca)
|
ca_certs=self.ca)
|
||||||
self.rfile = self.socket.makefile("rb")
|
|
||||||
self.wfile = self.socket.makefile("wb")
|
|
||||||
|
|
||||||
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
|
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
|
||||||
|
|
||||||
@@ -857,8 +861,7 @@ class Connection(object):
|
|||||||
|
|
||||||
if DEBUG: dump_packet(data)
|
if DEBUG: dump_packet(data)
|
||||||
|
|
||||||
self.wfile.write(data)
|
self._write_bytes(data)
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
auth_packet = MysqlPacket(self)
|
auth_packet = MysqlPacket(self)
|
||||||
auth_packet.check_error()
|
auth_packet.check_error()
|
||||||
@@ -874,8 +877,7 @@ class Connection(object):
|
|||||||
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
|
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
|
||||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
||||||
|
|
||||||
self.wfile.write(data)
|
self._write_bytes(data)
|
||||||
self.wfile.flush()
|
|
||||||
auth_packet = MysqlPacket(self)
|
auth_packet = MysqlPacket(self)
|
||||||
auth_packet.check_error()
|
auth_packet.check_error()
|
||||||
if DEBUG: auth_packet.dump()
|
if DEBUG: auth_packet.dump()
|
||||||
|
|||||||
Reference in New Issue
Block a user