diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index aea5bf99..6ddea9cc 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -24,4 +24,4 @@ jobs: pip install -e . - name: Test with pytest run: | - pytest tests + pytest -s tests diff --git a/.gitignore b/.gitignore index b2e2d4d6..f85d05b3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ site/* build/* dist/* *.egg-info -*.ipynb_checkpoints \ No newline at end of file +*.ipynb_checkpoints +**/*pt +**/*pth diff --git a/tests/manufactured-finetune_prithvi_swin_B.yaml b/tests/manufactured-finetune_prithvi_swin_B.yaml new file mode 100644 index 00000000..b7498d1b --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_B.yaml @@ -0,0 +1,150 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_B + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_B.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/manufactured-finetune_prithvi_swin_L.yaml b/tests/manufactured-finetune_prithvi_swin_L.yaml new file mode 100644 index 00000000..8619ffbf --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_L.yaml @@ -0,0 +1,150 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_L + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_L.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 64 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/manufactured-finetune_prithvi_vit_100.yaml b/tests/manufactured-finetune_prithvi_vit_100.yaml new file mode 100644 index 00000000..8ee70a9c --- /dev/null +++ b/tests/manufactured-finetune_prithvi_vit_100.yaml @@ -0,0 +1,151 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_vit_100 + backbone_pretrained_cfg_overlay: + file: tests/prithvi_vit_100.pt + backbone_drop_path_rate: 0.3 + num_frames: 1 + # backbone_window_size: 8 + decoder_channels: 64 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/manufactured-finetune_prithvi_vit_300.yaml b/tests/manufactured-finetune_prithvi_vit_300.yaml new file mode 100644 index 00000000..1994f0d1 --- /dev/null +++ b/tests/manufactured-finetune_prithvi_vit_300.yaml @@ -0,0 +1,151 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_vit_300 + backbone_pretrained_cfg_overlay: + file: tests/prithvi_vit_300.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 64 + num_frames: 1 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/regression_test_input_1.tif b/tests/regression_test_input_1.tif new file mode 100644 index 00000000..5db7593c Binary files /dev/null and b/tests/regression_test_input_1.tif differ diff --git a/tests/regression_test_label_1.tif b/tests/regression_test_label_1.tif new file mode 100644 index 00000000..521be614 Binary files /dev/null and b/tests/regression_test_label_1.tif differ diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 6b8c36b9..57d42f01 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -1,8 +1,9 @@ import pytest import timm import torch - +import importlib import terratorch # noqa: F401 +import os NUM_CHANNELS = 6 NUM_FRAMES = 3 @@ -49,7 +50,6 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) backbone(input_224_multitemporal) - #def test_swin_models_accept_non_divisible_by_patch_size(input_386): # backbone = timm.create_model("prithvi_swin_90_us", pretrained=False, num_frames=NUM_FRAMES) # backbone(input_386) diff --git a/tests/test_finetune.py b/tests/test_finetune.py new file mode 100644 index 00000000..4812d5f3 --- /dev/null +++ b/tests/test_finetune.py @@ -0,0 +1,31 @@ +import pytest +import timm +import torch +import importlib +import terratorch +import subprocess +import os + +from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn as checkpoint_filter_fn_vit +from terratorch.models.backbones.prithvi_swin import checkpoint_filter_fn as checkpoint_filter_fn_swin + +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) +def test_finetune_multiple_backbones(model_name): + + model_instance = timm.create_model(model_name) + pretrained_bands = [0, 1, 2, 3, 4, 5] + model_bands = [0, 1, 2, 3, 4, 5] + + state_dict = model_instance.state_dict() + + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + + # Running the terratorch CLI + command_str = f"terratorch fit -c tests/manufactured-finetune_{model_name}.yaml" + + command_out = subprocess.run(command_str, shell=True) + + assert not command_out.returncode + + +