Skip to content

Commit c9bc812

Browse files
committed
Add test for ephys.main_v8_networked
1 parent 0901d1b commit c9bc812

File tree

2 files changed

+208
-23
lines changed

2 files changed

+208
-23
lines changed

iblrig/ephys.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import argparse
2-
import string
3-
import datetime
42
import asyncio
3+
import datetime
4+
import logging
5+
import string
56
from pathlib import Path
67

78
import numpy as np
8-
from one.alf.io import next_num_folder
9-
from one.api import OneAlyx
10-
from iblutil.io import net
119

1210
from iblatlas import atlas
1311
from iblrig.base_tasks import EmptySession
12+
from iblrig.net import get_server_communicator, read_stdin, update_alyx_token
13+
from iblrig.path_helper import load_pydantic_yaml
14+
from iblrig.pydantic_definitions import RigSettings
1415
from iblrig.transfer_experiments import EphysCopier
15-
from iblrig.net import get_server_communicator, update_alyx_token, read_stdin
16+
from iblutil.io import net
1617
from iblutil.util import setup_logger
18+
from one.alf.io import next_num_folder
19+
from one.api import OneAlyx
1720

1821

1922
def prepare_ephys_session_cmd():
@@ -30,7 +33,7 @@ def prepare_ephys_session_cmd():
3033
help='the service URI to listen to messages on. pass ":<port>" to specify port only.',
3134
)
3235
args = parser.parse_args()
33-
setup_logger(name='iblrig', level='DEBUG' if args.debug else 'INFO')
36+
setup_logger(name=__name__, level='DEBUG' if args.debug else 'INFO')
3437
if args.service_uri:
3538
asyncio.run(main_v8_networked(args.subject_name, args.debug, args.nprobes, args.service_uri))
3639
else:
@@ -97,8 +100,7 @@ def neuropixel24_micromanipulator_coordinates(ref_shank, pname, ba=None, shank_s
97100
async def main_v8_networked(mouse, debug=False, n_probes=2, service_uri=None):
98101
# from iblrig.base_tasks import EmptySession
99102

100-
101-
log = setup_logger(name='iblrig', level=10 if debug else 20)
103+
log = logging.getLogger(__name__)
102104

103105
# if PARAMS.get('PROBE_TYPE_00', '3B') != '3B' or PARAMS.get('PROBE_TYPE_01', '3B') != '3B':
104106
# raise NotImplementedError('Only 3B probes supported.')
@@ -109,8 +111,6 @@ async def main_v8_networked(mouse, debug=False, n_probes=2, service_uri=None):
109111
# session = EmptySession(subject=mouse, interactive=False, iblrig_settings=iblrig_settings)
110112
# session_path = session.paths.SESSION_FOLDER
111113
# FIXME The following should be done by the EmptySession class
112-
from iblrig.path_helper import load_pydantic_yaml
113-
from iblrig.pydantic_definitions import RigSettings
114114
iblrig_settings = load_pydantic_yaml(RigSettings)
115115
date = datetime.datetime.now().date().isoformat()
116116
num = next_num_folder(iblrig_settings.iblrig_local_data_path / mouse / date)
@@ -119,15 +119,15 @@ async def main_v8_networked(mouse, debug=False, n_probes=2, service_uri=None):
119119
raw_data_folder.mkdir(parents=True, exist_ok=True)
120120

121121
log.info('Created %s', raw_data_folder)
122-
REMOTE_SUBJECT_FOLDER = iblrig_settings.iblrig_remote_subjects_path
122+
remote_subject_folder = iblrig_settings.iblrig_remote_subjects_path
123123

124124
for n in range(n_probes):
125125
probe_folder = raw_data_folder / f'probe{n:02}'
126126
probe_folder.mkdir(exist_ok=True)
127127
log.info('Created %s', probe_folder)
128128

129129
# Save the stub files locally and in the remote repo for future copy script to use
130-
copier = EphysCopier(session_path=session_path, remote_subjects_folder=REMOTE_SUBJECT_FOLDER)
130+
copier = EphysCopier(session_path=session_path, remote_subjects_folder=remote_subject_folder)
131131
communicator, _ = await get_server_communicator(service_uri, 'neuropixel')
132132
copier.initialize_experiment(nprobes=n_probes)
133133

@@ -156,24 +156,24 @@ async def main_v8_networked(mouse, debug=False, n_probes=2, service_uri=None):
156156
for d in raw_data_folder.iterdir(): # Remove probe folders
157157
d.rmdir()
158158
raw_data_folder.rmdir() # remove collection
159+
# Remove remote exp description file
160+
log.debug('Removing %s', copier.file_remote_experiment_description)
161+
copier.file_remote_experiment_description.unlink()
162+
copier.file_remote_experiment_description.with_suffix('.status_pending').unlink()
159163
# Delete whole session folder?
160164
session_files = list(session_path.rglob('*'))
161-
if len(session_files) == 1 and session_files[0].name.startswith(
162-
'_ibl_experiment.description'):
165+
if len(session_files) == 1 and session_files[0].name.startswith('_ibl_experiment.description'):
163166
ans = input(f'Remove empty session {"/".join(session_path.parts[-3:])}? [y/N]\n')
164167
if (ans.strip().lower() or 'n')[0] == 'y':
165168
log.warning('Removing %s', session_path)
166169
log.debug('Removing %s', session_files[0])
167170
session_files[0].unlink()
168171
session_path.rmdir()
169-
# Remove remote exp description file
170-
log.debug('Removing %s', copier.file_remote_experiment_description)
171-
copier.file_remote_experiment_description.unlink()
172172
else:
173173
session_path.joinpath('transfer_me.flag').touch()
174174
communicator.close()
175-
for task in tasks:
176-
task.cancel()
175+
for t in tasks:
176+
t.cancel()
177177
tasks.clear()
178178
return
179179
case 'remote':
@@ -182,7 +182,7 @@ async def main_v8_networked(mouse, debug=False, n_probes=2, service_uri=None):
182182
log.error('Remote communicator closed')
183183
else:
184184
data, addr, event = task.result()
185-
S = net.base.ExpMessage
185+
S = net.base.ExpMessage # noqa
186186
match event:
187187
case S.EXPINFO:
188188
reponse_data = {'exp_ref': one.dict2ref(exp_ref), 'main_sync': True}
@@ -192,8 +192,9 @@ async def main_v8_networked(mouse, debug=False, n_probes=2, service_uri=None):
192192
case S.EXPINIT:
193193
expected = one.dict2ref(exp_ref)
194194
# TODO Make assertion
195-
if 'exp_ref' in data and data['exp_ref'] != expected:
196-
log.critical('Experiment reference mismatch! Expected %s, got %s', expected, data['exp_ref'])
195+
remote_ref = (data[0] or {}).get('exp_ref') if any(data) else None
196+
if remote_ref and remote_ref != expected:
197+
log.critical('Experiment reference mismatch! Expected %s, got %s', expected, remote_ref)
197198
data = {'exp_ref': one.dict2ref(exp_ref), 'status': net.base.ExpStatus.RUNNING}
198199
await communicator.init(data, addr=addr)
199200
case S.EXPSTART:

iblrig/test/test_ephys.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import tempfile
12
import unittest
3+
from datetime import date
4+
from pathlib import Path
5+
from unittest.mock import AsyncMock, patch
26

37
import iblrig.ephys
8+
from iblrig.test.base import TEST_DB
9+
from iblutil.io import net
10+
from iblutil.util import Bunch
411

512

613
class TestFinalizeEphysSession(unittest.TestCase):
@@ -38,3 +45,180 @@ def test_neuropixel24_micromanipulator(self):
3845
},
3946
}
4047
assert trajectories == a
48+
49+
50+
class TestPrepareEphysSessionNetworked(unittest.IsolatedAsyncioTestCase):
51+
"""Test the main_v8_networked function."""
52+
53+
def setUp(self):
54+
"""Set up keyboard input and settings mocks."""
55+
# Set up keyboad input mock
56+
# When we set self.keyboard to a non-empty string, the test function should interpret this as
57+
# keyboard input and stop the session (i.e. close the communicator and return)
58+
self.keyboard = ''
59+
read_stdin = patch('iblrig.ephys.read_stdin')
60+
self.addCleanup(read_stdin.stop)
61+
read_stdin_mock = read_stdin.start()
62+
63+
async def _stdin():
64+
if self.keyboard:
65+
yield self.keyboard
66+
67+
read_stdin_mock.side_effect = _stdin
68+
69+
# Set up settings mock
70+
tmp = tempfile.TemporaryDirectory()
71+
self.tmpdir = Path(tmp.name)
72+
(local := self.tmpdir.joinpath('local')).mkdir()
73+
(remote := self.tmpdir.joinpath('remote')).mkdir()
74+
(remote_subjects := remote.joinpath('subjects')).mkdir()
75+
self.settings = Bunch(
76+
iblrig_local_data_path=local, iblrig_remote_data_path=remote, iblrig_remote_subjects_path=remote_subjects
77+
)
78+
m = patch('iblrig.ephys.load_pydantic_yaml', return_value=self.settings)
79+
m.start()
80+
self.addCleanup(m.stop)
81+
self.addr = '192.168.0.5:99998' # Fake address of the behaviour rig
82+
83+
async def asyncSetUp(self):
84+
"""Set up communicator mock.
85+
86+
To side-step UDP communication, we mock the communicator and simulate the messages that
87+
would be sent by the behaviour rig, then assert that the response methods are called with
88+
the expected arguments.
89+
"""
90+
self.communicator = AsyncMock(spec=iblrig.ephys.net.app.EchoProtocol)
91+
self.communicator.is_connected = True
92+
m = patch('iblrig.ephys.get_server_communicator', return_value=(self.communicator, None))
93+
m.start()
94+
self.addCleanup(m.stop)
95+
96+
async def test_standard_message_sequence(self):
97+
"""Test the main_v8_networked function with the usual sequence of behaviour rig messages."""
98+
# Create some mock behaviour rig messages
99+
ref = f'{date.today()}_1_foo'
100+
info_msg = ((net.base.ExpStatus.CONNECTED, {'subject_name': 'foo'}), self.addr, net.base.ExpMessage.EXPINFO)
101+
init_msg = ([{'exp_ref': ref}], self.addr, net.base.ExpMessage.EXPINIT)
102+
start_msg = ((ref, {}), self.addr, net.base.ExpMessage.EXPSTART)
103+
status_msg = (net.base.ExpStatus.RUNNING, self.addr, net.base.ExpMessage.EXPSTATUS)
104+
# This is the order in which the messages are expected to be sent (excluding status)
105+
self.messages = (info_msg, init_msg, start_msg, status_msg)
106+
107+
messages = self._iterate_messages()
108+
self.communicator.on_event.side_effect = lambda evt: next(messages)
109+
await iblrig.ephys.main_v8_networked('foo', debug=True)
110+
111+
# The on_event method is awaited at first then each time a message is received
112+
self.communicator.on_event.assert_awaited_with(net.base.ExpMessage.any())
113+
self.assertEqual(1 + len(self.messages), self.communicator.on_event.await_count)
114+
115+
# Check that the expected methods were called with the expected arguments
116+
kwargs = dict(addr=self.addr)
117+
expected_responses = [
118+
('info', (net.base.ExpStatus.RUNNING, {'exp_ref': ref, 'main_sync': True}), kwargs),
119+
('init', ({'exp_ref': ref, 'status': net.base.ExpStatus.RUNNING},), kwargs),
120+
('start', ({'subject': 'foo', 'date': date.today(), 'sequence': 1},), kwargs),
121+
('status', (net.base.ExpStatus.RUNNING,), kwargs),
122+
('close', (), {}), # should be called after the last message (when keyboad input is simulated)
123+
]
124+
# Check odd method calls as even ones are the on_event calls
125+
for expected, actual in zip(expected_responses, map(tuple, self.communicator.method_calls[1::2]), strict=False):
126+
self.assertEqual(expected, actual)
127+
128+
# Check that the local and remote sessions were created
129+
expected = [
130+
f'local/foo/{date.today()}/001/transfer_me.flag',
131+
f'local/foo/{date.today()}/001/_ibl_experiment.description_ephys.yaml',
132+
f'remote/subjects/foo/{date.today()}/001/_devices/[email protected]_pending',
133+
f'remote/subjects/foo/{date.today()}/001/_devices/[email protected]',
134+
]
135+
self.assertCountEqual(map(self.tmpdir.joinpath, expected), self.tmpdir.rglob('*.*'))
136+
# Should have created the raw ephys folders
137+
self.assertEqual(2, len(list(self.tmpdir.glob(f'local/foo/{date.today()}/001/raw_ephys_data/probe??'))))
138+
139+
async def test_abort_session(self):
140+
"""Test the main_v8_networked function with misc events and user 'abort' input."""
141+
# Create some mock behaviour rig messages where the behaviour rig runs subject 'bar' instead of 'foo'
142+
# No exception should be raised (this happens at the behaviour rig) but this should be logged
143+
ref = f'{date.today()}_1_bar'
144+
info_msg = ((net.base.ExpStatus.CONNECTED, {'subject_name': 'bar'}), self.addr, net.base.ExpMessage.EXPINFO)
145+
init_msg = ([{'exp_ref': ref}], self.addr, net.base.ExpMessage.EXPINIT)
146+
start_msg = ((ref, {}), self.addr, net.base.ExpMessage.EXPSTART)
147+
interrupt_msg = ((), self.addr, net.base.ExpMessage.EXPINTERRUPT)
148+
cleanup_msg = ((), self.addr, net.base.ExpMessage.EXPCLEANUP)
149+
self.messages = (info_msg, init_msg, start_msg, interrupt_msg, cleanup_msg)
150+
151+
messages = self._iterate_messages(keyboard_input='ABORT\n')
152+
self.communicator.on_event.side_effect = lambda evt: next(messages)
153+
# Should log exp ref mismatch
154+
with self.assertLogs('iblrig.ephys', level='CRITICAL'), patch('builtins.input', return_value='y') as input:
155+
await iblrig.ephys.main_v8_networked('foo', debug=True)
156+
input.assert_called_once()
157+
158+
# The on_event method is awaited at first then each time a message is received
159+
self.communicator.on_event.assert_awaited_with(net.base.ExpMessage.any())
160+
self.assertEqual(1 + len(self.messages), self.communicator.on_event.await_count)
161+
162+
# Check that the expected methods were called with the expected arguments
163+
kwargs = dict(addr=self.addr)
164+
ref = f'{date.today()}_1_foo'
165+
expected_responses = [
166+
('info', (net.base.ExpStatus.RUNNING, {'exp_ref': ref, 'main_sync': True}), kwargs),
167+
('init', ({'exp_ref': ref, 'status': net.base.ExpStatus.RUNNING},), kwargs),
168+
('start', ({'subject': 'foo', 'date': date.today(), 'sequence': 1},), kwargs),
169+
('confirmed_send', ((net.base.ExpMessage.EXPINTERRUPT, {'status': net.base.ExpStatus.RUNNING}),), kwargs),
170+
('confirmed_send', ((net.base.ExpMessage.EXPCLEANUP, {'status': net.base.ExpStatus.RUNNING}),), kwargs),
171+
('close', (), {}), # should be called after the last message (when keyboad input is simulated)
172+
]
173+
# Check odd method calls as even ones are the on_event calls
174+
for expected, actual in zip(expected_responses, map(tuple, self.communicator.method_calls[1::2]), strict=False):
175+
self.assertEqual(expected, actual)
176+
177+
# Check that the local and remote sessions were removed
178+
self.assertFalse(any(self.tmpdir.rglob('*.*')))
179+
self.assertFalse(self.tmpdir.joinpath(f'local/foo/{date.today()}/001').exists())
180+
181+
# Check behaviour when user does not confirm cleanup
182+
messages = self._iterate_messages(keyboard_input='ABORT\n')
183+
self.communicator.on_event.side_effect = lambda evt: next(messages)
184+
with patch('builtins.input', return_value='') as input:
185+
await iblrig.ephys.main_v8_networked('foo', debug=True)
186+
input.assert_called_once()
187+
self.assertTrue(any(self.tmpdir.rglob('*.*')))
188+
self.assertTrue(self.tmpdir.joinpath(f'local/foo/{date.today()}/001').exists())
189+
190+
async def test_alyx_request(self):
191+
"""Test the main_v8_networked function alyx request message."""
192+
# Create some mock behaviour rig messages that request and provide Alyx credentials
193+
alyx_req = ((None, {}), self.addr, net.base.ExpMessage.ALYX)
194+
alyx_mes = ((TEST_DB['base_url'], {'test_user': {'token': 't0k3n'}}), self.addr, net.base.ExpMessage.ALYX)
195+
# Behaviour should be thus:
196+
# 1. Request not processed as Alyx offline by default
197+
# 2. Alyx object updated with remote token
198+
# 3. Request processed with updated Alyx object (now logged in)
199+
self.messages = (alyx_req, alyx_mes, alyx_req)
200+
201+
messages = self._iterate_messages()
202+
self.communicator.on_event.side_effect = lambda evt: next(messages)
203+
with patch('iblrig.ephys.update_alyx_token', wraps=iblrig.ephys.update_alyx_token) as m:
204+
await iblrig.ephys.main_v8_networked('foo', debug=True)
205+
m.assert_called_once()
206+
207+
# Check that the expected methods were called with the expected arguments
208+
self.communicator.alyx.assert_awaited_once()
209+
(alyx,), addr = self.communicator.alyx.call_args
210+
self.assertTrue(alyx.is_logged_in)
211+
self.assertEqual(TEST_DB['base_url'], alyx.base_url)
212+
self.assertEqual({'token': 't0k3n'}, alyx._token)
213+
214+
def _iterate_messages(self, keyboard_input='\n'):
215+
"""Yield behaviour rig UDP messages with added side effect simulating keyboard input after."""
216+
# When first called we shouold not have awaited any methods on the communicator yet
217+
for method in ('info', 'init', 'start', 'status', 'alyx', 'confirmed_send'):
218+
getattr(self.communicator, method).assert_not_awaited()
219+
# Yeild the messages in order
220+
for msg in self.messages: # noqa: UP028
221+
yield msg # ruff complains here but the suggested `yield from` does not work
222+
# After the last message is processed, terminate by simulating keyboard input
223+
self.keyboard = keyboard_input
224+
yield self.messages[-1]

0 commit comments

Comments
 (0)