|
| 1 | +import tempfile |
1 | 2 | import unittest |
| 3 | +from datetime import date |
| 4 | +from pathlib import Path |
| 5 | +from unittest.mock import AsyncMock, patch |
2 | 6 |
|
3 | 7 | import iblrig.ephys |
| 8 | +from iblrig.test.base import TEST_DB |
| 9 | +from iblutil.io import net |
| 10 | +from iblutil.util import Bunch |
4 | 11 |
|
5 | 12 |
|
6 | 13 | class TestFinalizeEphysSession(unittest.TestCase): |
@@ -38,3 +45,180 @@ def test_neuropixel24_micromanipulator(self): |
38 | 45 | }, |
39 | 46 | } |
40 | 47 | assert trajectories == a |
| 48 | + |
| 49 | + |
| 50 | +class TestPrepareEphysSessionNetworked(unittest.IsolatedAsyncioTestCase): |
| 51 | + """Test the main_v8_networked function.""" |
| 52 | + |
| 53 | + def setUp(self): |
| 54 | + """Set up keyboard input and settings mocks.""" |
| 55 | + # Set up keyboad input mock |
| 56 | + # When we set self.keyboard to a non-empty string, the test function should interpret this as |
| 57 | + # keyboard input and stop the session (i.e. close the communicator and return) |
| 58 | + self.keyboard = '' |
| 59 | + read_stdin = patch('iblrig.ephys.read_stdin') |
| 60 | + self.addCleanup(read_stdin.stop) |
| 61 | + read_stdin_mock = read_stdin.start() |
| 62 | + |
| 63 | + async def _stdin(): |
| 64 | + if self.keyboard: |
| 65 | + yield self.keyboard |
| 66 | + |
| 67 | + read_stdin_mock.side_effect = _stdin |
| 68 | + |
| 69 | + # Set up settings mock |
| 70 | + tmp = tempfile.TemporaryDirectory() |
| 71 | + self.tmpdir = Path(tmp.name) |
| 72 | + (local := self.tmpdir.joinpath('local')).mkdir() |
| 73 | + (remote := self.tmpdir.joinpath('remote')).mkdir() |
| 74 | + (remote_subjects := remote.joinpath('subjects')).mkdir() |
| 75 | + self.settings = Bunch( |
| 76 | + iblrig_local_data_path=local, iblrig_remote_data_path=remote, iblrig_remote_subjects_path=remote_subjects |
| 77 | + ) |
| 78 | + m = patch('iblrig.ephys.load_pydantic_yaml', return_value=self.settings) |
| 79 | + m.start() |
| 80 | + self.addCleanup(m.stop) |
| 81 | + self.addr = '192.168.0.5:99998' # Fake address of the behaviour rig |
| 82 | + |
| 83 | + async def asyncSetUp(self): |
| 84 | + """Set up communicator mock. |
| 85 | +
|
| 86 | + To side-step UDP communication, we mock the communicator and simulate the messages that |
| 87 | + would be sent by the behaviour rig, then assert that the response methods are called with |
| 88 | + the expected arguments. |
| 89 | + """ |
| 90 | + self.communicator = AsyncMock(spec=iblrig.ephys.net.app.EchoProtocol) |
| 91 | + self.communicator.is_connected = True |
| 92 | + m = patch('iblrig.ephys.get_server_communicator', return_value=(self.communicator, None)) |
| 93 | + m.start() |
| 94 | + self.addCleanup(m.stop) |
| 95 | + |
| 96 | + async def test_standard_message_sequence(self): |
| 97 | + """Test the main_v8_networked function with the usual sequence of behaviour rig messages.""" |
| 98 | + # Create some mock behaviour rig messages |
| 99 | + ref = f'{date.today()}_1_foo' |
| 100 | + info_msg = ((net.base.ExpStatus.CONNECTED, {'subject_name': 'foo'}), self.addr, net.base.ExpMessage.EXPINFO) |
| 101 | + init_msg = ([{'exp_ref': ref}], self.addr, net.base.ExpMessage.EXPINIT) |
| 102 | + start_msg = ((ref, {}), self.addr, net.base.ExpMessage.EXPSTART) |
| 103 | + status_msg = (net.base.ExpStatus.RUNNING, self.addr, net.base.ExpMessage.EXPSTATUS) |
| 104 | + # This is the order in which the messages are expected to be sent (excluding status) |
| 105 | + self.messages = (info_msg, init_msg, start_msg, status_msg) |
| 106 | + |
| 107 | + messages = self._iterate_messages() |
| 108 | + self.communicator.on_event.side_effect = lambda evt: next(messages) |
| 109 | + await iblrig.ephys.main_v8_networked('foo', debug=True) |
| 110 | + |
| 111 | + # The on_event method is awaited at first then each time a message is received |
| 112 | + self.communicator.on_event.assert_awaited_with(net.base.ExpMessage.any()) |
| 113 | + self.assertEqual(1 + len(self.messages), self.communicator.on_event.await_count) |
| 114 | + |
| 115 | + # Check that the expected methods were called with the expected arguments |
| 116 | + kwargs = dict(addr=self.addr) |
| 117 | + expected_responses = [ |
| 118 | + ('info', (net.base.ExpStatus.RUNNING, {'exp_ref': ref, 'main_sync': True}), kwargs), |
| 119 | + ('init', ({'exp_ref': ref, 'status': net.base.ExpStatus.RUNNING},), kwargs), |
| 120 | + ('start', ({'subject': 'foo', 'date': date.today(), 'sequence': 1},), kwargs), |
| 121 | + ('status', (net.base.ExpStatus.RUNNING,), kwargs), |
| 122 | + ('close', (), {}), # should be called after the last message (when keyboad input is simulated) |
| 123 | + ] |
| 124 | + # Check odd method calls as even ones are the on_event calls |
| 125 | + for expected, actual in zip(expected_responses, map(tuple, self.communicator.method_calls[1::2]), strict=False): |
| 126 | + self.assertEqual(expected, actual) |
| 127 | + |
| 128 | + # Check that the local and remote sessions were created |
| 129 | + expected = [ |
| 130 | + f'local/foo/{date.today()}/001/transfer_me.flag', |
| 131 | + f'local/foo/{date.today()}/001/_ibl_experiment.description_ephys.yaml', |
| 132 | + f'remote/subjects/foo/{date.today()}/001/_devices/[email protected]_pending', |
| 133 | + f'remote/subjects/foo/{date.today()}/001/_devices/[email protected]', |
| 134 | + ] |
| 135 | + self.assertCountEqual(map(self.tmpdir.joinpath, expected), self.tmpdir.rglob('*.*')) |
| 136 | + # Should have created the raw ephys folders |
| 137 | + self.assertEqual(2, len(list(self.tmpdir.glob(f'local/foo/{date.today()}/001/raw_ephys_data/probe??')))) |
| 138 | + |
| 139 | + async def test_abort_session(self): |
| 140 | + """Test the main_v8_networked function with misc events and user 'abort' input.""" |
| 141 | + # Create some mock behaviour rig messages where the behaviour rig runs subject 'bar' instead of 'foo' |
| 142 | + # No exception should be raised (this happens at the behaviour rig) but this should be logged |
| 143 | + ref = f'{date.today()}_1_bar' |
| 144 | + info_msg = ((net.base.ExpStatus.CONNECTED, {'subject_name': 'bar'}), self.addr, net.base.ExpMessage.EXPINFO) |
| 145 | + init_msg = ([{'exp_ref': ref}], self.addr, net.base.ExpMessage.EXPINIT) |
| 146 | + start_msg = ((ref, {}), self.addr, net.base.ExpMessage.EXPSTART) |
| 147 | + interrupt_msg = ((), self.addr, net.base.ExpMessage.EXPINTERRUPT) |
| 148 | + cleanup_msg = ((), self.addr, net.base.ExpMessage.EXPCLEANUP) |
| 149 | + self.messages = (info_msg, init_msg, start_msg, interrupt_msg, cleanup_msg) |
| 150 | + |
| 151 | + messages = self._iterate_messages(keyboard_input='ABORT\n') |
| 152 | + self.communicator.on_event.side_effect = lambda evt: next(messages) |
| 153 | + # Should log exp ref mismatch |
| 154 | + with self.assertLogs('iblrig.ephys', level='CRITICAL'), patch('builtins.input', return_value='y') as input: |
| 155 | + await iblrig.ephys.main_v8_networked('foo', debug=True) |
| 156 | + input.assert_called_once() |
| 157 | + |
| 158 | + # The on_event method is awaited at first then each time a message is received |
| 159 | + self.communicator.on_event.assert_awaited_with(net.base.ExpMessage.any()) |
| 160 | + self.assertEqual(1 + len(self.messages), self.communicator.on_event.await_count) |
| 161 | + |
| 162 | + # Check that the expected methods were called with the expected arguments |
| 163 | + kwargs = dict(addr=self.addr) |
| 164 | + ref = f'{date.today()}_1_foo' |
| 165 | + expected_responses = [ |
| 166 | + ('info', (net.base.ExpStatus.RUNNING, {'exp_ref': ref, 'main_sync': True}), kwargs), |
| 167 | + ('init', ({'exp_ref': ref, 'status': net.base.ExpStatus.RUNNING},), kwargs), |
| 168 | + ('start', ({'subject': 'foo', 'date': date.today(), 'sequence': 1},), kwargs), |
| 169 | + ('confirmed_send', ((net.base.ExpMessage.EXPINTERRUPT, {'status': net.base.ExpStatus.RUNNING}),), kwargs), |
| 170 | + ('confirmed_send', ((net.base.ExpMessage.EXPCLEANUP, {'status': net.base.ExpStatus.RUNNING}),), kwargs), |
| 171 | + ('close', (), {}), # should be called after the last message (when keyboad input is simulated) |
| 172 | + ] |
| 173 | + # Check odd method calls as even ones are the on_event calls |
| 174 | + for expected, actual in zip(expected_responses, map(tuple, self.communicator.method_calls[1::2]), strict=False): |
| 175 | + self.assertEqual(expected, actual) |
| 176 | + |
| 177 | + # Check that the local and remote sessions were removed |
| 178 | + self.assertFalse(any(self.tmpdir.rglob('*.*'))) |
| 179 | + self.assertFalse(self.tmpdir.joinpath(f'local/foo/{date.today()}/001').exists()) |
| 180 | + |
| 181 | + # Check behaviour when user does not confirm cleanup |
| 182 | + messages = self._iterate_messages(keyboard_input='ABORT\n') |
| 183 | + self.communicator.on_event.side_effect = lambda evt: next(messages) |
| 184 | + with patch('builtins.input', return_value='') as input: |
| 185 | + await iblrig.ephys.main_v8_networked('foo', debug=True) |
| 186 | + input.assert_called_once() |
| 187 | + self.assertTrue(any(self.tmpdir.rglob('*.*'))) |
| 188 | + self.assertTrue(self.tmpdir.joinpath(f'local/foo/{date.today()}/001').exists()) |
| 189 | + |
| 190 | + async def test_alyx_request(self): |
| 191 | + """Test the main_v8_networked function alyx request message.""" |
| 192 | + # Create some mock behaviour rig messages that request and provide Alyx credentials |
| 193 | + alyx_req = ((None, {}), self.addr, net.base.ExpMessage.ALYX) |
| 194 | + alyx_mes = ((TEST_DB['base_url'], {'test_user': {'token': 't0k3n'}}), self.addr, net.base.ExpMessage.ALYX) |
| 195 | + # Behaviour should be thus: |
| 196 | + # 1. Request not processed as Alyx offline by default |
| 197 | + # 2. Alyx object updated with remote token |
| 198 | + # 3. Request processed with updated Alyx object (now logged in) |
| 199 | + self.messages = (alyx_req, alyx_mes, alyx_req) |
| 200 | + |
| 201 | + messages = self._iterate_messages() |
| 202 | + self.communicator.on_event.side_effect = lambda evt: next(messages) |
| 203 | + with patch('iblrig.ephys.update_alyx_token', wraps=iblrig.ephys.update_alyx_token) as m: |
| 204 | + await iblrig.ephys.main_v8_networked('foo', debug=True) |
| 205 | + m.assert_called_once() |
| 206 | + |
| 207 | + # Check that the expected methods were called with the expected arguments |
| 208 | + self.communicator.alyx.assert_awaited_once() |
| 209 | + (alyx,), addr = self.communicator.alyx.call_args |
| 210 | + self.assertTrue(alyx.is_logged_in) |
| 211 | + self.assertEqual(TEST_DB['base_url'], alyx.base_url) |
| 212 | + self.assertEqual({'token': 't0k3n'}, alyx._token) |
| 213 | + |
| 214 | + def _iterate_messages(self, keyboard_input='\n'): |
| 215 | + """Yield behaviour rig UDP messages with added side effect simulating keyboard input after.""" |
| 216 | + # When first called we shouold not have awaited any methods on the communicator yet |
| 217 | + for method in ('info', 'init', 'start', 'status', 'alyx', 'confirmed_send'): |
| 218 | + getattr(self.communicator, method).assert_not_awaited() |
| 219 | + # Yeild the messages in order |
| 220 | + for msg in self.messages: # noqa: UP028 |
| 221 | + yield msg # ruff complains here but the suggested `yield from` does not work |
| 222 | + # After the last message is processed, terminate by simulating keyboard input |
| 223 | + self.keyboard = keyboard_input |
| 224 | + yield self.messages[-1] |
0 commit comments