Skip to content

Commit 6b272b6

Browse files
committed
fix integration test connection mock
1 parent 41b6fc7 commit 6b272b6

File tree

2 files changed

+31
-61
lines changed

2 files changed

+31
-61
lines changed

integrations/acquisition/rvdss/test_scenarios.py

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# standard library
33
import unittest
44
from unittest.mock import MagicMock, patch
5-
import mock
65
from copy import copy
76

87
# first party
@@ -11,12 +10,10 @@
1110
import delphi.operations.secrets as secrets
1211
from delphi_utils import get_structured_logger
1312

14-
1513
# third party
1614
import mysql.connector
1715
import pandas as pd
1816
from pathlib import Path
19-
import pdb
2017

2118
# py3tester coverage target (equivalent to `import *`)
2219
# __test_target__ = 'delphi.epidata.acquisition.covid_hosp.facility.update'
@@ -25,7 +22,7 @@
2522

2623
class AcquisitionTests(unittest.TestCase):
2724
logger = get_structured_logger()
28-
25+
2926
def setUp(self):
3027
"""Perform per-test setup."""
3128

@@ -54,76 +51,51 @@ def setUp(self):
5451
epidata_cnx.commit()
5552
epidata_cur.close()
5653
#epidata_cnx.close()
57-
54+
5855
# make connection and cursor available to test cases
5956
self.cnx = epidata_cnx
60-
#self.cur = epidata_cnx.cursor()
61-
57+
self.cur = epidata_cnx.cursor()
58+
6259
def tearDown(self):
6360
"""Perform per-test teardown."""
64-
#self.cur.close()
61+
self.cur.close()
6562
self.cnx.close()
6663

67-
@mock.patch("mysql.connector.connect")
68-
def test_rvdss_repiratory_detections(self,mock_sql):
64+
@patch("mysql.connector.connect")
65+
def test_rvdss_repiratory_detections(self, mock_sql):
66+
connection_mock = MagicMock()
67+
6968
TEST_DIR = Path(__file__).parent.parent.parent.parent
70-
detection_data = pd.read_csv(str(TEST_DIR) + "/testdata/acquisition/rvdss/RVD_CurrentWeekTable_Formatted.csv")
69+
detection_data = pd.read_csv(str(TEST_DIR) + "/testdata/acquisition/rvdss/RVD_CurrentWeekTable_Formatted.csv")
7170
detection_data['time_type'] = "week"
72-
detection_subset = detection_data[(detection_data['geo_value'].isin(['nl', 'nb'])) & (detection_data['time_value'].isin([202408-31, 20240907])) ]
71+
detection_subset = detection_data[(detection_data['geo_value'].isin(['nl', 'nb'])) & (detection_data['time_value'].isin([20240831, 20240907])) ]
7372

74-
connection_mock = MagicMock()
7573
# make sure the data does not yet exist
7674
with self.subTest(name='no data yet'):
7775
response = Epidata.rvdss(geo_type='province',
7876
time_values= [202435, 202436],
7977
geo_value = ['nl','nb'])
8078
self.assertEqual(response['result'], -2, response)
8179

82-
8380
# acquire sample data into local database
84-
with self.subTest(name='first acquisition'):
85-
#mock_sql.cursor.return_value = self.cnx.cursor()
81+
with self.subTest(name='first acquisition'):
82+
# When the MagicMock connection's `cursor()` method is called, return
83+
# a real cursor made from the current open connection `cnx`.
8684
connection_mock.cursor.return_value = self.cnx.cursor()
85+
# Commit via the current open connection `cnx`, from which the cursor
86+
# is derived
87+
connection_mock.commit = self.cnx.commit
8788
mock_sql.return_value = connection_mock
88-
89-
rvdss_cols_subset = [col for col in detection_subset.columns if col in rvdss_cols]
90-
pdb.set_trace()
91-
update(detection_subset,self.logger)
92-
93-
response = Epidata.rvdss(geo_type='province',
94-
time_values= [202435, 202436],
95-
geo_value = ['nl','nb'])
96-
97-
self.assertEqual(response['result'], 1)
98-
99-
with self.subTest(name='first acquisition2'):
100-
#mock_sql.cursor.return_value = self.cnx.cursor()
101-
connection_mock.cursor.return_value = self.cnx.cursor()
102-
mock_sql.return_value = connection_mock
103-
104-
rvdss_cols_subset = [col for col in detection_subset.columns if col in rvdss_cols]
105-
update(detection_subset,self.logger)
106-
107-
response = Epidata.rvdss(geo_type='province',
108-
time_values= [202435, 202436],
109-
geo_value = ['nl','nb'])
110-
111-
self.assertEqual(response['result'], 1)
112-
113-
with self.subTest(name='first acquisition3'):
114-
#mock_sql.cursor.return_value = self.cnx.cursor()
115-
connection_mock.cursor.return_value = self.cnx.cursor()
116-
mock_sql.return_value = connection_mock
117-
118-
rvdss_cols_subset = [col for col in detection_subset.columns if col in rvdss_cols]
119-
update(detection_subset,self.logger)
120-
89+
90+
update(detection_subset, self.logger)
91+
12192
response = Epidata.rvdss(geo_type='province',
12293
time_values= [202435, 202436],
12394
geo_value = ['nl','nb'])
124-
95+
12596
self.assertEqual(response['result'], 1)
12697

98+
12799
# # make sure the data now exists
128100
# with self.subTest(name='initial data checks'):
129101
# expected_spotchecks = {
@@ -164,5 +136,3 @@ def test_rvdss_repiratory_detections(self,mock_sql):
164136
# '450822', Epidata.range(20200101, 20210101))
165137
# self.assertEqual(response['result'], 1)
166138
# self.assertEqual(len(response['epidata']), 2)
167-
pass
168-

src/acquisition/rvdss/database.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,36 +76,36 @@ def get_num_rows(cursor):
7676
pass
7777
return num
7878

79-
def update(data,logger):
79+
def update(data, logger):
8080
# connect to the database
8181
u, p = secrets.db.epi
8282
cnx = mysql.connector.connect(user=u, password=p, database="epidata")
8383
cur = cnx.cursor()
8484

8585
rvdss_cols_subset = [col for col in data.columns if col in rvdss_cols]
8686
data = data.to_dict(orient = "records")
87-
87+
8888
field_names = ", ".join(f"`{name}`" for name in rvdss_cols_subset)
8989
field_values = ", ".join(f"%({name})s" for name in rvdss_cols_subset)
90-
90+
9191
#check rvdss for new and/or revised data
9292
sql = f"""
9393
INSERT INTO rvdss ({field_names})
9494
VALUES ({field_values})
9595
"""
96-
96+
9797
# keep track of how many rows were added
9898
rows_before = get_num_rows(cur)
9999
total_rows = 0
100-
101-
#insert data
100+
101+
#insert data
102102
cur.executemany(sql, data)
103-
cnx.commit()
104-
103+
105104
# keep track of how many rows were added
106105
rows_after = get_num_rows(cur)
107106
logger.info(f"Inserted {int(rows_after - rows_before)}/{int(total_rows)} row(s) into table rvdss")
108-
107+
109108
# cleanup
109+
cnx.commit()
110110
cur.close()
111111
cnx.close()

0 commit comments

Comments
 (0)