diff --git a/data.py b/data.py index f1cc432d..af357157 100644 --- a/data.py +++ b/data.py @@ -74,7 +74,33 @@ def get_clip_timepoints(clip_sampler, duration): all_clips_timepoints.append((start, end)) return all_clips_timepoints +def load_and_transform_thermal_data(image_paths, device): + if image_paths is None: + return None + image_outputs = [] + for image_path in image_paths: + data_transform = transforms.Compose( + [ + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + transforms.Grayscale() + ] + ) + with open(image_path, "rb") as fopen: + image = Image.open(fopen).convert("RGB") + + image = data_transform(image).to(device) + image_outputs.append(image) + return torch.stack(image_outputs, dim=0) + def load_and_transform_vision_data(image_paths, device): if image_paths is None: return None