Skip to content

Commit

Permalink
Add query timeout tests for #428 (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nothing4You authored Feb 6, 2022
1 parent 6fd00fd commit 82a6195
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 2 deletions.
79 changes: 78 additions & 1 deletion tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from aiomysql import ProgrammingError, Cursor, InterfaceError
from aiomysql import ProgrammingError, Cursor, InterfaceError, OperationalError
from aiomysql.cursors import RE_INSERT_VALUES


Expand Down Expand Up @@ -354,3 +354,80 @@ async def test_executemany_percentage(connection_creator):
await cur.executemany(q, [(3, 4), (5, 6)])
assert cur._last_executed.endswith(b"(3, 4),(5, 6)"), \
"executemany with %% not in one query"


@pytest.mark.run_loop
async def test_max_execution_time(mysql_server, connection_creator):
conn = await connection_creator()
await _prepare(conn)
async with conn.cursor() as cur:
# MySQL MAX_EXECUTION_TIME takes ms
# MariaDB max_statement_time takes seconds as int/float, introduced in 10.1

# this will sleep 0.01 seconds per row
if mysql_server["db_type"] == "mysql":
sql = """
SELECT /*+ MAX_EXECUTION_TIME(2000) */
name, sleep(0.01) FROM tbl
"""
else:
sql = """
SET STATEMENT max_statement_time=2 FOR
SELECT name, sleep(0.01) FROM tbl
"""

await cur.execute(sql)
# unlike SSCursor, Cursor returns a tuple of tuples here
assert (await cur.fetchall()) == (
("a", 0),
("b", 0),
("c", 0),
)

if mysql_server["db_type"] == "mysql":
sql = """
SELECT /*+ MAX_EXECUTION_TIME(2000) */
name, sleep(0.01) FROM tbl
"""
else:
sql = """
SET STATEMENT max_statement_time=2 FOR
SELECT name, sleep(0.01) FROM tbl
"""
await cur.execute(sql)
assert (await cur.fetchone()) == ("a", 0)

# this discards the previous unfinished query
await cur.execute("SELECT 1")
assert (await cur.fetchone()) == (1,)

if mysql_server["db_type"] == "mysql":
sql = """
SELECT /*+ MAX_EXECUTION_TIME(1) */
name, sleep(1) FROM tbl
"""
else:
sql = """
SET STATEMENT max_statement_time=0.001 FOR
SELECT name, sleep(1) FROM tbl
"""
with pytest.raises(OperationalError) as cm:
# in a buffered cursor this should reliably raise an
# OperationalError
await cur.execute(sql)

if mysql_server["db_type"] == "mysql":
# this constant was only introduced in MySQL 5.7, not sure
# what was returned before, may have been ER_QUERY_INTERRUPTED

# this constant is pending a new PyMySQL release
# assert cm.value.args[0] == pymysql.constants.ER.QUERY_TIMEOUT
assert cm.value.args[0] == 3024
else:
# this constant is pending a new PyMySQL release
# assert cm.value.args[0] == pymysql.constants.ER.STATEMENT_TIMEOUT
assert cm.value.args[0] == 1969

# connection should still be fine at this point
await cur.execute("SELECT 1")
assert (await cur.fetchone()) == (1,)
109 changes: 108 additions & 1 deletion tests/test_sscursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pymysql import NotSupportedError

from aiomysql import ProgrammingError, InterfaceError
from aiomysql import ProgrammingError, InterfaceError, OperationalError
from aiomysql.cursors import SSCursor


Expand Down Expand Up @@ -199,3 +199,110 @@ async def test_sscursor_discarded_result(connection):
await cursor.execute("select 2")
ret = await cursor.fetchone()
assert (2,) == ret


@pytest.mark.skip(
reason="see aio-libs/aiomysql#428, "
"this gets stuck until aio-libs/aiomysql#646 is merged",
)
@pytest.mark.run_loop
async def test_max_execution_time(mysql_server, connection):
conn = connection

async with connection.cursor() as cur:
await cur.execute("DROP TABLE IF EXISTS tbl;")

await cur.execute(
"""
CREATE TABLE tbl (
id MEDIUMINT NOT NULL AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
PRIMARY KEY (id));
"""
)

for i in [(1, "a"), (2, "b"), (3, "c")]:
await cur.execute("INSERT INTO tbl VALUES(%s, %s)", i)

await conn.commit()

async with conn.cursor(SSCursor) as cur:
# MySQL MAX_EXECUTION_TIME takes ms
# MariaDB max_statement_time takes seconds as int/float, introduced in 10.1

# this will sleep 0.01 seconds per row
if mysql_server["db_type"] == "mysql":
sql = """
SELECT /*+ MAX_EXECUTION_TIME(2000) */
name, sleep(0.01) FROM tbl
"""
else:
sql = """
SET STATEMENT max_statement_time=2 FOR
SELECT name, sleep(0.01) FROM tbl
"""

await cur.execute(sql)
# unlike Cursor, SSCursor returns a list of tuples here

assert (await cur.fetchall()) == [
("a", 0),
("b", 0),
("c", 0),
]

if mysql_server["db_type"] == "mysql":
sql = """
SELECT /*+ MAX_EXECUTION_TIME(2000) */
name, sleep(0.01) FROM tbl
"""
else:
sql = """
SET STATEMENT max_statement_time=2 FOR
SELECT name, sleep(0.01) FROM tbl
"""
await cur.execute(sql)
assert (await cur.fetchone()) == ("a", 0)

# this discards the previous unfinished query and raises an
# incomplete unbuffered query warning
with pytest.warns(UserWarning):
await cur.execute("SELECT 1")
assert (await cur.fetchone()) == (1,)

# SSCursor will not read the EOF packet until we try to read
# another row. Skipping this will raise an incomplete unbuffered
# query warning in the next cur.execute().
assert (await cur.fetchone()) is None

if mysql_server["db_type"] == "mysql":
sql = """
SELECT /*+ MAX_EXECUTION_TIME(1) */
name, sleep(1) FROM tbl
"""
else:
sql = """
SET STATEMENT max_statement_time=0.001 FOR
SELECT name, sleep(1) FROM tbl
"""
with pytest.raises(OperationalError) as cm:
# in an unbuffered cursor the OperationalError may not show up
# until fetching the entire result
await cur.execute(sql)
await cur.fetchall()

if mysql_server["db_type"] == "mysql":
# this constant was only introduced in MySQL 5.7, not sure
# what was returned before, may have been ER_QUERY_INTERRUPTED

# this constant is pending a new PyMySQL release
# assert cm.value.args[0] == pymysql.constants.ER.QUERY_TIMEOUT
assert cm.value.args[0] == 3024
else:
# this constant is pending a new PyMySQL release
# assert cm.value.args[0] == pymysql.constants.ER.STATEMENT_TIMEOUT
assert cm.value.args[0] == 1969

# connection should still be fine at this point
await cur.execute("SELECT 1")
assert (await cur.fetchone()) == (1,)

0 comments on commit 82a6195

Please sign in to comment.