-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Hi guys, I was trying to train the TID network using the train.py script with the tid_train.yaml config. I downloaded the dsec dataset and placed it in the correct folders as per the readme. While training I noticed that the training crashes after some iterations with an error in the retrieval_fn.py:
idnet/idn/utils/retrieval_fn.py
Lines 45 to 48 in 6e9ade0
| def retreival_pred_nextflow_seq(out, batch): | |
| fmask = namedtuple("masked_frame", ["frame", "mask"]) | |
| return (out["flow_next_trajectory"], [fmask(x["flow_gt_next"], | |
| x["flow_gt_next_valid_mask"]) for x in batch]) |
It seems that the retrieval function is expecting there to be a gt_flow_next for all elements of the batch but when I check the dataloader here:
idnet/idn/loader/loader_dsec.py
Lines 396 to 404 in 6e9ade0
| if self.load_gt: | |
| if index + 2 < len(self.flow_png): | |
| output['flow_gt_next'] = [torch.tensor( | |
| x) for x in self.load_flow(self.flow_png[index + 2])] | |
| output['flow_gt_next'][0] = torch.moveaxis( | |
| output['flow_gt_next'][0], -1, 0) | |
| output['flow_gt_next'][1] = torch.unsqueeze( | |
| output['flow_gt_next'][1], 0) | |
| return output |
it will occasionally load a sample at the end of the recording for which there is no
gt_flow_next.
By changing the +1 to a -1 in line 524 here the training runs without errors:
idnet/idn/loader/loader_dsec.py
Lines 520 to 534 in 6e9ade0
| def get_continuous_sequences(self): | |
| continuous_seq_idcs = [] | |
| if self.sequence_length > 1: | |
| for i in range(len(self.timestamps_flow)-self.sequence_length+1): | |
| diff = self.timestamps_flow[i + | |
| self.sequence_length-1] - self.timestamps_flow[i] | |
| if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]): | |
| continuous_seq_idcs.append(i) | |
| else: | |
| for i in range(len(self.timestamps_flow)-1): | |
| diff = self.timestamps_flow[i+1] - self.timestamps_flow[i] | |
| if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]): | |
| continuous_seq_idcs.append(i) | |
| return continuous_seq_idcs |
Could it be that this was a typo in the code or am I missing something?