-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FAILED tests/test_datasets.py::TestSen1Floods11NonGeo::test_dataset_sample - KeyError: 'location' #310
Comments
I updated the branch to the most recent commit (ea44763), defined PYTHONPATH and ran |
still getting it (Python 3.12.7): FAILED tests/test_datasets.py::TestSen1Floods11NonGeo::test_dataset_sample - KeyError: 'location' this is what I did: git clone [email protected]:IBM/terratorch.git |
The following script is generated by AI Agent to help reproduce the issue: # terratorch/reproduce.py
import os
import torch
import pandas as pd
import geopandas as gpd
import numpy as np
from pathlib import Path
from albumentations.pytorch import ToTensorV2
import albumentations as A
from terratorch.datasets.sen1floods11 import Sen1Floods11NonGeo
def test_dataset_sample():
try:
transform = A.Compose([
A.Resize(32, 32),
ToTensorV2(),
])
# Mock data root
data_root = "terratorch/tests/mock_data/sen1floods"
os.makedirs(data_root, exist_ok=True)
# Mock necessary directories and files
required_dirs = [
f"{data_root}/v1.1/splits/flood_handlabeled",
f"{data_root}/v1.1/data/flood_events/HandLabeled/S2Hand",
f"{data_root}/v1.1/data/flood_events/HandLabeled/LabelHand"
]
for dir in required_dirs:
os.makedirs(dir, exist_ok=True)
split_file_path = f"{data_root}/v1.1/splits/flood_handlabeled/flood_train_data.txt"
with open(split_file_path, 'w') as f:
f.write("mock_file")
# Mock image and label files
mock_image_file_path = f"{data_root}/v1.1/data/flood_events/HandLabeled/S2Hand/mock_file_S2Hand.tif"
mock_label_file_path = f"{data_root}/v1.1/data/flood_events/HandLabeled/LabelHand/mock_file_LabelHand.tif"
import rasterio
from rasterio.transform import from_origin
# Create a mock GeoTIFF image file with the required bands
with rasterio.open(
mock_image_file_path, 'w',
driver='GTiff',
height=32, width=32,
count=len(Sen1Floods11NonGeo.all_band_names), dtype=np.uint8,
crs='+proj=latlong',
transform=from_origin(0, 0, 1, 1)
) as dst:
data = np.zeros((len(Sen1Floods11NonGeo.all_band_names), 32, 32), dtype=np.uint8)
for i in range(len(Sen1Floods11NonGeo.all_band_names)):
dst.write(data[i], i + 1)
# Create a mock GeoTIFF label file
with rasterio.open(
mock_label_file_path, 'w',
driver='GTiff',
height=32, width=32,
count=1, dtype=np.uint8,
crs='+proj=latlong',
transform=from_origin(0, 0, 1, 1)
) as dst:
data = np.zeros((32, 32), dtype=np.uint8)
dst.write(data, 1)
# Mock metadata file
metadata_file_path = f"{data_root}/v1.1/Sen1Floods11_Metadata.geojson"
metadata = gpd.GeoDataFrame({
'wrong_key': ['mock_location'],
's2_date': [pd.to_datetime('2022-01-01')],
'geometry': gpd.points_from_xy([0], [0])
})
metadata.set_geometry('geometry')
metadata.to_file(metadata_file_path, driver="GeoJSON")
dataset = Sen1Floods11NonGeo(data_root=data_root, split="train", transform=transform, use_metadata=True)
sample = dataset[0]
assert "image" in sample, "Sample does not contain 'image'"
assert "mask" in sample, "Sample does not contain 'mask'"
assert "location_coords" in sample, "Sample does not contain 'location_coords'"
assert "temporal_coords" in sample, "Sample does not contain 'temporal_coords'"
assert isinstance(sample["image"], torch.Tensor), "'image' is not a torch.Tensor"
assert isinstance(sample["mask"], torch.Tensor), "'mask' is not a torch.Tensor"
assert isinstance(sample["location_coords"], torch.Tensor), "'location_coords' is not a torch.Tensor"
assert isinstance(sample["temporal_coords"], torch.Tensor), "'temporal_coords' is not a torch.Tensor"
assert sample["image"].dtype == torch.float32, "'image' does not have dtype torch.float32"
assert sample["mask"].dtype == torch.long, "'mask' does not have dtype torch.long"
assert sample["location_coords"].dtype == torch.float32, "'location_coords' does not have dtype torch.float32"
assert sample["temporal_coords"].dtype == torch.float32, "'temporal_coords' does not have dtype torch.float32"
num_bands = len(dataset.bands)
assert sample["image"].shape == (num_bands, 32, 32), f"'image' has incorrect shape: {sample['image'].shape}"
assert sample["mask"].shape == (32, 32), f"'mask' has incorrect shape: {sample['mask'].shape}"
print("Test passed successfully with no errors!")
except Exception as e:
raise AssertionError(e)
if __name__ == "__main__":
test_dataset_sample() How to run: python3 terratorch/reproduce.py Expected Result:
Thank you for your valuable contribution to this project and we appreciate your feedback! Please respond with an emoji if you find this script helpful. Feel free to comment below if any improvements are needed. Best regards from an AI Agent! |
Describe the issue
when running tests on main we get
FAILED tests/test_datasets.py::TestSen1Floods11NonGeo::test_dataset_sample - KeyError: 'location'
The text was updated successfully, but these errors were encountered: