Merge pull request #1 from methane/remove-makefile

Don't use `socket.makefile()`.
This commit is contained in:
Marcel Rodrigues
2013-08-29 06:13:15 -07:00

View File

@@ -223,7 +223,7 @@ class MysqlPacket(object):
def __recv_packet(self):
"""Parse the packet header and read entire packet payload into buffer."""
packet_header = self.connection.rfile.read(4)
packet_header = self.connection._read_bytes(4)
if len(packet_header) < 4:
raise OperationalError(2013, "Lost connection to MySQL server during query")
@@ -234,7 +234,7 @@ class MysqlPacket(object):
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
bytes_to_read = struct.unpack('<I', bin_length)[0]
recv_data = self.connection.rfile.read(bytes_to_read)
recv_data = self.connection._read_bytes(bytes_to_read)
if len(recv_data) < bytes_to_read:
raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(recv_data)
@@ -614,13 +614,9 @@ class Connection(object):
if self.socket is None:
raise Error("Already closed")
send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
self.wfile.write(send_data)
self.wfile.close()
self.rfile.close()
self._write_bytes(send_data)
self.socket.close()
self.socket = None
self.rfile = None
self.wfile = None
def autocommit(self, value):
self.autocommit_mode = value
@@ -752,8 +748,6 @@ class Connection(object):
if self.no_delay:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.socket = sock
self.rfile = self.socket.makefile("rb")
self.wfile = self.socket.makefile("wb")
self._get_server_information()
self._request_authentication()
@@ -769,6 +763,23 @@ class Connection(object):
packet.check_error()
return packet
def _read_bytes(self, num_bytes):
d = self.socket.recv(num_bytes)
num_bytes -= len(d)
if num_bytes == 0:
return d
buff = bytearray(d)
while num_bytes:
d = self.socket.recv(num_bytes)
if not d:
break
num_bytes -= len(d)
buff += d
return bytes(buff)
def _write_bytes(self, data):
return self.socket.sendall(data)
def _read_query_result(self, unbuffered=False):
if unbuffered:
try:
@@ -804,17 +815,13 @@ class Connection(object):
sql = sql.encode(self.charset)
prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
self.wfile.write(prelude + sql)
self.wfile.flush()
self._write_bytes(prelude + sql)
if DEBUG: dump_packet(prelude + sql)
def _execute_command(self, command, sql):
self._send_command(command, sql)
def _request_authentication(self):
self._send_authentication()
def _send_authentication(self):
self.client_flag |= CAPABILITIES
if self.server_version.startswith('5'):
self.client_flag |= MULTI_RESULTS
@@ -836,15 +843,12 @@ class Connection(object):
if DEBUG: dump_packet(data)
self.wfile.write(data)
self.wfile.flush()
self._write_bytes(data)
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
certfile=self.cert,
ssl_version=ssl.PROTOCOL_TLSv1,
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=self.ca)
self.rfile = self.socket.makefile("rb")
self.wfile = self.socket.makefile("wb")
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
@@ -857,8 +861,7 @@ class Connection(object):
if DEBUG: dump_packet(data)
self.wfile.write(data)
self.wfile.flush()
self._write_bytes(data)
auth_packet = MysqlPacket(self)
auth_packet.check_error()
@@ -874,8 +877,7 @@ class Connection(object):
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
data = pack_int24(len(data)) + int2byte(next_packet) + data
self.wfile.write(data)
self.wfile.flush()
self._write_bytes(data)
auth_packet = MysqlPacket(self)
auth_packet.check_error()
if DEBUG: auth_packet.dump()