diff --git a/image_datasets/control_dataset.py b/image_datasets/control_dataset.py index 650799f..0fad995 100644 --- a/image_datasets/control_dataset.py +++ b/image_datasets/control_dataset.py @@ -74,56 +74,54 @@ def __len__(self): return 999999 def __getitem__(self, idx): - try: - idx = random.randint(0, len(self.images) - 1) - if self.cached_image_embeddings is None: - img = Image.open(self.images[idx]).convert('RGB') - if self.random_ratio: - ratio = random.choice(["16:9", "default", "1:1", "4:3"]) - if ratio != "default": - img = crop_to_aspect_ratio(img, ratio) - img = image_resize(img, self.img_size) - w, h = img.size - new_w = (w // 32) * 32 - new_h = (h // 32) * 32 - img = img.resize((new_w, new_h)) - img = torch.from_numpy((np.array(img) / 127.5) - 1) - img = img.permute(2, 0, 1) - else: - img = self.cached_image_embeddings[self.images[idx].split('/')[-1]] - if self.cached_control_image_embeddings is None: - img = Image.open(self.images[idx]).convert('RGB') - if self.random_ratio: - ratio = random.choice(["16:9", "default", "1:1", "4:3"]) - if ratio != "default": - img = crop_to_aspect_ratio(img, ratio) - img = image_resize(img, self.img_size) - w, h = img.size - new_w = (w // 32) * 32 - new_h = (h // 32) * 32 - img = img.resize((new_w, new_h)) - img = torch.from_numpy((np.array(img) / 127.5) - 1) - img = img.permute(2, 0, 1) - else: - control_img = self.cached_control_image_embeddings[self.images[idx].split('/')[-1]] - - txt_path = self.images[idx].split('.')[0] + '.' + self.caption_type - if self.cached_text_embeddings is None: - prompt = open(txt_path, encoding='utf-8').read() - if throw_one(self.caption_dropout_rate): - return img, " ", control_img + for _ in range(10): + try: + idx = random.randint(0, len(self.images) - 1) + img_path = self.images[idx] + img_name = os.path.basename(img_path) + base = os.path.splitext(img_name)[0] + txt = base + ".txt" + + # --- image load (cached or not) --- + if self.cached_image_embeddings is None: + img = Image.open(img_path).convert("RGB") + img = image_resize(img, self.img_size) + img = torch.from_numpy((np.array(img) / 127.5) - 1).permute(2, 0, 1) + else: + img = self.cached_image_embeddings[img_name] + + # --- control image load --- + if self.cached_control_image_embeddings is None: + control_img = Image.open(img_path).convert("RGB") + control_img = image_resize(control_img, self.img_size) + control_img = torch.from_numpy((np.array(control_img) / 127.5) - 1).permute(2, 0, 1) else: + control_img = self.cached_control_image_embeddings[img_name] + + # --- text embedding --- + if self.cached_text_embeddings is None: + txt_path = os.path.join(os.path.dirname(img_path), txt) + if not os.path.exists(txt_path): + raise FileNotFoundError(f"Caption file not found: {txt_path}") + prompt = open(txt_path, encoding="utf-8").read() + if throw_one(self.caption_dropout_rate): + return img, " ", control_img return img, prompt, control_img - else: - txt = txt_path.split('/')[-1] - if throw_one(self.caption_dropout_rate): - return img, self.cached_text_embeddings[txt + 'empty_embedding']['prompt_embeds'], self.cached_text_embeddings[txt + 'empty_embedding']['prompt_embeds_mask'], control_img else: - return img, self.cached_text_embeddings[txt]['prompt_embeds'], self.cached_text_embeddings[txt]['prompt_embeds_mask'], control_img - - except Exception as e: - print(e) - return self.__getitem__(random.randint(0, len(self.images) - 1)) + if throw_one(self.caption_dropout_rate): + key = txt + "empty_embedding" + else: + key = txt + if key not in self.cached_text_embeddings: + raise KeyError(f"Missing embedding key: {key}") + emb = self.cached_text_embeddings[key] + return img, emb["prompt_embeds"], emb["prompt_embeds_mask"], control_img + + except Exception as e: + print(f"Error loading sample (try again): {e}") + continue + raise RuntimeError("Too many dataset loading errors") + def loader(train_batch_size, num_workers, **args):