Skip to content

attMIL-regression breaks when adding cont_labels #15

@tsorz

Description

@tsorz

Hi!

When I try to add a list of cont_labels to marugoto.mil crossval, the program breaks.

python -m marugoto.mil crossval \
    --clini-excel $cliniPath \
    --slide-csv $slidePath  \
    --feature-dir$featurePath \
    --target-label $target \
    --output-path $outPath \
    --n_splits 3 \
    --cont_labels "PLT,INR,aPTT,BILIRUBIN,gGT,CHE"

The error I get is: "AttributeError: 'tuple' object has no attribute 'shape'" (full error message below)
The point at which the program breaks is in marugoto/mil/_mil.py, line 146:
batch = train_dl.one_batch()

Some additional behaviors I noticed:

  • Adding --cont_labels to the main marugoto branch works fine.
  • The ds variable created in marugoto/mil/data.py, line 139 is differently formatted between main and attMIL branch.
def _make_multi_input_dataset(
    *,
    bags: Sequence[Iterable[Path]],
    targets: Tuple[FunctionTransformer, Sequence[Any]],
    add_features: Iterable[Tuple[Any, Sequence[Any]]],
    bag_size: Optional[int] = None
) -> MapDataset:
    target_enc, targs = targets
    assert len(bags) == len(targs), \
        'number of bags and ground truths does not match!'
    for i, (_, vals) in enumerate(add_features):
        assert len(vals) == len(targs), \
            f'number of additional attributes #{i} and ground truths does not match!'

    bag_ds = BagDataset(bags, bag_size=bag_size)

    add_ds = MapDataset(
        _splat_concat,
        *[
            EncodedDataset(enc, vals)
            for enc, vals in add_features
        ])

    targ_ds = EncodedDataset(target_enc, targs)
    ############ !!! Different behavior spotted here !!!
    ds = MapDataset(
        _attach_add_to_bag_and_zip_with_targ,
        bag_ds,
        add_ds,
        targ_ds,
    ) 
    ############
    return ds
  • Indexing ds[0] in attMIL gives [(features_tensor, int_ninstances), [[np.array]], [tensor_target]], where the np.array contains raw values of the cont_labels features
    image

  • Indexing ds[0] in main gives (tensor_features, int_n_instances, tensor_target). Cont_labels are transformed and concatenated to tensor_features, resulting in shape n_instances * (n_features + n_cont_labels)
    image

Error Message:

Traceback (most recent call last):
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 179, in create_batch
try: return (fa_collate,fa_convert)self.prebatched
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 51, in fa_collate
return (default_collate(t) if isinstance(b, _collate_types)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 120, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 170, in collate_numpy_array_fn
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/main.py", line 5, in
Fire({
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 466, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 681, in CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/helpers.py", line 359, in categorical_crossval

learn = _crossval_train(
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/helpers.py", line 420, in _crossval_train
learn = train(
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/_mil.py", line 147, in train
batch = train_dl.one_batch()
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 187, in one_batch
with self.fake_l.no_multiproc(): res = first(self)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastcore/basics.py", line 660, in first
return next(x, None)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 127, in iter
for b in _loadersself.fake_l.num_workers==0:
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 43, in fetch
data = next(self.dataset_iter)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 138, in create_batches
yield from map(self.do_batch, self.chunkify(res))
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 183, in do_batch
def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 181, in create_batch
if not self.prebatched: collate_error(e,b)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 75, in collate_error
if i == 0: shape_a, type_a = item[idx].shape, item[idx].class.name
AttributeError: 'tuple' object has no attribute 'shape'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions