Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data mismatch in F.conv2d() module #362

Closed
takaomoriyama opened this issue Jan 10, 2025 · 10 comments
Closed

Data mismatch in F.conv2d() module #362

takaomoriyama opened this issue Jan 10, 2025 · 10 comments
Assignees
Labels

Comments

@takaomoriyama
Copy link
Member

Describe the issue
When running a task for Object Detection (under development in branch), fit process stops due to data mismatch in F.conv2d() module.

  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

To Reproduce (optional, but appreciated)
Steps to reproduce the behavior:

$ conda create -n terratorch python=3.11
$ conda activate terratorch
$ git clone -b obj_det_geobench [email protected]:IBM/terratorch.git
$ cd terratorch
$ pip install -r requirements/required.txt -r requirements/dev.txt
$ pip install pycocotools
$ pip install -e .
$ cd terratorch/examples/confs
$ terratorch fit --config object_detection_vhr10.yaml

Screenshots or log output (optional)

$ terratorch fit --config object_detection_vhr10.yaml
Seed set to 0
/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/cli.py:676: `ObjectDetectionTask.configure_optimizers` will be overridden by `MyLightnin\
gCLI.configure_optimizers`.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trad\
e-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Files already downloaded and verified
loading annotations into memory...
Done (t=0.30s)
creating index...
index created!
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓
┃   ┃ Name          ┃ Type                        ┃ Params ┃ Mode  ┃
┡━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩
│ 0 │ model         │ ObjectDetectionModelWrapper │ 41.3 M │ train │
│ 1 │ train_metrics │ MetricCollection            │      0 │ train │
│ 2 │ val_metrics   │ MetricCollection            │      0 │ train │
│ 3 │ test_metrics  │ MetricCollection            │      0 │ train │
└───┴───────────────┴─────────────────────────────┴────────┴───────┘
Trainable params: 41.1 M
Non-trainable params: 222 K
Total params: 41.3 M
Total estimated model params size (MB): 165
Modules in train mode: 196
Modules in eval mode: 0


Traceback (most recent call last):
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/bin/terratorch", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/dccstor/usgs_dem/moriyama/dev/IBM/terratorch-obj_det_geobench/terratorch/__main__.py", line 9, in main
    _ = build_lightning_cli()
        ^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/dev/IBM/terratorch-obj_det_geobench/terratorch/cli_tools.py", line 429, in build_lightning_cli
    return MyLightningCLI(
           ^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 394, in __init__
    self._run_subcommand(self.subcommand)
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 701, in _run_subcommand
    fn(**fn_kwargs)
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 411, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/dev/IBM/terratorch-obj_det_geobench/terratorch/tasks/object_detection_tasks.py", line 165, in validation_step
    y_hat = self(x).output
            ^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torchgeo/trainers/base.py", line 81, in forward
    return self.model(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/dev/IBM/terratorch-obj_det_geobench/terratorch/models/object_detection_model_factory.py", line 177, in forward
    return ModelOutput(self.torchvision_model(*args, **kwargs))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py", line 101, in forward
    features = self.backbone(images.tensors)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py", line 57, in forward
    x = self.body(x)
        ^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torchvision/models/_utils.py", line 69, in forward
    x = module(x)
        ^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-od2/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

I found that input data passed to F.conv2d() is a tensor on CPU, while weight is a tensor on GPU.

@takaomoriyama
Copy link
Member Author

takaomoriyama commented Jan 10, 2025

After some data tracing, I found that transformation function defined in torchgeo.datamodule.VHR10DataModule.init() transforms data, and move image data back to CPU.

        self.train_aug = AugPipe(
            AugmentationSequential(
                K.Normalize(mean=self.mean, std=self.std),
                K.Resize(self.patch_size),
                K.RandomHorizontalFlip(),
                K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.7),
                K.RandomVerticalFlip(),
                data_keys=['image', 'boxes', 'masks'],
            ),
            batch_size,
        )

https://github.com/microsoft/torchgeo/blob/5cca8e7336fc582ddc81a86e05d50b505ff74a07/torchgeo/datamodules/vhr10.py#L55-L65

Input data and output data of AugPipe().

Input
    image: list[16] of tensor(torch.Size([3, 751, 979])) (torch.float32 on cuda:0)
    boxes: list[16] of tensor(torch.Size([3, 4])) (torch.float32 on cuda:0)
    labels: list[16] of tensor(torch.Size([3])) (torch.int64 on cuda:0)
    masks: list[16] of tensor(torch.Size([3, 751, 979])) (torch.uint8 on cuda:0)

Output
    image: tensor(torch.Size([16, 3, 512, 512])) (torch.float32 on cpu) ← image is scaled to 512 x 512, then converted from a list of tensor to single tensor
    boxes: list[16] of tensor(torch.Size([3, 4])) (torch.float32 on cuda:0)
    labels: list[16] of tensor(torch.Size([3])) (torch.int64 on cuda:0)
    masks: list[16] of tensor(torch.Size([3, 512, 512])) (torch.uint8 on cuda:0) ← masks are scaled to 512 x 512

@takaomoriyama
Copy link
Member Author

takaomoriyama commented Jan 10, 2025

Found that is is known issue of Kornia: “AugmentationSequential explicitly moves the output to the CPU if data_keys is given” (kornia/kornia#3066). Data is moved to CPU when AugmentationSequential() is given data_type with more than one data item in it.
Fix exists (kornia/kornia#3084), but the latest code including the fix has not released yet (as v0.7.5 or v0.8 ?)

Why this bug has not appeared with terratorch ?
Some datamodule in terratorch are using kornial AugmentationSequential(), but they all specify single data in data_keys or even without data_keys.

  • sen1floods11.py: AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
  • geobench_data_module.py: AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
  • fire_scars.py: AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
  • fire_scars.py: AugmentationSequential(K.RandomCrop(224, 224), K.Normalize(means, stds))
  • fire_scars.py: AugmentationSequential(K.Normalize(means, stds))
  • fire_scars.py: AugmentationSequential(K.RandomCrop(224, 224), K.normalize())
  • landslide4sense.py: AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
  • openearthmap.py: AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

Two tentative solutions.
Solution-1: Use older Kornia 0.8.3

$ pip uninstall kornia
$ pip install kornia==0.7.3
Installing collected packages: kornia
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchgeo 0.7.0.dev0 requires kornia>=0.7.4, but you have kornia 0.7.3 which is incompatible.

This works, but dependency requirement is not fully satisfied.

Solution-2: Use latest Kornia from github.

$ pip uninstall kornia
$ pip install git+https://github.com/kornia/kornia.git

This is tentative solution until new version of kornia (0.7.5 or 0.8) is released. Then we would pin down the version in requirements/required.txt

@Joao-L-S-Almeida Joao-L-S-Almeida self-assigned this Jan 13, 2025
@Joao-L-S-Almeida
Copy link
Member

Joao-L-S-Almeida commented Jan 13, 2025

@takaomoriyama According to my tests, it seems that the input tensor x is not in the same device the model is placed.
I did:

x_ = x.to(self.device)
y_hat = self(x_).output

Ad that step is passing.

@Joao-L-S-Almeida
Copy link
Member

There is another ocurrence in ~ terratorch/terratorch/tasks/object_detection_tasks.py:110.

@takaomoriyama
Copy link
Member Author

takaomoriyama commented Jan 13, 2025

@Joao-L-S-Almeida Yes, x_ = x.to(self.device) is a good workaround that does not require any change in the execution environment.

@takaomoriyama
Copy link
Member Author

@Joao-L-S-Almeida Can you please elaborate on the issue you are referring on line 110 of object_detection_task.py ?
I recognize the result of loss = self(x, y).output from Faster-RCNN is not a model output, but a dictionary of loss values as follows

  • loss_classifier: Tensor
  • loss_box_reg: Tensor
  • loss_objectness: Tensor
  • loss_rpn_box_reg: Tensor

@Joao-L-S-Almeida
Copy link
Member

The tensors x and y are also in a different device to the model. They also should be sent to self.device (usually cuda:0).

@takaomoriyama
Copy link
Member Author

@Joao-L-S-Almeida I see. I confirmed only x is on CPU. So I added x = x.to(self.device) after x = batch['image'] in training_step(), validation_step(), test_step(), and predict_step() in commit, and confirmed the code successfully ran with kornia 0.7.4.

@takaomoriyama
Copy link
Member Author

Now that kornia 0.8.0 is released at pypi, the code should run without the workaround.

@takaomoriyama
Copy link
Member Author

@Joao-L-S-Almeida Removed the workaround in the commit, and confirmed the code ran with ordinary set up procedure described in README_Object_Detection.md. Thanks for your help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants