diff --git a/pymysql/connections.py b/pymysql/connections.py index b8aaa53..f7be54c 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -845,6 +845,7 @@ class Connection(object): result.init_unbuffered_query() except: result.unbuffered_active = False + result.connection = None raise else: result = MySQLResult(self) @@ -1033,8 +1034,7 @@ class Connection(object): class MySQLResult(object): def __init__(self, connection): - from weakref import proxy - self.connection = proxy(connection) + self.connection = connection self.affected_rows = None self.insert_id = None self.server_status = None @@ -1051,13 +1051,16 @@ class MySQLResult(object): self._finish_unbuffered_query() def read(self): - first_packet = self.connection._read_packet() + try: + first_packet = self.connection._read_packet() - # TODO: use classes for different packet types? - if first_packet.is_ok_packet(): - self._read_ok_packet(first_packet) - else: - self._read_result_packet(first_packet) + # TODO: use classes for different packet types? + if first_packet.is_ok_packet(): + self._read_ok_packet(first_packet) + else: + self._read_result_packet(first_packet) + finally: + self.connection = False def init_unbuffered_query(self): self.unbuffered_active = True @@ -1066,6 +1069,7 @@ class MySQLResult(object): if first_packet.is_ok_packet(): self._read_ok_packet(first_packet) self.unbuffered_active = False + self.connection = None else: self.field_count = first_packet.read_length_encoded_integer() self._get_descriptions() @@ -1105,6 +1109,7 @@ class MySQLResult(object): packet = self.connection._read_packet() if self._check_packet_is_eof(packet): self.unbuffered_active = False + self.connection = None self.rows = None return @@ -1121,6 +1126,7 @@ class MySQLResult(object): packet = self.connection._read_packet() if self._check_packet_is_eof(packet): self.unbuffered_active = False + self.connection = None # release reference to kill cyclic reference. def _read_rowdata_packet(self): """Read a rowdata packet for each data row in the result set.""" @@ -1128,6 +1134,7 @@ class MySQLResult(object): while True: packet = self.connection._read_packet() if self._check_packet_is_eof(packet): + self.connection = None # release reference to kill cyclic reference. break rows.append(self._read_row_from_packet(packet))