-
Notifications
You must be signed in to change notification settings - Fork 31
Description
Hi!
When I try to add a list of cont_labels to marugoto.mil crossval, the program breaks.
python -m marugoto.mil crossval \
--clini-excel $cliniPath \
--slide-csv $slidePath \
--feature-dir$featurePath \
--target-label $target \
--output-path $outPath \
--n_splits 3 \
--cont_labels "PLT,INR,aPTT,BILIRUBIN,gGT,CHE"
The error I get is: "AttributeError: 'tuple' object has no attribute 'shape'" (full error message below)
The point at which the program breaks is in marugoto/mil/_mil.py, line 146:
batch = train_dl.one_batch()
Some additional behaviors I noticed:
- Adding --cont_labels to the main marugoto branch works fine.
- The ds variable created in marugoto/mil/data.py, line 139 is differently formatted between main and attMIL branch.
def _make_multi_input_dataset(
*,
bags: Sequence[Iterable[Path]],
targets: Tuple[FunctionTransformer, Sequence[Any]],
add_features: Iterable[Tuple[Any, Sequence[Any]]],
bag_size: Optional[int] = None
) -> MapDataset:
target_enc, targs = targets
assert len(bags) == len(targs), \
'number of bags and ground truths does not match!'
for i, (_, vals) in enumerate(add_features):
assert len(vals) == len(targs), \
f'number of additional attributes #{i} and ground truths does not match!'
bag_ds = BagDataset(bags, bag_size=bag_size)
add_ds = MapDataset(
_splat_concat,
*[
EncodedDataset(enc, vals)
for enc, vals in add_features
])
targ_ds = EncodedDataset(target_enc, targs)
############ !!! Different behavior spotted here !!!
ds = MapDataset(
_attach_add_to_bag_and_zip_with_targ,
bag_ds,
add_ds,
targ_ds,
)
############
return ds
-
Indexing ds[0] in attMIL gives [(features_tensor, int_ninstances), [[np.array]], [tensor_target]], where the np.array contains raw values of the cont_labels features

-
Indexing ds[0] in main gives (tensor_features, int_n_instances, tensor_target). Cont_labels are transformed and concatenated to tensor_features, resulting in shape n_instances * (n_features + n_cont_labels)

Error Message:
Traceback (most recent call last):
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 179, in create_batch
try: return (fa_collate,fa_convert)self.prebatched
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 51, in fa_collate
return (default_collate(t) if isinstance(b, _collate_types)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 120, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 170, in collate_numpy_array_fn
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/main.py", line 5, in
Fire({
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 466, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 681, in CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/helpers.py", line 359, in categorical_crossval
learn = _crossval_train(
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/helpers.py", line 420, in _crossval_train
learn = train(
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/_mil.py", line 147, in train
batch = train_dl.one_batch()
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 187, in one_batch
with self.fake_l.no_multiproc(): res = first(self)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastcore/basics.py", line 660, in first
return next(x, None)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 127, in iter
for b in _loadersself.fake_l.num_workers==0:
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 43, in fetch
data = next(self.dataset_iter)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 138, in create_batches
yield from map(self.do_batch, self.chunkify(res))
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 183, in do_batch
def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 181, in create_batch
if not self.prebatched: collate_error(e,b)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 75, in collate_error
if i == 0: shape_a, type_a = item[idx].shape, item[idx].class.name
AttributeError: 'tuple' object has no attribute 'shape'