Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MySQL 8.0 tests, properly close timed out connections #660

Merged
merged 7 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ To be included in 1.0.0 (unreleased)

* Don't send sys.argv[0] as program_name to MySQL server by default #620
* Allow running process as anonymous uid #587
* Fix timed out MySQL 8.0 connections raising InternalError rather than OperationalError #660
* Fix timed out MySQL 8.0 connections being returned from Pool #660
* Ensure connections are properly closed before raising an OperationalError when the server connection is lost #660
* Ensure connections are properly closed before raising an InternalError when packet sequence numbers are out of sync #660


0.0.22 (2021-11-14)
Expand Down
73 changes: 67 additions & 6 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pymysql.constants import SERVER_STATUS
from pymysql.constants import CLIENT
from pymysql.constants import COMMAND
from pymysql.constants import CR
from pymysql.constants import FIELD_TYPE
from pymysql.util import byte2int, int2byte
from pymysql.converters import (escape_item, encoders, decoders,
Expand Down Expand Up @@ -79,6 +80,57 @@ async def _connect(*args, **kwargs):
return conn


async def _open_connection(host=None, port=None, **kwds):
"""This is based on asyncio.open_connection, allowing us to use a custom
StreamReader.

`limit` arg has been removed as we don't currently use it.
"""
loop = asyncio.events.get_running_loop()
reader = _StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_connection(
lambda: protocol, host, port, **kwds)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer


async def _open_unix_connection(path=None, **kwds):
"""This is based on asyncio.open_unix_connection, allowing us to use a custom
StreamReader.

`limit` arg has been removed as we don't currently use it.
"""
loop = asyncio.events.get_running_loop()

reader = _StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer


class _StreamReader(asyncio.StreamReader):
"""This StreamReader exposes whether EOF was received, allowing us to
discard the associated connection instead of returning it from the pool
when checking free connections in Pool._fill_free_pool().

`limit` arg has been removed as we don't currently use it.
"""
def __init__(self, loop=None):
self._eof_received = False
super().__init__(loop=loop)

def feed_eof(self) -> None:
self._eof_received = True
super().feed_eof()

@property
def eof_received(self):
return self._eof_received


class Connection:
"""Representation of a socket with a mysql server.

Expand Down Expand Up @@ -471,21 +523,21 @@ async def set_charset(self, charset):

async def _connect(self):
# TODO: Set close callback
# raise OperationalError(2006,
# raise OperationalError(CR.CR_SERVER_GONE_ERROR,
# "MySQL server has gone away (%r)" % (e,))
try:
if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
self._reader, self._writer = await \
asyncio.wait_for(
asyncio.open_unix_connection(
_open_unix_connection(
self._unix_socket),
timeout=self.connect_timeout)
self.host_info = "Localhost via UNIX socket: " + \
self._unix_socket
else:
self._reader, self._writer = await \
asyncio.wait_for(
asyncio.open_connection(
_open_connection(
self._host,
self._port),
timeout=self.connect_timeout)
Expand Down Expand Up @@ -570,6 +622,13 @@ async def _read_packet(self, packet_type=MysqlPacket):
# we increment in both write_packet and read_packet. The count
# is reset at new COMMAND PHASE.
if packet_number != self._next_seq_id:
self.close()
if packet_number == 0:
# MySQL 8.0 sends error packet with seqno==0 when shutdown
raise OperationalError(
CR.CR_SERVER_LOST,
"Lost connection to MySQL server during query")

raise InternalError(
"Packet sequence number wrong - got %d expected %d" %
(packet_number, self._next_seq_id))
Expand Down Expand Up @@ -597,10 +656,12 @@ async def _read_bytes(self, num_bytes):
data = await self._reader.readexactly(num_bytes)
except asyncio.IncompleteReadError as e:
msg = "Lost connection to MySQL server during query"
raise OperationalError(2013, msg) from e
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except (IOError, OSError) as e:
msg = "Lost connection to MySQL server during query (%s)" % (e,)
raise OperationalError(2013, msg) from e
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
return data

def _write_bytes(self, data):
Expand Down Expand Up @@ -704,7 +765,7 @@ async def _request_authentication(self):
# TCP connection not at start. Passing in a socket to
# open_connection will cause it to negotiate TLS on an existing
# connection not initiate a new one.
self._reader, self._writer = await asyncio.open_connection(
self._reader, self._writer = await _open_connection(
sock=raw_sock, ssl=self._ssl_context,
server_hostname=self._host
)
Expand Down
10 changes: 9 additions & 1 deletion aiomysql/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def _acquire(self):
await self._cond.wait()

async def _fill_free_pool(self, override_min):
# iterate over free connections and remove timeouted ones
# iterate over free connections and remove timed out ones
free_size = len(self._free)
n = 0
while n < free_size:
Expand All @@ -152,6 +152,14 @@ async def _fill_free_pool(self, override_min):
self._free.pop()
conn.close()

# On MySQL 8.0 a timed out connection sends an error packet before
# closing the connection, preventing us from relying on at_eof().
# This relies on our custom StreamReader, as eof_received is not
# present in asyncio.StreamReader.
elif conn._reader.eof_received:
Nothing4You marked this conversation as resolved.
Show resolved Hide resolved
self._free.pop()
conn.close()

elif (self._recycle > -1 and
self._loop.time() - conn.last_usage > self._recycle):
self._free.pop()
Expand Down