Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 45 additions & 47 deletions image_datasets/control_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,56 +74,54 @@ def __len__(self):
return 999999

def __getitem__(self, idx):
try:
idx = random.randint(0, len(self.images) - 1)
if self.cached_image_embeddings is None:
img = Image.open(self.images[idx]).convert('RGB')
if self.random_ratio:
ratio = random.choice(["16:9", "default", "1:1", "4:3"])
if ratio != "default":
img = crop_to_aspect_ratio(img, ratio)
img = image_resize(img, self.img_size)
w, h = img.size
new_w = (w // 32) * 32
new_h = (h // 32) * 32
img = img.resize((new_w, new_h))
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1)
else:
img = self.cached_image_embeddings[self.images[idx].split('/')[-1]]
if self.cached_control_image_embeddings is None:
img = Image.open(self.images[idx]).convert('RGB')
if self.random_ratio:
ratio = random.choice(["16:9", "default", "1:1", "4:3"])
if ratio != "default":
img = crop_to_aspect_ratio(img, ratio)
img = image_resize(img, self.img_size)
w, h = img.size
new_w = (w // 32) * 32
new_h = (h // 32) * 32
img = img.resize((new_w, new_h))
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1)
else:
control_img = self.cached_control_image_embeddings[self.images[idx].split('/')[-1]]

txt_path = self.images[idx].split('.')[0] + '.' + self.caption_type
if self.cached_text_embeddings is None:
prompt = open(txt_path, encoding='utf-8').read()
if throw_one(self.caption_dropout_rate):
return img, " ", control_img
for _ in range(10):
try:
idx = random.randint(0, len(self.images) - 1)
img_path = self.images[idx]
img_name = os.path.basename(img_path)
base = os.path.splitext(img_name)[0]
txt = base + ".txt"

# --- image load (cached or not) ---
if self.cached_image_embeddings is None:
img = Image.open(img_path).convert("RGB")
img = image_resize(img, self.img_size)
img = torch.from_numpy((np.array(img) / 127.5) - 1).permute(2, 0, 1)
else:
img = self.cached_image_embeddings[img_name]

# --- control image load ---
if self.cached_control_image_embeddings is None:
control_img = Image.open(img_path).convert("RGB")
control_img = image_resize(control_img, self.img_size)
control_img = torch.from_numpy((np.array(control_img) / 127.5) - 1).permute(2, 0, 1)
else:
control_img = self.cached_control_image_embeddings[img_name]

# --- text embedding ---
if self.cached_text_embeddings is None:
txt_path = os.path.join(os.path.dirname(img_path), txt)
if not os.path.exists(txt_path):
raise FileNotFoundError(f"Caption file not found: {txt_path}")
prompt = open(txt_path, encoding="utf-8").read()
if throw_one(self.caption_dropout_rate):
return img, " ", control_img
return img, prompt, control_img
else:
txt = txt_path.split('/')[-1]
if throw_one(self.caption_dropout_rate):
return img, self.cached_text_embeddings[txt + 'empty_embedding']['prompt_embeds'], self.cached_text_embeddings[txt + 'empty_embedding']['prompt_embeds_mask'], control_img
else:
return img, self.cached_text_embeddings[txt]['prompt_embeds'], self.cached_text_embeddings[txt]['prompt_embeds_mask'], control_img

except Exception as e:
print(e)
return self.__getitem__(random.randint(0, len(self.images) - 1))
if throw_one(self.caption_dropout_rate):
key = txt + "empty_embedding"
else:
key = txt
if key not in self.cached_text_embeddings:
raise KeyError(f"Missing embedding key: {key}")
emb = self.cached_text_embeddings[key]
return img, emb["prompt_embeds"], emb["prompt_embeds_mask"], control_img

except Exception as e:
print(f"Error loading sample (try again): {e}")
continue
raise RuntimeError("Too many dataset loading errors")



def loader(train_batch_size, num_workers, **args):
Expand Down