Skip to content

Commit

Permalink
Rewrote tests to use pytest exclusively
Browse files Browse the repository at this point in the history
  • Loading branch information
terricain committed Apr 25, 2020
1 parent 99c6198 commit 819d4ec
Show file tree
Hide file tree
Showing 15 changed files with 2,520 additions and 2,640 deletions.
19 changes: 6 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ def pytest_generate_tests(metafunc):
if 'loop_type' in metafunc.fixturenames:
loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio']
metafunc.parametrize("loop_type", loop_type)
#
# if 'mysql_tag' in metafunc.fixturenames:
# tags = set(metafunc.config.option.mysql_tag)
# if not tags:
# tags = ['5.6', '8.0']
# elif 'all' in tags:
# tags = ['5.6', '5.7', '8.0']
# else:
# tags = list(tags)
# metafunc.parametrize("mysql_tag", tags, scope='session')


# This is here unless someone fixes the generate_tests bit
Expand Down Expand Up @@ -172,7 +162,10 @@ def f(**kw):
yield f

for conn in connections:
loop.run_until_complete(conn.ensure_closed())
try:
loop.run_until_complete(conn.ensure_closed())
except ConnectionResetError:
pass


@pytest.yield_fixture
Expand Down Expand Up @@ -248,13 +241,13 @@ def mysql_server(unused_port, docker, session_id,
tls_cnf = os.path.join(os.path.dirname(__file__),
'ssl_resources', 'tls.cnf')

ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ctx.check_hostname = False
ctx.load_verify_locations(cafile=ca_file)
# ctx.verify_mode = ssl.CERT_NONE

container_args = dict(
image='mysql:{}'.format(mysql_tag),
image='{}:{}'.format(mysql_image, mysql_tag),
name='aiomysql-test-server-{}-{}'.format(mysql_tag, session_id),
ports=[3306],
detach=True,
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/my.cnf.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ port = {port}
host = {host}
password = {password}
database = {db}
socket = /var/run/mysqld/mysqld.sock
default-character-set = utf8

[client_with_unix_socket]
user = {user}
port = {port}
host = {host}
password = {password}
database = {db}
socket = /var/run/mysqld/mysqld.sock
default-character-set = utf8
240 changes: 113 additions & 127 deletions tests/sa/test_sa_compiled_cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
from aiomysql import sa
import pytest
from sqlalchemy import bindparam
from sqlalchemy import MetaData, Table, Column, Integer, String

import os
import unittest
from aiomysql import sa

from sqlalchemy import MetaData, Table, Column, Integer, String

meta = MetaData()
tbl = Table('sa_tbl_cache_test', meta,
Expand All @@ -14,125 +12,113 @@
Column('val', String(255)))


class TestCompiledCache(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.host = os.environ.get('MYSQL_HOST', 'localhost')
self.port = int(os.environ.get('MYSQL_PORT', 3306))
self.user = os.environ.get('MYSQL_USER', 'root')
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
self.password = os.environ.get('MYSQL_PASSWORD', '')
self.engine = self.loop.run_until_complete(self.make_engine())
self.loop.run_until_complete(self.start())

def tearDown(self):
self.engine.terminate()
self.loop.run_until_complete(self.engine.wait_closed())
self.loop.close()

async def make_engine(self, **kwargs):
return (await sa.create_engine(db=self.db,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
loop=self.loop,
minsize=10,
**kwargs))

async def start(self):
async with self.engine.acquire() as conn:
tx = await conn.begin()
await conn.execute("DROP TABLE IF EXISTS "
"sa_tbl_cache_test")
await conn.execute("CREATE TABLE sa_tbl_cache_test"
"(id serial, val varchar(255))")
await conn.execute(tbl.insert().values(val='some_val_1'))
await conn.execute(tbl.insert().values(val='some_val_2'))
await conn.execute(tbl.insert().values(val='some_val_3'))
await tx.commit()

def test_cache(self):
async def go():
cache = dict()
engine = await self.make_engine(compiled_cache=cache)
async with engine.acquire() as conn:
# check select with params not added to cache
q = tbl.select().where(tbl.c.val == 'some_val_1')
cursor = await conn.execute(q)
row = await cursor.fetchone()
self.assertEqual('some_val_1', row.val)
self.assertEqual(0, len(cache))

# check select with bound params added to cache
select_by_val = tbl.select().where(
tbl.c.val == bindparam('value')
)
cursor = await conn.execute(
select_by_val, {'value': 'some_val_3'}
)
row = await cursor.fetchone()
self.assertEqual('some_val_3', row.val)
self.assertEqual(1, len(cache))

cursor = await conn.execute(
select_by_val, value='some_val_2'
)
row = await cursor.fetchone()
self.assertEqual('some_val_2', row.val)
self.assertEqual(1, len(cache))

select_all = tbl.select()
cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
self.assertEqual(3, len(rows))
self.assertEqual(2, len(cache))

# check insert with bound params not added to cache
await conn.execute(tbl.insert().values(val='some_val_4'))
self.assertEqual(2, len(cache))

# check insert with bound params added to cache
q = tbl.insert().values(val=bindparam('value'))
await conn.execute(q, value='some_val_5')
self.assertEqual(3, len(cache))

await conn.execute(q, value='some_val_6')
self.assertEqual(3, len(cache))

await conn.execute(q, {'value': 'some_val_7'})
self.assertEqual(3, len(cache))

cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
self.assertEqual(7, len(rows))
self.assertEqual(3, len(cache))

# check update with params not added to cache
q = tbl.update().where(
tbl.c.val == 'some_val_1'
).values(val='updated_val_1')
await conn.execute(q)
self.assertEqual(3, len(cache))
cursor = await conn.execute(
select_by_val, value='updated_val_1'
)
row = await cursor.fetchone()
self.assertEqual('updated_val_1', row.val)

# check update with bound params added to cache
q = tbl.update().where(
tbl.c.val == bindparam('value')
).values(val=bindparam('update'))
await conn.execute(
q, value='some_val_2', update='updated_val_2'
)
self.assertEqual(4, len(cache))
cursor = await conn.execute(
select_by_val, value='updated_val_2'
)
row = await cursor.fetchone()
self.assertEqual('updated_val_2', row.val)

self.loop.run_until_complete(go())
@pytest.fixture()
def make_engine(mysql_params, connection):
async def _make_engine(**kwargs):
return (await sa.create_engine(db=mysql_params['db'],
user=mysql_params['user'],
password=mysql_params['password'],
host=mysql_params['host'],
port=mysql_params['port'],
minsize=10,
**kwargs))

return _make_engine


async def start(engine):
async with engine.acquire() as conn:
tx = await conn.begin()
await conn.execute("DROP TABLE IF EXISTS "
"sa_tbl_cache_test")
await conn.execute("CREATE TABLE sa_tbl_cache_test"
"(id serial, val varchar(255))")
await conn.execute(tbl.insert().values(val='some_val_1'))
await conn.execute(tbl.insert().values(val='some_val_2'))
await conn.execute(tbl.insert().values(val='some_val_3'))
await tx.commit()


@pytest.mark.run_loop
async def test_dialect(make_engine):
cache = dict()
engine = await make_engine(compiled_cache=cache)
await start(engine)

async with engine.acquire() as conn:
# check select with params not added to cache
q = tbl.select().where(tbl.c.val == 'some_val_1')
cursor = await conn.execute(q)
row = await cursor.fetchone()
assert 'some_val_1' == row.val
assert 0 == len(cache)

# check select with bound params added to cache
select_by_val = tbl.select().where(
tbl.c.val == bindparam('value')
)
cursor = await conn.execute(
select_by_val, {'value': 'some_val_3'}
)
row = await cursor.fetchone()
assert 'some_val_3' == row.val
assert 1 == len(cache)

cursor = await conn.execute(
select_by_val, value='some_val_2'
)
row = await cursor.fetchone()
assert 'some_val_2' == row.val
assert 1 == len(cache)

select_all = tbl.select()
cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
assert 3 == len(rows)
assert 2 == len(cache)

# check insert with bound params not added to cache
await conn.execute(tbl.insert().values(val='some_val_4'))
assert 2 == len(cache)

# check insert with bound params added to cache
q = tbl.insert().values(val=bindparam('value'))
await conn.execute(q, value='some_val_5')
assert 3 == len(cache)

await conn.execute(q, value='some_val_6')
assert 3 == len(cache)

await conn.execute(q, {'value': 'some_val_7'})
assert 3 == len(cache)

cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
assert 7 == len(rows)
assert 3 == len(cache)

# check update with params not added to cache
q = tbl.update().where(
tbl.c.val == 'some_val_1'
).values(val='updated_val_1')
await conn.execute(q)
assert 3 == len(cache)
cursor = await conn.execute(
select_by_val, value='updated_val_1'
)
row = await cursor.fetchone()
assert 'updated_val_1' == row.val

# check update with bound params added to cache
q = tbl.update().where(
tbl.c.val == bindparam('value')
).values(val=bindparam('update'))
await conn.execute(
q, value='some_val_2', update='updated_val_2'
)
assert 4 == len(cache)
cursor = await conn.execute(
select_by_val, value='updated_val_2'
)
row = await cursor.fetchone()
assert 'updated_val_2' == row.val
Loading

0 comments on commit 819d4ec

Please sign in to comment.