1- import argparse , glob , hashlib
2- import os , random , re , shutil
1+ import argparse
2+ import glob
3+ import hashlib
4+ import os
5+ import random
6+ import re
7+ import shutil
38from dataclasses import dataclass
49from datetime import datetime
510from typing import Optional
@@ -362,7 +367,7 @@ def main():
362367 channels = args .channels ,
363368 layers = args .layers ,
364369 discr_layers = args .discr_layers ,
365- ).to (' cpu' if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
370+ ).to (" cpu" if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
366371
367372 if args .latest_checkpoint :
368373 accelerator .print ("Finding latest checkpoint..." )
@@ -427,7 +432,7 @@ def main():
427432 args .seq_len = vae .get_encoded_fmap_size (args .image_size ) ** 2
428433
429434 # move vae to device
430- vae = vae .to (' cpu' if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
435+ vae = vae .to (" cpu" if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
431436
432437 # Use the parameters() method to get an iterator over all the learnable parameters of the model
433438 total_params = sum (p .numel () for p in vae .parameters ())
@@ -458,7 +463,9 @@ def main():
458463 )
459464
460465 _ , ids , _ = vae .encode (
461- dataset [image_id ][None ].to ('cpu' if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
466+ dataset [image_id ][None ].to (
467+ "cpu" if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } "
468+ )
462469 )
463470 recon = vae .decode_from_ids (ids )
464471 save_image (recon , f"{ args .results_dir } /outputs/output.{ str (args .input_image ).split ('.' )[- 1 ]} " )
@@ -471,7 +478,9 @@ def main():
471478 save_image (dataset [image_id ], f"{ args .results_dir } /outputs/input.png" )
472479
473480 _ , ids , _ = vae .encode (
474- dataset [image_id ][None ].to ('cpu' if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
481+ dataset [image_id ][None ].to (
482+ "cpu" if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } "
483+ )
475484 )
476485 recon = vae .decode_from_ids (ids )
477486 save_image (recon , f"{ args .results_dir } /outputs/output.png" )
@@ -490,7 +499,13 @@ def main():
490499 if not args .use_paintmind :
491500 # encode
492501 _ , ids , _ = vae .encode (
493- dataset [i ][None ].to ('cpu' if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
502+ dataset [i ][None ].to (
503+ "cpu"
504+ if args .cpu
505+ else accelerator .device
506+ if args .gpu == 0
507+ else f"cuda:{ args .gpu } "
508+ )
494509 )
495510 # decode
496511 recon = vae .decode_from_ids (ids )
@@ -499,7 +514,13 @@ def main():
499514 else :
500515 # encode
501516 encoded , _ , _ = vae .encode (
502- dataset [i ][None ].to ('cpu' if args .cpu else accelerator .device if args .gpu == 0 else f"cuda:{ args .gpu } " )
517+ dataset [i ][None ].to (
518+ "cpu"
519+ if args .cpu
520+ else accelerator .device
521+ if args .gpu == 0
522+ else f"cuda:{ args .gpu } "
523+ )
503524 )
504525
505526 # decode
@@ -531,10 +552,15 @@ def main():
531552 os .remove (f"{ output_dir } /input.png" )
532553 os .remove (f"{ output_dir } /output.png" )
533554 else :
534- os .makedirs (os .path .join (output_dir , 'originals' ), exist_ok = True )
535- shutil .move (f"{ output_dir } /input.png" , f"{ os .path .join (output_dir , 'originals' )} /input_{ now } .png" )
536- shutil .move (f"{ output_dir } /output.png" , f"{ os .path .join (output_dir , 'originals' )} /output_{ now } .png" )
537-
555+ os .makedirs (os .path .join (output_dir , "originals" ), exist_ok = True )
556+ shutil .move (
557+ f"{ output_dir } /input.png" ,
558+ f"{ os .path .join (output_dir , 'originals' )} /input_{ now } .png" ,
559+ )
560+ shutil .move (
561+ f"{ output_dir } /output.png" ,
562+ f"{ os .path .join (output_dir , 'originals' )} /output_{ now } .png" ,
563+ )
538564
539565 del _
540566 del ids
0 commit comments