Skip to content

Commit 446a465

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9533f48 commit 446a465

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

infer_vae.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
import argparse, glob, hashlib
2-
import os, random, re, shutil
1+
import argparse
2+
import glob
3+
import hashlib
4+
import os
5+
import random
6+
import re
7+
import shutil
38
from dataclasses import dataclass
49
from datetime import datetime
510
from typing import Optional
@@ -362,7 +367,7 @@ def main():
362367
channels=args.channels,
363368
layers=args.layers,
364369
discr_layers=args.discr_layers,
365-
).to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
370+
).to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
366371

367372
if args.latest_checkpoint:
368373
accelerator.print("Finding latest checkpoint...")
@@ -427,7 +432,7 @@ def main():
427432
args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2
428433

429434
# move vae to device
430-
vae = vae.to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
435+
vae = vae.to("cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
431436

432437
# Use the parameters() method to get an iterator over all the learnable parameters of the model
433438
total_params = sum(p.numel() for p in vae.parameters())
@@ -458,7 +463,9 @@ def main():
458463
)
459464

460465
_, ids, _ = vae.encode(
461-
dataset[image_id][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
466+
dataset[image_id][None].to(
467+
"cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}"
468+
)
462469
)
463470
recon = vae.decode_from_ids(ids)
464471
save_image(recon, f"{args.results_dir}/outputs/output.{str(args.input_image).split('.')[-1]}")
@@ -471,7 +478,9 @@ def main():
471478
save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png")
472479

473480
_, ids, _ = vae.encode(
474-
dataset[image_id][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
481+
dataset[image_id][None].to(
482+
"cpu" if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}"
483+
)
475484
)
476485
recon = vae.decode_from_ids(ids)
477486
save_image(recon, f"{args.results_dir}/outputs/output.png")
@@ -490,7 +499,13 @@ def main():
490499
if not args.use_paintmind:
491500
# encode
492501
_, ids, _ = vae.encode(
493-
dataset[i][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
502+
dataset[i][None].to(
503+
"cpu"
504+
if args.cpu
505+
else accelerator.device
506+
if args.gpu == 0
507+
else f"cuda:{args.gpu}"
508+
)
494509
)
495510
# decode
496511
recon = vae.decode_from_ids(ids)
@@ -499,7 +514,13 @@ def main():
499514
else:
500515
# encode
501516
encoded, _, _ = vae.encode(
502-
dataset[i][None].to('cpu' if args.cpu else accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
517+
dataset[i][None].to(
518+
"cpu"
519+
if args.cpu
520+
else accelerator.device
521+
if args.gpu == 0
522+
else f"cuda:{args.gpu}"
523+
)
503524
)
504525

505526
# decode
@@ -531,10 +552,15 @@ def main():
531552
os.remove(f"{output_dir}/input.png")
532553
os.remove(f"{output_dir}/output.png")
533554
else:
534-
os.makedirs(os.path.join(output_dir, 'originals'), exist_ok=True)
535-
shutil.move(f"{output_dir}/input.png", f"{os.path.join(output_dir, 'originals')}/input_{now}.png")
536-
shutil.move(f"{output_dir}/output.png", f"{os.path.join(output_dir, 'originals')}/output_{now}.png")
537-
555+
os.makedirs(os.path.join(output_dir, "originals"), exist_ok=True)
556+
shutil.move(
557+
f"{output_dir}/input.png",
558+
f"{os.path.join(output_dir, 'originals')}/input_{now}.png",
559+
)
560+
shutil.move(
561+
f"{output_dir}/output.png",
562+
f"{os.path.join(output_dir, 'originals')}/output_{now}.png",
563+
)
538564

539565
del _
540566
del ids

0 commit comments

Comments
 (0)