Skip to content

Commit f1cf70c

Browse files
Fixed the dataset not converting properly the images to RGBA when using 4 channels for training on the maskgit.
1 parent 2f99bfe commit f1cf70c

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

muse_maskgit_pytorch/dataset.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,6 @@ def __init__(
243243
self.caption_pair.append(captions)
244244

245245
transform_list = [
246-
T.Lambda(
247-
lambda img: img.convert("RGBA")
248-
if img.mode != "RGBA" and alpha_channel
249-
else img
250-
if img.mode == "RGB" and not alpha_channel
251-
else img.convert("RGB")
252-
),
253246
T.Resize(image_size),
254247
]
255248
if flip:
@@ -258,6 +251,10 @@ def __init__(
258251
transform_list.append(T.CenterCrop(image_size))
259252
if random_crop:
260253
transform_list.append(T.RandomCrop(image_size, pad_if_needed=True))
254+
if alpha_channel:
255+
transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img))
256+
else:
257+
transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img))
261258
transform_list.append(T.ToTensor())
262259
self.transform = T.Compose(transform_list)
263260

0 commit comments

Comments
 (0)