Skip to content

Commit 1382c71

Browse files
Merge pull request #61 from ZeroCool940711/dev
Fixed the dataset not converting properly the images to RGBA when using 4 channels for training on the maskgit.
2 parents 2f99bfe + 9477603 commit 1382c71

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

muse_maskgit_pytorch/dataset.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,6 @@ def __init__(
243243
self.caption_pair.append(captions)
244244

245245
transform_list = [
246-
T.Lambda(
247-
lambda img: img.convert("RGBA")
248-
if img.mode != "RGBA" and alpha_channel
249-
else img
250-
if img.mode == "RGB" and not alpha_channel
251-
else img.convert("RGB")
252-
),
253246
T.Resize(image_size),
254247
]
255248
if flip:
@@ -258,6 +251,10 @@ def __init__(
258251
transform_list.append(T.CenterCrop(image_size))
259252
if random_crop:
260253
transform_list.append(T.RandomCrop(image_size, pad_if_needed=True))
254+
if alpha_channel:
255+
transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img))
256+
else:
257+
transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img))
261258
transform_list.append(T.ToTensor())
262259
self.transform = T.Compose(transform_list)
263260

train_muse_maskgit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@
426426

427427
@dataclass
428428
class Arguments:
429+
total_params: Optional[int] = None
429430
only_save_last_checkpoint: bool = False
430431
validation_image_scale: float = 1.0
431432
no_center_crop: bool = False
@@ -493,7 +494,6 @@ class Arguments:
493494
debug: bool = False
494495
config_path: Optional[str] = None
495496
attention_type: str = "flash"
496-
total_params: Optional[int] = None
497497

498498

499499
def main():
@@ -714,7 +714,7 @@ def main():
714714

715715
# load the maskgit transformer from disk if we have previously trained one
716716
with accelerator.main_process_first():
717-
if args.resume_path:
717+
if args.resume_path is not None and len(args.resume_path) > 1:
718718
load = True
719719

720720
accelerator.print("Loading Muse MaskGit...")

0 commit comments

Comments
 (0)