Skip to content

Commit 2f99bfe

Browse files
Merge pull request #60 from ZeroCool940711/dev
Fixed the infer_vae.py script so it would retain the alpha channel whens saving the grid if we are using `--channels 4`, previosly it was generating properly the reconstructed image with alpha channel but the grid was converting the image to RGB instead of RGBA which removed the alpha channel, now it properly handles both RGB and RGBA images.
2 parents 6e74200 + c15221c commit 2f99bfe

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

infer_vae.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,9 @@ def main():
446446
os.makedirs(f"{args.results_dir}/outputs", exist_ok=True)
447447

448448
save_image(
449-
dataset[image_id], f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}"
449+
dataset[image_id],
450+
f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}",
451+
format="PNG",
450452
)
451453

452454
_, ids, _ = vae.encode(
@@ -505,7 +507,8 @@ def main():
505507

506508
# Create horizontal grid with input and output images
507509
grid_image = PIL.Image.new(
508-
"RGB", (input_image.width + output_image.width, input_image.height)
510+
"RGB" if args.channels == 3 else "RGBA",
511+
(input_image.width + output_image.width, input_image.height),
509512
)
510513
grid_image.paste(input_image, (0, 0))
511514
grid_image.paste(output_image, (input_image.width, 0))
@@ -515,7 +518,7 @@ def main():
515518
hash = hashlib.sha1(input_image.tobytes()).hexdigest()
516519

517520
filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png"
518-
grid_image.save(f"{output_dir}/{filename}")
521+
grid_image.save(f"{output_dir}/{filename}", format="PNG")
519522

520523
# Remove input and output images after the grid was made.
521524
os.remove(f"{output_dir}/input.png")

0 commit comments

Comments
 (0)