diff --git a/.dev_scripts/diff_images.py b/.dev_scripts/diff_images.py index e21cae214e9..c0376eeba26 100644 --- a/.dev_scripts/diff_images.py +++ b/.dev_scripts/diff_images.py @@ -14,16 +14,14 @@ def calc_images_mean_L1(image1_path, image2_path): image2 = read_image_int16(image2_path) assert image1.shape == image2.shape - mean_L1 = np.abs(image1 - image2).mean() - return mean_L1 + return np.abs(image1 - image2).mean() def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('image1_path') parser.add_argument('image2_path') - args = parser.parse_args() - return args + return parser.parse_args() if __name__ == '__main__': diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index ddff39a7bc5..41c8c9aa601 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -45,8 +45,8 @@ def setup_flask(self): mimetypes.add_type('application/javascript', '.js') mimetypes.add_type('text/css', '.css') # Socket IO - logger = True if args.web_verbose else False - engineio_logger = True if args.web_verbose else False + logger = bool(args.web_verbose) + engineio_logger = bool(args.web_verbose) max_http_buffer_size = 10000000 socketio_args = { @@ -162,7 +162,7 @@ def setup_app(self): def load_socketio_listeners(self, socketio): @socketio.on('requestSystemConfig') def handle_request_capabilities(): - print(f'>> System config requested') + print('>> System config requested') config = self.get_system_config() socketio.emit('systemConfig', config) @@ -372,7 +372,7 @@ def handle_run_postprocessing( @socketio.on('cancel') def handle_cancel(): - print(f'>> Cancel processing requested') + print('>> Cancel processing requested') self.canceled.set() # TODO: I think this needs a safety mechanism. @@ -693,7 +693,6 @@ def image_done(image, seed, first_seed): raise except CanceledException: self.socketio.emit('processingCanceled') - pass except Exception as e: print(e) self.socketio.emit('error', {'message': (str(e))}) @@ -753,9 +752,7 @@ def parameters_to_generated_image_metadata(self, parameters): } ) - rfc_dict['postprocessing'] = ( - postprocessing if len(postprocessing) > 0 else None - ) + rfc_dict['postprocessing'] = postprocessing or None # semantic drift rfc_dict['sampler'] = parameters['sampler_name'] @@ -877,16 +874,15 @@ def save_result_image( seed = 'unknown_seed' - if 'image' in metadata: - if 'seed' in metadata['image']: - seed = metadata['image']['seed'] + if 'image' in metadata and 'seed' in metadata['image']: + seed = metadata['image']['seed'] filename = f'{prefix}.{seed}' if step_index: filename += f'.{step_index}' if postprocessing: - filename += f'.postprocessed' + filename += '.postprocessed' filename += '.png' diff --git a/backend/modules/parameters.py b/backend/modules/parameters.py index 0fae7ef729d..d697105aba8 100644 --- a/backend/modules/parameters.py +++ b/backend/modules/parameters.py @@ -18,7 +18,7 @@ def parameters_to_command(params): Converts dict of parameters into a `invoke.py` REPL command. """ - switches = list() + switches = [] if "prompt" in params: switches.append(f'"{params["prompt"]}"') @@ -35,7 +35,7 @@ def parameters_to_command(params): if "sampler_name" in params: switches.append(f'-A {params["sampler_name"]}') if "seamless" in params and params["seamless"] == True: - switches.append(f"--seamless") + switches.append("--seamless") if "init_img" in params and len(params["init_img"]) > 0: switches.append(f'-I {params["init_img"]}') if "init_mask" in params and len(params["init_mask"]) > 0: @@ -45,7 +45,7 @@ def parameters_to_command(params): if "strength" in params and "init_img" in params: switches.append(f'-f {params["strength"]}') if "fit" in params and params["fit"] == True: - switches.append(f"--fit") + switches.append("--fit") if "gfpgan_strength" in params and params["gfpgan_strength"]: switches.append(f'-G {params["gfpgan_strength"]}') if "upscale" in params and params["upscale"]: diff --git a/backend/modules/parse_seed_weights.py b/backend/modules/parse_seed_weights.py index 7e15d4e166e..115f60ab689 100644 --- a/backend/modules/parse_seed_weights.py +++ b/backend/modules/parse_seed_weights.py @@ -33,11 +33,11 @@ def parse_seed_weights(seed_weights): return False # Seed must be 0 or above - if not seed >= 0: + if seed < 0: return False # Weight must be between 0 and 1 - if not (weight >= 0 and weight <= 1): + if weight < 0 or weight > 1: return False # This pair is valid diff --git a/backend/server.py b/backend/server.py index cc0996dc664..6617e6b6808 100644 --- a/backend/server.py +++ b/backend/server.py @@ -52,9 +52,7 @@ precision = opt.precision free_gpu_mem = opt.free_gpu_mem embedding_path = opt.embedding_path -additional_allowed_origins = ( - opt.cors if opt.cors else [] -) # additional CORS allowed origins +additional_allowed_origins = opt.cors or [] model = "stable-diffusion-1.4" """ @@ -90,8 +88,8 @@ def serve(path): return send_from_directory(app.static_folder, "index.html") -logger = True if verbose else False -engineio_logger = True if verbose else False +logger = bool(verbose) +engineio_logger = bool(verbose) # default 1,000,000, needs to be higher for socketio to accept larger images max_http_buffer_size = 10000000 @@ -131,9 +129,7 @@ class CanceledException(Exception): gfpgan, codeformer = restoration.load_face_restore_models() esrgan = restoration.load_esrgan() - # coreformer.process(self, image, strength, device, seed=None, fidelity=0.75) - -except (ModuleNotFoundError, ImportError): +except ImportError: print(traceback.format_exc(), file=sys.stderr) print(">> You may need to install the ESRGAN and/or GFPGAN modules") @@ -185,7 +181,7 @@ class CanceledException(Exception): @socketio.on("requestSystemConfig") def handle_request_capabilities(): - print(f">> System config requested") + print(">> System config requested") config = get_system_config() socketio.emit("systemConfig", config) @@ -195,7 +191,7 @@ def handle_request_images(page=1, offset=0, last_mtime=None): chunk_size = 50 if last_mtime: - print(f">> Latest images requested") + print(">> Latest images requested") else: print( f">> Page {page} of images requested (page size {chunk_size} offset {offset})" @@ -231,7 +227,7 @@ def handle_request_images(page=1, offset=0, last_mtime=None): "images": image_array, "nextPage": page, "offset": offset, - "onlyNewImages": True if last_mtime else False, + "onlyNewImages": bool(last_mtime), }, ) @@ -389,7 +385,7 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters): @socketio.on("cancel") def handle_cancel(): - print(f">> Cancel processing requested") + print(">> Cancel processing requested") canceled.set() socketio.emit("processingCanceled") @@ -520,7 +516,7 @@ def parameters_to_generated_image_metadata(parameters): } ) - rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None + rfc_dict["postprocessing"] = postprocessing or None # semantic drift rfc_dict["sampler"] = parameters["sampler_name"] @@ -582,25 +578,22 @@ def save_image( seed = "unknown_seed" - if "image" in metadata: - if "seed" in metadata["image"]: - seed = metadata["image"]["seed"] + if "image" in metadata and "seed" in metadata["image"]: + seed = metadata["image"]["seed"] filename = f"{prefix}.{seed}" if step_index: filename += f".{step_index}" if postprocessing: - filename += f".postprocessed" + filename += ".postprocessed" filename += ".png" - path = pngwriter.save_image_and_prompt_to_png( + return pngwriter.save_image_and_prompt_to_png( image=image, dream_prompt=command, metadata=metadata, name=filename ) - return path - def calculate_real_steps(steps, strength, has_init_image): return math.floor(strength * steps) if has_init_image else steps diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py index d155f6d6aea..b6d5945f34e 100644 --- a/ldm/data/imagenet.py +++ b/ldm/data/imagenet.py @@ -28,13 +28,13 @@ def synset2idx(path_to_yaml='data/index_synset.yaml'): with open(path_to_yaml) as f: di2s = yaml.load(f) - return dict((v, k) for k, v in di2s.items()) + return {v: k for k, v in di2s.items()} class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config) == dict: + if type(self.config) != dict: self.config = OmegaConf.to_container(self.config) self.keep_orig_class_label = self.config.get( 'keep_orig_class_label', False @@ -56,56 +56,49 @@ def _prepare(self): raise NotImplementedError() def _filter_relpaths(self, relpaths): - ignore = set( - [ - 'n06596364_9591.JPEG', - ] - ) - relpaths = [ - rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore - ] - if 'sub_indices' in self.config: - indices = str_to_indices(self.config['sub_indices']) - synsets = give_synsets_from_indices( - indices, path_to_yaml=self.idx2syn - ) # returns a list of strings - self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) - files = [] - for rpath in relpaths: - syn = rpath.split('/')[0] - if syn in synsets: - files.append(rpath) - return files - else: + ignore = {'n06596364_9591.JPEG'} + relpaths = [rpath for rpath in relpaths if rpath.split('/')[-1] not in ignore] + if 'sub_indices' not in self.config: return relpaths + indices = str_to_indices(self.config['sub_indices']) + synsets = give_synsets_from_indices( + indices, path_to_yaml=self.idx2syn + ) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split('/')[0] + if syn in synsets: + files.append(rpath) + return files def _prepare_synset_to_human(self): SIZE = 2655750 - URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1' self.human_dict = os.path.join(self.root, 'synset_human.txt') if ( not os.path.exists(self.human_dict) - or not os.path.getsize(self.human_dict) == SIZE + or os.path.getsize(self.human_dict) != SIZE ): + URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1' download(URL, self.human_dict) def _prepare_idx_to_synset(self): - URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1' self.idx2syn = os.path.join(self.root, 'index_synset.yaml') if not os.path.exists(self.idx2syn): + URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1' download(URL, self.idx2syn) def _prepare_human_to_integer_label(self): - URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1' self.human2integer = os.path.join( self.root, 'imagenet1000_clsidx_to_labels.txt' ) if not os.path.exists(self.human2integer): + URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1' download(URL, self.human2integer) with open(self.human2integer, 'r') as f: lines = f.read().splitlines() assert len(lines) == 1000 - self.human2integer_dict = dict() + self.human2integer_dict = {} for line in lines: value, key = line.split(':') self.human2integer_dict[key] = int(value) @@ -116,22 +109,20 @@ def _load(self): l1 = len(self.relpaths) self.relpaths = self._filter_relpaths(self.relpaths) print( - 'Removed {} files from filelist during filtering.'.format( - l1 - len(self.relpaths) - ) + f'Removed {l1 - len(self.relpaths)} files from filelist during filtering.' ) + self.synsets = [p.split('/')[0] for p in self.relpaths] self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] unique_synsets = np.unique(self.synsets) - class_dict = dict( - (synset, i) for i, synset in enumerate(unique_synsets) + class_dict = {synset: i for i, synset in enumerate(unique_synsets)} + self.class_labels = ( + [self.synset2idx[s] for s in self.synsets] + if self.keep_orig_class_label + else [class_dict[s] for s in self.synsets] ) - if not self.keep_orig_class_label: - self.class_labels = [class_dict[s] for s in self.synsets] - else: - self.class_labels = [self.synset2idx[s] for s in self.synsets] with open(self.human_dict, 'r') as f: human_dict = f.read().splitlines() @@ -191,21 +182,21 @@ def _prepare(self): ) if not tdu.is_prepared(self.root): # prep - print('Preparing dataset {} in {}'.format(self.NAME, self.root)) + print(f'Preparing dataset {self.NAME} in {self.root}') datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) if ( not os.path.exists(path) - or not os.path.getsize(path) == self.SIZES[0] + or os.path.getsize(path) != self.SIZES[0] ): import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path - print('Extracting {} to {}'.format(path, datadir)) + print(f'Extracting {path} to {datadir}') os.makedirs(datadir, exist_ok=True) with tarfile.open(path, 'r:') as tar: tar.extractall(path=datadir) @@ -263,21 +254,21 @@ def _prepare(self): ) if not tdu.is_prepared(self.root): # prep - print('Preparing dataset {} in {}'.format(self.NAME, self.root)) + print(f'Preparing dataset {self.NAME} in {self.root}') datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) if ( not os.path.exists(path) - or not os.path.getsize(path) == self.SIZES[0] + or os.path.getsize(path) != self.SIZES[0] ): import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path - print('Extracting {} to {}'.format(path, datadir)) + print(f'Extracting {path} to {datadir}') os.makedirs(datadir, exist_ok=True) with tarfile.open(path, 'r:') as tar: tar.extractall(path=datadir) @@ -285,7 +276,7 @@ def _prepare(self): vspath = os.path.join(self.root, self.FILES[1]) if ( not os.path.exists(vspath) - or not os.path.getsize(vspath) == self.SIZES[1] + or os.path.getsize(vspath) != self.SIZES[1] ): download(self.VS_URL, vspath) @@ -402,7 +393,7 @@ def __getitem__(self, i): example = self.base[i] image = Image.open(example['file_path_']) - if not image.mode == 'RGB': + if image.mode != 'RGB': image = image.convert('RGB') image = np.array(image).astype(np.uint8) diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py index 4a7ecb147ef..69ab73e5030 100644 --- a/ldm/data/lsun.py +++ b/ldm/data/lsun.py @@ -21,12 +21,13 @@ def __init__( self.image_paths = f.read().splitlines() self._length = len(self.image_paths) self.labels = { - 'relative_file_path_': [l for l in self.image_paths], + 'relative_file_path_': list(self.image_paths), 'file_path_': [ os.path.join(self.data_root, l) for l in self.image_paths ], } + self.size = size self.interpolation = { 'linear': PIL.Image.LINEAR, @@ -40,9 +41,9 @@ def __len__(self): return self._length def __getitem__(self, i): - example = dict((k, self.labels[k][i]) for k in self.labels) + example = {k: self.labels[k][i] for k in self.labels} image = Image.open(example['file_path_']) - if not image.mode == 'RGB': + if image.mode != 'RGB': image = image.convert('RGB') # default to score-sde preprocessing diff --git a/ldm/data/personalized.py b/ldm/data/personalized.py index 8d9573fbc62..2dc01965b7c 100644 --- a/ldm/data/personalized.py +++ b/ldm/data/personalized.py @@ -153,10 +153,9 @@ def __len__(self): return self._length def __getitem__(self, i): - example = {} image = Image.open(self.image_paths[i % self.num_images]) - if not image.mode == 'RGB': + if image.mode != 'RGB': image = image.convert('RGB') placeholder_string = self.placeholder_token @@ -174,8 +173,7 @@ def __getitem__(self, i): placeholder_string ) - example['caption'] = text - + example = {'caption': text} # default to score-sde preprocessing img = np.array(image).astype(np.uint8) diff --git a/ldm/data/personalized_style.py b/ldm/data/personalized_style.py index 118d5be9919..ea1526d22b3 100644 --- a/ldm/data/personalized_style.py +++ b/ldm/data/personalized_style.py @@ -126,10 +126,9 @@ def __len__(self): return self._length def __getitem__(self, i): - example = {} image = Image.open(self.image_paths[i % self.num_images]) - if not image.mode == 'RGB': + if image.mode != 'RGB': image = image.convert('RGB') if self.per_image_tokens and np.random.uniform() < 0.25: @@ -141,8 +140,7 @@ def __getitem__(self, i): self.placeholder_token ) - example['caption'] = text - + example = {'caption': text} # default to score-sde preprocessing img = np.array(image).astype(np.uint8) diff --git a/ldm/generate.py b/ldm/generate.py index f9fc364cf3c..0705d58b700 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -257,7 +257,7 @@ def prompt2image( catch_interrupts = False, hires_fix = False, **args, - ): # eat up additional cruft + ): # eat up additional cruft """ ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() It takes the following arguments: @@ -313,7 +313,7 @@ def process_image(image,seed): # will instantiate the model or return it from cache model = self.load_model() - + for m in model.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): m.padding_mode = 'circular' if seamless else m._orig_padding_mode @@ -330,18 +330,20 @@ def process_image(image,seed): 0.0 <= perlin <= 1.0 ), '--perlin must be in [0.0, 1.0]' assert ( - (embiggen == None and embiggen_tiles == None) or ( - (embiggen != None or embiggen_tiles != None) and init_img != None) + embiggen is None + and embiggen_tiles is None + or ((embiggen != None or embiggen_tiles != None) and init_img != None) ), 'Embiggen requires an init/input image to be specified' + if len(with_variations) > 0 or variation_amount > 1.0: assert seed is not None,\ - 'seed must be specified when using with_variations' + 'seed must be specified when using with_variations' if variation_amount == 0.0: assert iterations == 1,\ - 'when using --with_variations, multiple iterations are only possible when using --variation_amount' + 'when using --with_variations, multiple iterations are only possible when using --variation_amount' assert all(0 <= weight <= 1 for _, weight in with_variations),\ - f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' + f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' width, height, _ = self._resolution_check(width, height, log=True) @@ -353,7 +355,7 @@ def process_image(image,seed): if self._has_cuda(): torch.cuda.reset_peak_memory_stats() - results = list() + results = [] init_image = None mask_image = None @@ -439,19 +441,21 @@ def process_image(image,seed): ) if self._has_cuda(): print( - f'>> Max VRAM used for this generation:', + '>> Max VRAM used for this generation:', '%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9), 'Current VRAM utilization:', '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), ) + self.session_peakmem = max( self.session_peakmem, torch.cuda.max_memory_allocated() ) print( - f'>> Max VRAM used since script start: ', + '>> Max VRAM used since script start: ', '%4.2fG' % (self.session_peakmem / 1e9), ) + return results # this needs to be generalized to all sorts of postprocessors, which should be wrapped @@ -486,8 +490,7 @@ def apply_postprocessor( # try to reuse the same filename prefix as the original file. # we take everything up to the first period prefix = None - m = re.match('^([^.]+)\.',os.path.basename(image_path)) - if m: + if m := re.match('^([^.]+)\.', os.path.basename(image_path)): prefix = m.groups()[0] # face fixers and esrgan take an Image, but embiggen takes a path @@ -521,9 +524,10 @@ def apply_postprocessor( elif tool == 'outcrop': from ldm.invoke.restoration.outcrop import Outcrop - extend_instructions = {} - for direction,pixels in _pairwise(opt.outcrop): - extend_instructions[direction]=int(pixels) + extend_instructions = { + direction: int(pixels) + for direction, pixels in _pairwise(opt.outcrop) + } restorer = Outcrop(image,self,) return restorer.process ( @@ -565,9 +569,9 @@ def apply_postprocessor( image_callback = callback, prefix = prefix ) - + elif tool is None: - print(f'* please provide at least one postprocessing option, such as -G or -U') + print('* please provide at least one postprocessing option, such as -G or -U') return None else: print(f'* postprocessing tool {tool} is not yet supported') @@ -601,7 +605,7 @@ def _make_images( self._transparency_check_and_warning(image, mask) # this returns a torch tensor init_mask = self._create_init_mask(image, width, height, fit=fit) - + if (image.width * image.height) > (self.width * self.height): print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") @@ -660,8 +664,10 @@ def load_model(self): model = self._load_model_from_config(self.config, self.weights) if self.embedding_path is not None: model.embedding_manager.load( - self.embedding_path, self.precision == 'float32' or self.precision == 'autocast' + self.embedding_path, + self.precision in ['float32', 'autocast'], ) + self.model = model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here self.model.cond_stage_model.device = self.device @@ -710,7 +716,7 @@ def upscale_and_reconstruct(self, image_callback = None, prefix = None, ): - + for r in image_list: image, seed = r try: @@ -813,9 +819,7 @@ def _load_model_from_config(self, config, weights): # usage statistics toc = time.time() - print( - f'>> Model loaded in', '%4.2fs' % (toc - tic) - ) + print('>> Model loaded in', '%4.2fs' % (toc - tic)) if self._has_cuda(): print( '>> Max VRAM used to load the model:', @@ -910,8 +914,7 @@ def _check_for_erasure(self, image): for x in range(width): if pixdata[x, y][3] == 0: r, g, b, _ = pixdata[x, y] - if (r, g, b) != (0, 0, 0) and \ - (r, g, b) != (255, 255, 255): + if (r, g, b) not in [(0, 0, 0), (255, 255, 255)]: colored += 1 return colored == 0 @@ -928,9 +931,7 @@ def _transparency_check_and_warning(self,image, mask): def _squeeze_image(self, image): x, y, resize_needed = self._resolution_check(image.width, image.height) - if resize_needed: - return InitImageResizer(image).resize(x, y) - return image + return InitImageResizer(image).resize(x, y) if resize_needed else image def _fit_image(self, image, max_dimensions): w, h = max_dimensions @@ -941,8 +942,6 @@ def _fit_image(self, image, max_dimensions): h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height elif image.height > image.width: w = None # ditto for w - else: - pass # note that InitImageResizer does the multiple of 64 truncation internally image = InitImageResizer(image).resize(w, h) print( @@ -973,12 +972,12 @@ def _cached_sha256(self,path,data): dirname = os.path.dirname(path) basename = os.path.basename(path) base, _ = os.path.splitext(basename) - hashpath = os.path.join(dirname,base+'.sha256') + hashpath = os.path.join(dirname, f'{base}.sha256') if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath): with open(hashpath) as f: hash = f.read() return hash - print(f'>> Calculating sha256 hash of weights file') + print('>> Calculating sha256 hash of weights file') tic = time.time() sha = hashlib.sha256() sha.update(data) diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 09073679c03..5d8d83d9248 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -126,15 +126,13 @@ def _format_usage(self, usage, actions, groups, prefix): if usage is not None: usage = usage % dict(prog=self._prog) - # if no optionals or positionals are available, usage is just prog - elif usage is None and not actions: + elif not actions: usage = 'invoke>' - elif usage is None: + else: prog='invoke>' # build full usage string action_usage = self._format_actions_usage(actions, groups) # NEW usage = ' '.join([s for s in [prog, action_usage] if s]) - # omit the long line wrapping code # prefix with 'usage:' return '%s%s\n\n' % (prefix, usage) @@ -187,7 +185,7 @@ def parse_cmd(self,cmd_string): else: switches[0] += element switches[0] += ' ' - switches[0] = switches[0][: len(switches[0]) - 1] + switches[0] = switches[0][:-1] try: self._cmd_switches = self._cmd_parser.parse_args(switches) return self._cmd_switches @@ -198,24 +196,23 @@ def json(self,**kwargs): return json.dumps(self.to_dict(**kwargs)) def to_dict(self,**kwargs): - a = vars(self) - a.update(kwargs) - return a + return vars(self) | kwargs # Isn't there a more automated way of doing this? # Ideally we get the switch strings out of the argparse objects, # but I don't see a documented API for this. def dream_prompt_str(self,**kwargs): """Normalized dream_prompt.""" - a = vars(self) - a.update(kwargs) - switches = list() - switches.append(f'"{a["prompt"]}"') - switches.append(f'-s {a["steps"]}') - switches.append(f'-S {a["seed"]}') - switches.append(f'-W {a["width"]}') - switches.append(f'-H {a["height"]}') - switches.append(f'-C {a["cfg_scale"]}') + a = vars(self) | kwargs + switches = [ + f'"{a["prompt"]}"', + f'-s {a["steps"]}', + f'-S {a["seed"]}', + f'-W {a["width"]}', + f'-H {a["height"]}', + f'-C {a["cfg_scale"]}', + ] + if a['perlin'] > 0: switches.append(f'--perlin {a["perlin"]}') if a['threshold'] > 0: @@ -229,10 +226,9 @@ def dream_prompt_str(self,**kwargs): # img2img generations have parameters relevant only to them and have special handling if a['init_img'] and len(a['init_img'])>0: - switches.append(f'-I {a["init_img"]}') - switches.append(f'-A {a["sampler_name"]}') + switches.extend((f'-I {a["init_img"]}', f'-A {a["sampler_name"]}')) if a['fit']: - switches.append(f'--fit') + switches.append('--fit') if a['init_mask'] and len(a['init_mask'])>0: switches.append(f'-M {a["init_mask"]}') if a['init_color'] and len(a['init_color'])>0: @@ -298,7 +294,7 @@ def __getattribute__(self,name): if not hasattr(cmd_switches,name) and not hasattr(arg_switches,name): raise AttributeError - + value_arg,value_cmd = (None,None) try: value_cmd = getattr(cmd_switches,name) @@ -314,10 +310,7 @@ def __getattribute__(self,name): # the arg value. For example, the --grid and --individual options are a little # funny because of their push/pull relationship. This is how to handle it. if name=='grid': - if cmd_switches.individual: - return False - else: - return value_cmd or value_arg + return False if cmd_switches.individual else value_cmd or value_arg return value_cmd if value_cmd is not None else value_arg def __setattr__(self,name,value): @@ -756,7 +749,7 @@ def _create_dream_cmd_parser(self): return parser def format_metadata(**kwargs): - print(f'format_metadata() is deprecated. Please use metadata_dumps()') + print('format_metadata() is deprecated. Please use metadata_dumps()') return metadata_dumps(kwargs) def metadata_dumps(opt, @@ -887,27 +880,25 @@ def metadata_loads(metadata) -> list: def calculate_init_img_hash(image_string): prefix = 'data:image/png;base64,' hash = None - if image_string.startswith(prefix): - imagebase64 = image_string[len(prefix):] - imagedata = base64.b64decode(imagebase64) - with open('outputs/test.png','wb') as file: - file.write(imagedata) - sha = hashlib.sha256() - sha.update(imagedata) - hash = sha.hexdigest() - else: - hash = sha256(image_string) - return hash + if not image_string.startswith(prefix): + return sha256(image_string) + imagebase64 = image_string[len(prefix):] + imagedata = base64.b64decode(imagebase64) + with open('outputs/test.png','wb') as file: + file.write(imagedata) + sha = hashlib.sha256() + sha.update(imagedata) + return sha.hexdigest() # Bah. This should be moved somewhere else... def sha256(path): sha = hashlib.sha256() with open(path,'rb') as f: while True: - data = f.read(65536) - if not data: + if data := f.read(65536): + sha.update(data) + else: break - sha.update(data) return sha.hexdigest() def legacy_metadata_load(meta,pathname) -> Args: @@ -916,9 +907,8 @@ def legacy_metadata_load(meta,pathname) -> Args: opt = Args() opt.parse_cmd(dream_prompt) return opt - else: # if nothing else, we can get the seed - match = re.search('\d+\.(\d+)',pathname) - if match: + else: # if nothing else, we can get the seed + if match := re.search('\d+\.(\d+)', pathname): seed = match.groups()[0] opt = Args() opt.seed = seed diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c7..39abdd17769 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -94,15 +94,15 @@ def log_tokenization(text, model, log=False, weight=1): discarded = "" usedTokens = 0 totalTokens = len(tokens) - for i in range(0, totalTokens): + for i in range(totalTokens): token = tokens[i].replace('', ' ') # alternate color s = (usedTokens % 6) + 1 if i < model.cond_stage_model.max_length: - tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + tokenized = f"{tokenized}\x1b[0;3{s};40m{token}" usedTokens += 1 else: # over max token length - discarded = discarded + f"\x1b[0;3{s};40m{token}" + discarded = f"{discarded}\x1b[0;3{s};40m{token}" print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") if discarded != "": print( diff --git a/ldm/invoke/devices.py b/ldm/invoke/devices.py index 424ae5a6d30..fe93fcbb5d0 100644 --- a/ldm/invoke/devices.py +++ b/ldm/invoke/devices.py @@ -14,7 +14,10 @@ def choose_precision(device) -> str: '''Returns an appropriate precision for the given torch device''' if device.type == 'cuda': device_name = torch.cuda.get_device_name(device) - if not ('GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name): + if ( + 'GeForce GTX 1660' not in device_name + and 'GeForce GTX 1650' not in device_name + ): return 'float16' return 'float32' @@ -22,6 +25,4 @@ def choose_autocast(precision): '''Returns an autocast context or nullcontext for the given precision string''' # float16 currently requires autocast to avoid errors like: # 'expected scalar type Half but found Float' - if precision == 'autocast' or precision == 'float16': - return autocast - return nullcontext + return autocast if precision in ['autocast', 'float16'] else nullcontext diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 2aa0caf5f95..3605f5ab4a1 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -59,7 +59,7 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None, first_seed = seed seed, initial_noise = self.generate_initial_noise(seed, width, height) with scope(self.model.device.type), self.model.ema_scope(): - for n in trange(iterations, desc='Generating'): + for _ in trange(iterations, desc='Generating'): x_T = None if self.variation_amount > 0: seed_everything(seed) diff --git a/ldm/invoke/generator/embiggen.py b/ldm/invoke/generator/embiggen.py index 53fbde68cf2..874765c5932 100644 --- a/ldm/invoke/generator/embiggen.py +++ b/ldm/invoke/generator/embiggen.py @@ -28,17 +28,17 @@ def generate(self,prompt,iterations=1,seed=None, **kwargs ) results = [] - seed = seed if seed else self.new_seed() + seed = seed or self.new_seed() # Noise will be generated by the Img2Img generator when called with scope(self.model.device.type), self.model.ema_scope(): - for n in trange(iterations, desc='Generating'): + for _ in trange(iterations, desc='Generating'): # make_image will call Img2Img which will do the equivalent of get_noise itself image = make_image() results.append([image, seed]) if image_callback is not None: image_callback(image, seed) - seed = self.new_seed() + seed = self.new_seed() return results @torch.no_grad() @@ -64,7 +64,7 @@ def get_make_image( Return value depends on the seed at the time you call it """ # Construct embiggen arg array, and sanity check arguments - if embiggen == None: # embiggen can also be called with just embiggen_tiles + if embiggen is None: # embiggen can also be called with just embiggen_tiles embiggen = [1.0] # If not specified, assume no scaling elif embiggen[0] < 0: embiggen[0] = 1.0 @@ -84,9 +84,7 @@ def get_make_image( # Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math # and then sort them, because... people. if embiggen_tiles: - embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles)) - embiggen_tiles.sort() - + embiggen_tiles = sorted(map(lambda n: n-1, embiggen_tiles)) if strength >= 0.5: print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.') @@ -134,8 +132,7 @@ def get_make_image( # Use width and height as tile widths and height # Determine buffer size in pixels if embiggen[2] < 1: - if embiggen[2] < 0: - embiggen[2] = 0 + embiggen[2] = max(embiggen[2], 0) overlap_size_x = round(embiggen[2] * width) overlap_size_y = round(embiggen[2] * height) else: @@ -174,11 +171,10 @@ def ceildiv(a, b): # Find distance to lower right corner (numpy takes arrays) distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0] # Clamp values to max 255 - if distanceToLR > 255: - distanceToLR = 255 + distanceToLR = min(distanceToLR, 255) #Place the pixel as invert of distance agradientC.putpixel((x, y), round(255 - distanceToLR)) - + # Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges # Fits for a left-fading gradient on the bottom side and full opacity on the right side. agradientAsymC = Image.new('L', (256, 256)) @@ -307,7 +303,7 @@ def make_image(): seed = 0 # Determine if this is a re-run and replace - if embiggen_tiles and not tile in embiggen_tiles: + if embiggen_tiles and tile not in embiggen_tiles: continue # Get row and column entries emb_row_i = tile // emb_tiles_x @@ -490,9 +486,13 @@ def make_image(): # Layer tile onto final image outputsuperimage.alpha_composite(intileimage, (left, top)) else: - print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') + print( + 'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.' + ) + # after internal loops and patching up return Embiggen image return outputsuperimage + # end of function declaration return make_image diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 3ab14565644..3c13892f14b 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -26,11 +26,9 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, """ # klms samplers not supported yet, so ignore previous sampler if isinstance(sampler,KSampler): - print( - f">> Using recommended DDIM sampler for inpainting." - ) + print(">> Using recommended DDIM sampler for inpainting.") sampler = DDIMSampler(self.model, device=self.model.device) - + sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 945ebadd90b..6ac8ed43c0d 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -116,7 +116,7 @@ def get_noise(self,width,height,scale = True): else: scaled_width = width scaled_height = height - + device = self.model.device if device.type == 'mps': return torch.randn([1, diff --git a/ldm/invoke/image_util.py b/ldm/invoke/image_util.py index 2ec7b55834c..aa40f756ea4 100644 --- a/ldm/invoke/image_util.py +++ b/ldm/invoke/image_util.py @@ -19,7 +19,7 @@ def resize(self,width=None,height=None) -> Image: that it can be passed to img2img() """ im = self.image - + ar = im.width/float(im.height) # Infer missing values from aspect ratio @@ -44,14 +44,12 @@ def resize(self,width=None,height=None) -> Image: # no resize necessary, but return a copy if im.width == width and im.height == height: return im.copy() - - # otherwise resize the original image so that it fits inside the bounding box - resized_image = self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS) - return resized_image + + return self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS) def make_grid(image_list, rows=None, cols=None): - image_cnt = len(image_list) if None in (rows, cols): + image_cnt = len(image_list) rows = floor(sqrt(image_cnt)) # try to make it square cols = ceil(image_cnt / rows) width = image_list[0].width @@ -59,8 +57,8 @@ def make_grid(image_list, rows=None, cols=None): grid_img = Image.new('RGB', (width * cols, height * rows)) i = 0 - for r in range(0, rows): - for c in range(0, cols): + for r in range(rows): + for c in range(cols): if i >= len(image_list): break grid_img.paste(image_list[i], (c * width, r * height)) diff --git a/ldm/invoke/log.py b/ldm/invoke/log.py index 8aebe626713..77acbb99fe5 100644 --- a/ldm/invoke/log.py +++ b/ldm/invoke/log.py @@ -26,19 +26,17 @@ def write_log_message(results, output_cntr): return output_cntr log_lines = [f"{path}: {prompt}\n" for path, prompt in results] if len(log_lines)>1: - subcntr = 1 - for l in log_lines: - print(f"[{output_cntr}.{subcntr}] {l}", end="") - subcntr += 1 + for subcntr, l in enumerate(log_lines, start=1): + print(f"[{output_cntr}.{subcntr}] {l}", end="") else: - print(f"[{output_cntr}] {log_lines[0]}", end="") + print(f"[{output_cntr}] {log_lines[0]}", end="") return output_cntr+1 def write_log_files(results, log_path, file_types): for file_type in file_types: if file_type == "txt": write_log_txt(log_path, results) - elif file_type == "md" or file_type == "markdown": + elif file_type in ["md", "markdown"]: write_log_markdown(log_path, results) else: print(f"'{file_type}' format is not supported, so write in plain text") @@ -47,13 +45,13 @@ def write_log_files(results, log_path, file_types): def write_log_default(log_path, results, file_type): plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results] - with open(log_path + "." + file_type, "a", encoding="utf-8") as file: + with open(f"{log_path}.{file_type}", "a", encoding="utf-8") as file: file.writelines(plain_txt_lines) def write_log_txt(log_path, results): txt_lines = [f"{path}: {prompt}\n" for path, prompt in results] - with open(log_path + ".txt", "a", encoding="utf-8") as file: + with open(f"{log_path}.txt", "a", encoding="utf-8") as file: file.writelines(txt_lines) @@ -62,5 +60,5 @@ def write_log_markdown(log_path, results): for path, prompt in results: file_name = os.path.basename(path) md_lines.append(f"## {file_name}\n![]({file_name})\n\n{prompt}\n") - with open(log_path + ".md", "a", encoding="utf-8") as file: + with open(f"{log_path}.md", "a", encoding="utf-8") as file: file.writelines(md_lines) diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index 73664ef82ce..8d8116ec7bb 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -8,6 +8,7 @@ completer.add_seed(18247566) completer.add_seed(9281839) """ + import os import re import atexit @@ -17,7 +18,7 @@ try: import readline readline_available = True -except (ImportError,ModuleNotFoundError): +except ImportError: readline_available = False IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF') @@ -66,7 +67,7 @@ class Completer(object): def __init__(self, options): self.options = sorted(options) self.seeds = set() - self.matches = list() + self.matches = [] self.default_dir = None self.linebuffer = None self.auto_history_active = True @@ -155,13 +156,13 @@ def show_history(self,match=None): Print the session history using the pydoc pager ''' import pydoc - lines = list() + lines = [] h_len = self.get_current_history_length() if h_len < 1: print('') return - - for i in range(0,h_len): + + for i in range(h_len): line = self.get_history_item(i+1) if match and match not in line: continue @@ -173,18 +174,14 @@ def set_line(self,line)->None: readline.redisplay() def _seed_completions(self, text, state): - m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text) - if m: + if m := re.search('(-S\s?|--seed[=\s]?)(\d*)', text): switch = m.groups()[0] partial = m.groups()[1] else: switch = '' partial = text - matches = list() - for s in self.seeds: - if s.startswith(partial): - matches.append(switch+s) + matches = [switch+s for s in self.seeds if s.startswith(partial)] matches.sort() return matches @@ -204,7 +201,7 @@ def _path_completions(self, text, state, extensions, shortcut_ok=True): switch,partial_path = match.groups() partial_path = partial_path.lstrip() - matches = list() + matches = [] path = os.path.expanduser(partial_path) if os.path.isdir(path): @@ -232,7 +229,7 @@ def _path_completions(self, text, state, extensions, shortcut_ok=True): if switch is None: match_path = os.path.join(dir,node) - matches.append(match_path+'/' if os.path.isdir(full_path) else match_path) + matches.append(f'{match_path}/' if os.path.isdir(full_path) else match_path) elif os.path.isdir(full_path): matches.append( switch+os.path.join(os.path.dirname(full_path), node) + '/' @@ -246,13 +243,13 @@ def _path_completions(self, text, state, extensions, shortcut_ok=True): class DummyCompleter(Completer): def __init__(self,options): super().__init__(options) - self.history = list() + self.history = [] def add_history(self,line): self.history.append(line) def clear_history(self): - self.history = list() + self.history = [] def get_current_history_length(self): return len(self.history) diff --git a/ldm/invoke/restoration/codeformer.py b/ldm/invoke/restoration/codeformer.py index 0d13ae0a36b..e66f89181c7 100644 --- a/ldm/invoke/restoration/codeformer.py +++ b/ldm/invoke/restoration/codeformer.py @@ -14,7 +14,7 @@ def __init__(self, self.codeformer_model_exists = os.path.isfile(self.model_path) if not self.codeformer_model_exists: - print('## NOT FOUND: CodeFormer model not found at ' + self.model_path) + print(f'## NOT FOUND: CodeFormer model not found at {self.model_path}') sys.path.append(os.path.abspath(codeformer_dir)) def process(self, image, strength, device, seed=None, fidelity=0.75): diff --git a/ldm/invoke/restoration/codeformer_arch.py b/ldm/invoke/restoration/codeformer_arch.py index b23872b18fb..78ead9c8da3 100644 --- a/ldm/invoke/restoration/codeformer_arch.py +++ b/ldm/invoke/restoration/codeformer_arch.py @@ -82,8 +82,7 @@ def forward(self, x, mask=None): pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos + return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) def _get_activation_fn(activation): """Return an activation function given a string""" @@ -153,8 +152,7 @@ def forward(self, enc_feat, dec_feat, w=1): scale = self.scale(enc_feat) shift = self.shift(enc_feat) residual = w * (dec_feat * scale + shift) - out = dec_feat + residual - return out + return dec_feat + residual @ARCH_REGISTRY.register() @@ -186,7 +184,7 @@ def __init__(self, dim_embd=512, n_head=8, n_layers=9, self.idx_pred_layer = nn.Sequential( nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)) - + self.channels = { '16': 512, '32': 256, @@ -266,11 +264,10 @@ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] for i, block in enumerate(self.generator.blocks): - x = block(x) - if i in fuse_list: # fuse after i-th block + x = block(x) + if i in fuse_list and w > 0: f_size = str(x.shape[-1]) - if w>0: - x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) out = x # logits doesn't need softmax before cross_entropy loss return out, logits, lq_feat diff --git a/ldm/invoke/restoration/gfpgan.py b/ldm/invoke/restoration/gfpgan.py index 473d708961b..dd554f8e367 100644 --- a/ldm/invoke/restoration/gfpgan.py +++ b/ldm/invoke/restoration/gfpgan.py @@ -17,7 +17,7 @@ def __init__( self.gfpgan_model_exists = os.path.isfile(self.model_path) if not self.gfpgan_model_exists: - print('## NOT FOUND: GFPGAN model not found at ' + self.model_path) + print(f'## NOT FOUND: GFPGAN model not found at {self.model_path}') return None sys.path.append(os.path.abspath(gfpgan_dir)) @@ -46,9 +46,7 @@ def process(self, image, strength: float, seed: str = None): print(traceback.format_exc(), file=sys.stderr) if self.gfpgan is None: - print( - f'>> WARNING: GFPGAN not initialized.' - ) + print('>> WARNING: GFPGAN not initialized.') print( f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.' ) diff --git a/ldm/invoke/restoration/outcrop.py b/ldm/invoke/restoration/outcrop.py index 017d9de7e19..c5d788fd4c6 100644 --- a/ldm/invoke/restoration/outcrop.py +++ b/ldm/invoke/restoration/outcrop.py @@ -59,9 +59,8 @@ def _extend_all( adjacent image. ''' image = self.image - for direction in extents: + for direction, pixels in extents.items(): assert direction in ['top', 'left', 'bottom', 'right'],'Direction must be one of "top", "left", "bottom", "right"' - pixels = extents[direction] # round pixels up to the nearest 64 pixels = math.ceil(pixels/64) * 64 print(f'>> extending image {direction}ward by {pixels} pixels') diff --git a/ldm/invoke/restoration/outpaint.py b/ldm/invoke/restoration/outpaint.py index e75b48221ff..e76ee421109 100644 --- a/ldm/invoke/restoration/outpaint.py +++ b/ldm/invoke/restoration/outpaint.py @@ -33,7 +33,11 @@ def wrapped_callback(img,seed,**kwargs): ) def _create_outpaint_image(self, image, direction_args): - assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.' + assert len(direction_args) in { + 1, + 2, + }, 'Direction (-D) must have exactly one or two arguments.' + if len(direction_args) == 1: direction = direction_args[0] @@ -46,10 +50,10 @@ def _create_outpaint_image(self, image, direction_args): image = image.convert("RGBA") # we always extend top, but rotate to extend along the requested side - if direction == 'left': - image = image.transpose(Image.Transpose.ROTATE_270) - elif direction == 'bottom': + if direction == 'bottom': image = image.transpose(Image.Transpose.ROTATE_180) + elif direction == 'left': + image = image.transpose(Image.Transpose.ROTATE_270) elif direction == 'right': image = image.transpose(Image.Transpose.ROTATE_90) @@ -81,10 +85,10 @@ def _create_outpaint_image(self, image, direction_args): new_img.putpixel((x, y), (r, g, b, 0)) # let's rotate back again - if direction == 'left': - new_img = new_img.transpose(Image.Transpose.ROTATE_90) - elif direction == 'bottom': + if direction == 'bottom': new_img = new_img.transpose(Image.Transpose.ROTATE_180) + elif direction == 'left': + new_img = new_img.transpose(Image.Transpose.ROTATE_90) elif direction == 'right': new_img = new_img.transpose(Image.Transpose.ROTATE_270) diff --git a/ldm/invoke/restoration/realesrgan.py b/ldm/invoke/restoration/realesrgan.py index dc3eebd9123..652853afdc2 100644 --- a/ldm/invoke/restoration/realesrgan.py +++ b/ldm/invoke/restoration/realesrgan.py @@ -9,17 +9,10 @@ class ESRGAN(): def __init__(self, bg_tile_size=400) -> None: self.bg_tile_size = bg_tile_size - if not torch.cuda.is_available(): # CPU or MPS on M1 - use_half_precision = False - else: - use_half_precision = True + use_half_precision = bool(torch.cuda.is_available()) def load_esrgan_bg_upsampler(self): - if not torch.cuda.is_available(): # CPU or MPS on M1 - use_half_precision = False - else: - use_half_precision = True - + use_half_precision = bool(torch.cuda.is_available()) from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan import RealESRGANer @@ -27,7 +20,7 @@ def load_esrgan_bg_upsampler(self): model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' scale = 4 - bg_upsampler = RealESRGANer( + return RealESRGANer( scale=scale, model_path=model_path, model=model, @@ -37,8 +30,6 @@ def load_esrgan_bg_upsampler(self): half=use_half_precision, ) - return bg_upsampler - def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2): with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) diff --git a/ldm/invoke/restoration/vqgan_arch.py b/ldm/invoke/restoration/vqgan_arch.py index f6dfcf4c998..8a624c9be3b 100644 --- a/ldm/invoke/restoration/vqgan_arch.py +++ b/ldm/invoke/restoration/vqgan_arch.py @@ -239,10 +239,7 @@ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution curr_res = self.resolution in_ch_mult = (1,)+tuple(ch_mult) - blocks = [] - # initial convultion - blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) - + blocks = [nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)] # residual and downsampling blocks, with attention on smaller res (16x16) for i in range(self.num_resolutions): block_in_ch = nf * in_ch_mult[i] @@ -277,20 +274,22 @@ def forward(self, x): class Generator(nn.Module): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): super().__init__() - self.nf = nf - self.ch_mult = ch_mult + self.nf = nf + self.ch_mult = ch_mult self.num_resolutions = len(self.ch_mult) self.num_res_blocks = res_blocks - self.resolution = img_size + self.resolution = img_size self.attn_resolutions = attn_resolutions self.in_channels = emb_dim self.out_channels = 3 block_in_ch = self.nf * self.ch_mult[-1] curr_res = self.resolution // 2 ** (self.num_resolutions-1) - blocks = [] - # initial conv - blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + blocks = [ + nn.Conv2d( + self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1 + ) + ] # non-local attention block blocks.append(ResBlock(block_in_ch, block_in_ch)) @@ -330,9 +329,9 @@ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, att beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks self.codebook_size = codebook_size self.embed_dim = emb_dim self.ch_mult = ch_mult @@ -380,7 +379,7 @@ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, att self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) logger.info(f'vqgan is loaded from: {model_path} [params]') else: - raise ValueError(f'Wrong params!') + raise ValueError('Wrong params!') def forward(self, x): @@ -429,7 +428,7 @@ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): elif 'params' in chkpt: self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) else: - raise ValueError(f'Wrong params!') + raise ValueError('Wrong params!') def forward(self, x): return self.main(x) \ No newline at end of file diff --git a/ldm/invoke/server.py b/ldm/invoke/server.py index 4eef1ddd562..3e12433caea 100644 --- a/ldm/invoke/server.py +++ b/ldm/invoke/server.py @@ -90,13 +90,13 @@ def do_GET(self): config = { 'gfpgan_model_exists': self.gfpgan_model_exists } - self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8")) + self.wfile.write(bytes(f"let config = {json.dumps(config)}" + ";\n", "utf-8")) elif self.path == "/run_log.json": self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() output = [] - + log_file = os.path.join(self.outdir, "legacy_web_log.txt") if os.path.exists(log_file): with open(log_file, "r") as log: @@ -118,7 +118,7 @@ def do_GET(self): path_dir = os.path.dirname(self.path) out_dir = os.path.realpath(self.outdir.rstrip('/')) if self.path.startswith('/static/legacy_web/'): - path = '.' + self.path + path = f'.{self.path}' elif out_dir.replace('\\', '/').endswith(path_dir): file = os.path.basename(self.path) path = os.path.join(self.outdir,file) @@ -256,7 +256,7 @@ def image_progress(sample, step): # Remove the temp file os.remove("./img2img-tmp.png") except CanceledException: - print(f"Canceled.") + print("Canceled.") return except Exception as e: print("Error happened") diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py index 79c1d1978e7..c1bdb278550 100644 --- a/ldm/lr_scheduler.py +++ b/ldm/lr_scheduler.py @@ -24,17 +24,14 @@ def __init__( self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f'current step: {n}, recent lr-multiplier: {self.last_lr}' - ) + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + print( + f'current step: {n}, recent lr-multiplier: {self.last_lr}' + ) if n < self.lr_warm_up_steps: lr = ( self.lr_max - self.lr_start ) / self.lr_warm_up_steps * n + self.lr_start - self.last_lr = lr - return lr else: t = (n - self.lr_warm_up_steps) / ( self.lr_max_decay_steps - self.lr_warm_up_steps @@ -43,8 +40,9 @@ def schedule(self, n, **kwargs): lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 1 + np.cos(t * np.pi) ) - self.last_lr = lr - return lr + + self.last_lr = lr + return lr def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) @@ -82,27 +80,22 @@ def __init__( self.verbosity_interval = verbosity_interval def find_in_interval(self, n): - interval = 0 - for cl in self.cum_cycles[1:]: + for interval, cl in enumerate(self.cum_cycles[1:]): if n <= cl: return interval - interval += 1 def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f'current step: {n}, recent lr-multiplier: {self.last_f}, ' - f'current cycle {cycle}' - ) + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + print( + f'current step: {n}, recent lr-multiplier: {self.last_f}, ' + f'current cycle {cycle}' + ) if n < self.lr_warm_up_steps[cycle]: f = ( self.f_max[cycle] - self.f_start[cycle] ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f else: t = (n - self.lr_warm_up_steps[cycle]) / ( self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] @@ -111,8 +104,9 @@ def schedule(self, n, **kwargs): f = self.f_min[cycle] + 0.5 * ( self.f_max[cycle] - self.f_min[cycle] ) * (1 + np.cos(t * np.pi)) - self.last_f = f - return f + + self.last_f = f + return f def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) @@ -122,22 +116,20 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f'current step: {n}, recent lr-multiplier: {self.last_f}, ' - f'current cycle {cycle}' - ) + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + print( + f'current step: {n}, recent lr-multiplier: {self.last_f}, ' + f'current cycle {cycle}' + ) if n < self.lr_warm_up_steps[cycle]: f = ( self.f_max[cycle] - self.f_start[cycle] ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f else: f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( self.cycle_lengths[cycle] - n ) / (self.cycle_lengths[cycle]) - self.last_f = f - return f + + self.last_f = f + return f diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 359f5688d15..cc71d85a0c7 100644 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -94,7 +94,7 @@ def init_from_ckpt(self, path, ignore_keys=list()): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) + print(f'Deleting key {k} from state_dict.') del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print( @@ -121,20 +121,16 @@ def encode_to_prequant(self, x): def decode(self, quant): quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec + return self.decoder(quant) def decode_code(self, code_b): quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec + return self.decode(quant_b) def forward(self, input, return_pred_indices=False): quant, diff, (_, _, ind) = self.encode(input) dec = self.decode(quant) - if return_pred_indices: - return dec, diff, ind - return dec, diff + return (dec, diff, ind) if return_pred_indices else (dec, diff) def get_input(self, batch, k): x = batch[k] @@ -226,10 +222,11 @@ def _validation_step(self, batch, batch_idx, suffix=''): 0, self.global_step, last_layer=self.get_last_layer(), - split='val' + suffix, + split=f'val{suffix}', predicted_indices=ind, ) + discloss, log_dict_disc = self.loss( qloss, x, @@ -237,9 +234,10 @@ def _validation_step(self, batch, batch_idx, suffix=''): 1, self.global_step, last_layer=self.get_last_layer(), - split='val' + suffix, + split=f'val{suffix}', predicted_indices=ind, ) + rec_loss = log_dict_ae[f'val{suffix}/rec_loss'] self.log( f'val{suffix}/rec_loss', @@ -310,7 +308,7 @@ def get_last_layer(self): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: @@ -360,8 +358,7 @@ def decode(self, h, force_not_quantize=False): else: quant = h quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec + return self.decoder(quant) class AutoencoderKL(pl.LightningModule): @@ -405,7 +402,7 @@ def init_from_ckpt(self, path, ignore_keys=list()): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) + print(f'Deleting key {k} from state_dict.') del sd[k] self.load_state_dict(sd, strict=False) print(f'Restored from {path}') @@ -413,20 +410,15 @@ def init_from_ckpt(self, path, ignore_keys=list()): def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior + return DiagonalGaussianDistribution(moments) def decode(self, z): z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec + return self.decoder(z) def forward(self, input, sample_posterior=True): posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() + z = posterior.sample() if sample_posterior else posterior.mode() dec = self.decode(z) return dec, posterior @@ -550,7 +542,7 @@ def get_last_layer(self): @torch.no_grad() def log_images(self, batch, only_inputs=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if not only_inputs: @@ -588,9 +580,7 @@ def decode(self, x, *args, **kwargs): return x def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x + return (x, None, [None, None, None]) if self.vq_interface else x def forward(self, x, *args, **kwargs): return x diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py index be0d8c19198..fc873b836d1 100644 --- a/ldm/models/diffusion/classifier.py +++ b/ldm/models/diffusion/classifier.py @@ -61,11 +61,12 @@ def __init__( self.log_steps = log_steps self.label_key = ( - label_key - if not hasattr(self.diffusion_model, 'cond_stage_key') - else self.diffusion_model.cond_stage_key + self.diffusion_model.cond_stage_key + if hasattr(self.diffusion_model, 'cond_stage_key') + else label_key ) + assert ( self.label_key is not None ), 'label_key neither in diffusion model nor in model.params' @@ -87,13 +88,14 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) + print(f'Deleting key {k} from state_dict.') del sd[k] missing, unexpected = ( - self.load_state_dict(sd, strict=False) - if not only_model - else self.model.load_state_dict(sd, strict=False) + self.model.load_state_dict(sd, strict=False) + if only_model + else self.load_state_dict(sd, strict=False) ) + print( f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' ) @@ -172,13 +174,13 @@ def get_conditioning(self, batch, k=None): if self.label_key == 'segmentation': targets = rearrange(targets, 'b h w c -> b c h w') - for down in range(self.numd): + for _ in range(self.numd): h, w = targets.shape[-2:] targets = F.interpolate( targets, size=(h // 2, w // 2), mode='nearest' ) - # targets = rearrange(targets,'b c h w -> b h w c') + # targets = rearrange(targets,'b c h w -> b h w c') return targets @@ -198,8 +200,7 @@ def on_train_epoch_start(self): @torch.no_grad() def write_logs(self, loss, logits, targets): log_prefix = 'train' if self.training else 'val' - log = {} - log[f'{log_prefix}/loss'] = loss.mean() + log = {f'{log_prefix}/loss': loss.mean()} log[f'{log_prefix}/acc@1'] = self.compute_top_k( logits, targets, k=1, reduction='mean' ) @@ -320,10 +321,8 @@ def configure_optimizers(self): @torch.no_grad() def log_images(self, batch, N=8, *args, **kwargs): - log = dict() x = self.get_input(batch, self.diffusion_model.first_stage_key) - log['inputs'] = x - + log = {'inputs': x} y = self.get_conditioning(batch) if self.label_key == 'class_label': diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 3f103da767b..251cad9913b 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -287,13 +287,14 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print('Deleting key {} from state_dict.'.format(k)) + print(f'Deleting key {k} from state_dict.') del sd[k] missing, unexpected = ( - self.load_state_dict(sd, strict=False) - if not only_model - else self.model.load_state_dict(sd, strict=False) + self.model.load_state_dict(sd, strict=False) + if only_model + else self.load_state_dict(sd, strict=False) ) + print( f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' ) @@ -388,12 +389,7 @@ def p_sample_loop(self, shape, return_intermediates=False): b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] - for i in tqdm( - reversed(range(0, self.num_timesteps)), - desc='Sampling t', - total=self.num_timesteps, - dynamic_ncols=True, - ): + for i in tqdm(reversed(range(self.num_timesteps)), desc='Sampling t', total=self.num_timesteps, dynamic_ncols=True): img = self.p_sample( img, torch.full((b,), i, device=device, dtype=torch.long), @@ -401,9 +397,7 @@ def p_sample_loop(self, shape, return_intermediates=False): ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) - if return_intermediates: - return img, intermediates - return img + return (img, intermediates) if return_intermediates else img @torch.no_grad() def sample(self, batch_size=16, return_intermediates=False): @@ -447,7 +441,6 @@ def p_losses(self, x_start, t, noise=None): x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t) - loss_dict = {} if self.parameterization == 'eps': target = noise elif self.parameterization == 'x0': @@ -461,15 +454,15 @@ def p_losses(self, x_start, t, noise=None): log_prefix = 'train' if self.training else 'val' - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_dict = {f'{log_prefix}/loss_simple': loss.mean()} loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + loss_dict[f'{log_prefix}/loss_vlb'] = loss_vlb loss = loss_simple + self.original_elbo_weight * loss_vlb - loss_dict.update({f'{log_prefix}/loss': loss}) + loss_dict[f'{log_prefix}/loss'] = loss return loss, loss_dict @@ -528,9 +521,7 @@ def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = { - key + '_ema': loss_dict_ema[key] for key in loss_dict_ema - } + loss_dict_ema = {f'{key}_ema': loss_dict_ema[key] for key in loss_dict_ema} self.log_dict( loss_dict_no_ema, prog_bar=False, @@ -561,15 +552,13 @@ def _get_rows_from_list(self, samples): def log_images( self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs ): - log = dict() x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) x = x.to(self.device)[:N] - log['inputs'] = x - + log = {'inputs': x} # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -603,9 +592,8 @@ def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt + params += [self.logvar] + return torch.optim.AdamW(params, lr=lr) class LatentDiffusion(DDPM): @@ -754,23 +742,7 @@ def instantiate_first_stage(self, config): param.requires_grad = False def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == '__is_first_stage__': - print('Using first stage also as cond stage.') - self.cond_stage_model = self.first_stage_model - elif config == '__is_unconditional__': - print( - f'Training {self.__class__.__name__} as an unconditional model.' - ) - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: + if self.cond_stage_trainable: assert config != '__is_first_stage__' assert config != '__is_unconditional__' try: @@ -781,6 +753,22 @@ def instantiate_cond_stage(self, config): ) self.cond_stage_model = model + elif config == '__is_first_stage__': + print('Using first stage also as cond stage.') + self.cond_stage_model = self.first_stage_model + elif config == '__is_unconditional__': + print( + f'Training {self.__class__.__name__} as an unconditional model.' + ) + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + def instantiate_embedding_manager(self, config, embedder): model = instantiate_from_config(config, embedder=embedder) @@ -794,14 +782,14 @@ def instantiate_embedding_manager(self, config, embedder): def _get_denoise_row_from_list( self, samples, desc='', force_no_decoder_quantization=False ): - denoise_row = [] - for zd in tqdm(samples, desc=desc): - denoise_row.append( - self.decode_first_stage( - zd.to(self.device), - force_not_quantize=force_no_decoder_quantization, - ) + denoise_row = [ + self.decode_first_stage( + zd.to(self.device), + force_not_quantize=force_no_decoder_quantization, ) + for zd in tqdm(samples, desc=desc) + ] + n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') @@ -841,8 +829,7 @@ def meshgrid(self, h, w): y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) - arr = torch.cat([y, x], dim=-1) - return arr + return torch.cat([y, x], dim=-1) def delta_border(self, h, w): """ @@ -855,10 +842,9 @@ def delta_border(self, h, w): arr = self.meshgrid(h, w) / lower_right_corner dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min( + return torch.min( torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 )[0] - return edge_dist def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) @@ -993,17 +979,16 @@ def get_input( if self.model.conditioning_key is not None: if cond_key is None: cond_key = self.cond_stage_key - if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox']: - xc = batch[cond_key] - elif cond_key == 'class_label': - xc = batch - else: - xc = super().get_input(batch, cond_key).to(self.device) - else: + if cond_key == self.first_stage_key: xc = x + elif cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): + if isinstance(xc, (dict, list)): # import pudb; pudb.set_trace() c = self.get_learned_conditioning(xc) else: @@ -1046,75 +1031,76 @@ def decode_first_stage( z = 1.0 / self.scale_factor * z - if hasattr(self, 'split_input_params'): - if self.split_input_params['patch_distributed_vq']: - ks = self.split_input_params['ks'] # eg. (128, 128) - stride = self.split_input_params['stride'] # eg. (64, 64) - uf = self.split_input_params['vqf'] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print('reducing Kernel') - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print('reducing stride') - - fold, unfold, normalization, weighting = self.get_fold_unfold( - z, ks, stride, uf=uf - ) + if ( + hasattr(self, 'split_input_params') + and not self.split_input_params['patch_distributed_vq'] + and isinstance(self.first_stage_model, VQModelInterface) + or not hasattr(self, 'split_input_params') + and isinstance(self.first_stage_model, VQModelInterface) + ): + return self.first_stage_model.decode( + z, + force_not_quantize=predict_cids or force_not_quantize, + ) + elif ( + hasattr(self, 'split_input_params') + and not self.split_input_params['patch_distributed_vq'] + or not hasattr(self, 'split_input_params') + ): + return self.first_stage_model.decode(z) - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) + else: + ks = self.split_input_params['ks'] # eg. (128, 128) + stride = self.split_input_params['stride'] # eg. (64, 64) + uf = self.split_input_params['vqf'] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print('reducing Kernel') - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [ - self.first_stage_model.decode( - z[:, :, :, :, i], - force_not_quantize=predict_cids - or force_not_quantize, - ) - for i in range(z.shape[-1]) - ] - else: + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print('reducing stride') - output_list = [ - self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf + ) - o = torch.stack( - output_list, axis=-1 - ) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view( - (o.shape[0], -1, o.shape[-1]) - ) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, - force_not_quantize=predict_cids or force_not_quantize, - ) - else: - return self.first_stage_model.decode(z) + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) - else: + # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, force_not_quantize=predict_cids or force_not_quantize - ) + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], + force_not_quantize=predict_cids + or force_not_quantize, + ) + for i in range(z.shape[-1]) + ] else: - return self.first_stage_model.decode(z) + + output_list = [ + self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack( + output_list, axis=-1 + ) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view( + (o.shape[0], -1, o.shape[-1]) + ) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded # same as above but without decorator def differentiable_decode_first_stage( @@ -1130,128 +1116,126 @@ def differentiable_decode_first_stage( z = 1.0 / self.scale_factor * z - if hasattr(self, 'split_input_params'): - if self.split_input_params['patch_distributed_vq']: - ks = self.split_input_params['ks'] # eg. (128, 128) - stride = self.split_input_params['stride'] # eg. (64, 64) - uf = self.split_input_params['vqf'] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print('reducing Kernel') - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print('reducing stride') - - fold, unfold, normalization, weighting = self.get_fold_unfold( - z, ks, stride, uf=uf - ) + if ( + hasattr(self, 'split_input_params') + and not self.split_input_params['patch_distributed_vq'] + and isinstance(self.first_stage_model, VQModelInterface) + or not hasattr(self, 'split_input_params') + and isinstance(self.first_stage_model, VQModelInterface) + ): + return self.first_stage_model.decode( + z, + force_not_quantize=predict_cids or force_not_quantize, + ) + elif ( + hasattr(self, 'split_input_params') + and not self.split_input_params['patch_distributed_vq'] + or not hasattr(self, 'split_input_params') + ): + return self.first_stage_model.decode(z) - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) + else: + ks = self.split_input_params['ks'] # eg. (128, 128) + stride = self.split_input_params['stride'] # eg. (64, 64) + uf = self.split_input_params['vqf'] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print('reducing Kernel') - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [ - self.first_stage_model.decode( - z[:, :, :, :, i], - force_not_quantize=predict_cids - or force_not_quantize, - ) - for i in range(z.shape[-1]) - ] - else: + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print('reducing stride') - output_list = [ - self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf + ) - o = torch.stack( - output_list, axis=-1 - ) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view( - (o.shape[0], -1, o.shape[-1]) - ) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, - force_not_quantize=predict_cids or force_not_quantize, - ) - else: - return self.first_stage_model.decode(z) + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) - else: + # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, force_not_quantize=predict_cids or force_not_quantize - ) + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], + force_not_quantize=predict_cids + or force_not_quantize, + ) + for i in range(z.shape[-1]) + ] else: - return self.first_stage_model.decode(z) - - @torch.no_grad() - def encode_first_stage(self, x): - if hasattr(self, 'split_input_params'): - if self.split_input_params['patch_distributed_vq']: - ks = self.split_input_params['ks'] # eg. (128, 128) - stride = self.split_input_params['stride'] # eg. (64, 64) - df = self.split_input_params['vqf'] - self.split_input_params['original_image_size'] = x.shape[-2:] - bs, nc, h, w = x.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print('reducing Kernel') - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print('reducing stride') - - fold, unfold, normalization, weighting = self.get_fold_unfold( - x, ks, stride, df=df - ) - z = unfold(x) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) output_list = [ - self.first_stage_model.encode(z[:, :, :, :, i]) + self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1]) ] - o = torch.stack(output_list, axis=-1) - o = o * weighting - - # Reverse reshape to img shape - o = o.view( - (o.shape[0], -1, o.shape[-1]) - ) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization - return decoded + o = torch.stack( + output_list, axis=-1 + ) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view( + (o.shape[0], -1, o.shape[-1]) + ) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded - else: - return self.first_stage_model.encode(x) - else: + @torch.no_grad() + def encode_first_stage(self, x): + if ( + not hasattr(self, 'split_input_params') + or not self.split_input_params['patch_distributed_vq'] + ): return self.first_stage_model.encode(x) + ks = self.split_input_params['ks'] # eg. (128, 128) + stride = self.split_input_params['stride'] # eg. (64, 64) + df = self.split_input_params['vqf'] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print('reducing Kernel') + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print('reducing stride') + + fold, unfold, normalization, weighting = self.get_fold_unfold( + x, ks, stride, df=df + ) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) + + output_list = [ + self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view( + (o.shape[0], -1, o.shape[-1]) + ) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded def shared_step(self, batch, **kwargs): x, c = self.get_input(batch, self.first_stage_key) - loss = self(x, c) - return loss + return self(x, c) def forward(self, x, c, *args, **kwargs): t = torch.randint( @@ -1283,10 +1267,7 @@ def rescale_bbox(bbox): def apply_model(self, x_noisy, t, cond, return_ids=False): - if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict - pass - else: + if not isinstance(cond, dict): if not isinstance(cond, list): cond = [cond] key = ( @@ -1439,10 +1420,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): else: x_recon = self.model(x_noisy, t, **cond) - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon + return x_recon[0] if isinstance(x_recon, tuple) and not return_ids else x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( @@ -1474,7 +1452,6 @@ def p_losses(self, x_start, cond, t, noise=None): x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_output = self.apply_model(x_noisy, t, cond) - loss_dict = {} prefix = 'train' if self.training else 'val' if self.parameterization == 'x0': @@ -1487,14 +1464,13 @@ def p_losses(self, x_start, cond, t, noise=None): loss_simple = self.get_loss(model_output, target, mean=False).mean( [1, 2, 3] ) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - + loss_dict = {f'{prefix}/loss_simple': loss_simple.mean()} logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) + loss_dict[f'{prefix}/loss_gamma'] = loss.mean() + loss_dict['logvar'] = self.logvar.data.mean() loss = self.l_simple_weight * loss.mean() @@ -1502,19 +1478,19 @@ def p_losses(self, x_start, cond, t, noise=None): dim=(1, 2, 3) ) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss_dict[f'{prefix}/loss_vlb'] = loss_vlb loss += self.original_elbo_weight * loss_vlb - loss_dict.update({f'{prefix}/loss': loss}) + loss_dict[f'{prefix}/loss'] = loss if self.embedding_reg_weight > 0: loss_embedding_reg = ( self.embedding_manager.embedding_to_coarse_loss().mean() ) - loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg}) + loss_dict[f'{prefix}/loss_emb_reg'] = loss_embedding_reg loss += self.embedding_reg_weight * loss_embedding_reg - loss_dict.update({f'{prefix}/loss': loss}) + loss_dict[f'{prefix}/loss'] = loss return loss, loss_dict @@ -1667,19 +1643,17 @@ def progressive_denoising( shape = [batch_size] + list(shape) else: b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T + img = torch.randn(shape, device=self.device) if x_T is None else x_T intermediates = [] if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: list(map(lambda x: x[:batch_size], cond[key])) + if isinstance(cond[key], list) + else cond[key][:batch_size] for key in cond } + else: cond = ( [c[:batch_size] for c in cond] @@ -1691,13 +1665,14 @@ def progressive_denoising( timesteps = min(timesteps, start_T) iterator = ( tqdm( - reversed(range(0, timesteps)), + reversed(range(timesteps)), desc='Progressive Generation', total=timesteps, ) if verbose - else reversed(range(0, timesteps)) + else reversed(range(timesteps)) ) + if type(temperature) == float: temperature = [temperature] * timesteps @@ -1757,11 +1732,7 @@ def p_sample_loop( log_every_t = self.log_every_t device = self.betas.device b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - + img = torch.randn(shape, device=device) if x_T is None else x_T intermediates = [img] if timesteps is None: timesteps = self.num_timesteps @@ -1769,15 +1740,12 @@ def p_sample_loop( if start_T is not None: timesteps = min(timesteps, start_T) iterator = ( - tqdm( - reversed(range(0, timesteps)), - desc='Sampling t', - total=timesteps, - ) + tqdm(reversed(range(timesteps)), desc='Sampling t', total=timesteps) if verbose - else reversed(range(0, timesteps)) + else reversed(range(timesteps)) ) + if mask is not None: assert x0 is not None assert ( @@ -1811,9 +1779,7 @@ def p_sample_loop( if img_callback: img_callback(img, i) - if return_intermediates: - return img, intermediates - return img + return (img, intermediates) if return_intermediates else img @torch.no_grad() def sample( @@ -1840,11 +1806,12 @@ def sample( if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: list(map(lambda x: x[:batch_size], cond[key])) + if isinstance(cond[key], list) + else cond[key][:batch_size] for key in cond } + else: cond = ( [c[:batch_size] for c in cond] diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 0bc6ccd2968..55accde32c8 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -261,12 +261,9 @@ def p_sample( # are at an intermediate step in img2img. See similar in # sample() which does work. def get_initial_image(self,x_T,shape,steps): - print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing') + print('WARNING: ksampler.get_initial_image(): get_initial_image needs testing') x = (torch.randn(shape, device=self.device) * self.sigmas[0]) - if x_T is not None: - return x_T + x - else: - return x + return x_T + x if x_T is not None else x def prepare_to_sample(self,t_enc): self.t_enc = t_enc diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 88cdc019740..9948de616c8 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -26,9 +26,8 @@ def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs): self.device = device or choose_torch_device() def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device(self.device): - attr = attr.to(torch.float32).to(torch.device(self.device)) + if type(attr) == torch.Tensor and attr.device != torch.device(self.device): + attr = attr.to(torch.float32).to(torch.device(self.device)) setattr(self, name, attr) # This method was copied over from ddim.py and probably does stuff that is @@ -217,11 +216,12 @@ def do_sampling( ): b = shape[0] time_range = ( - list(reversed(range(0, timesteps))) + list(reversed(range(timesteps))) if ddim_use_original_steps else np.flip(timesteps) ) + total_steps=steps iterator = tqdm( @@ -362,10 +362,7 @@ def decode( return x_dec def get_initial_image(self,x_T,shape,timesteps=None): - if x_T is None: - return torch.randn(shape, device=self.device) - else: - return x_T + return torch.randn(shape, device=self.device) if x_T is None else x_T def p_sample( self, diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ef9c2d3e653..4b75863318d 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -50,10 +50,12 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) + project_in = ( + GEGLU(dim, inner_dim) + if glu + else nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + ) + self.net = nn.Sequential( project_in, @@ -190,11 +192,10 @@ def einsum_op_slice_1(self, q, k, v, slice_size): return r def einsum_op_mps_v1(self, q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + if q.shape[1] <= 4096: return self.einsum_op_compvis(q, k, v) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_1(q, k, v, slice_size) + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + return self.einsum_op_slice_1(q, k, v, slice_size) def einsum_op_mps_v2(self, q, k, v): if self.mem_total_gb > 8 and q.shape[1] <= 4096: @@ -293,10 +294,19 @@ def __init__(self, in_channels, n_heads, d_head, padding=0) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth)] + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + ) + for _ in range(depth) + ] ) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 78876a0919c..f8e4a9a5bb7 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -130,11 +130,7 @@ def forward(self, x, temb): h = self.conv2(h) if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) return x + h class LinAttnBlock(LinearAttention): @@ -202,19 +198,19 @@ def forward(self, x): tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 mem_required = tensor_size * 2.5 - steps = 1 + steps = ( + 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + if mem_required > mem_free_total + else 1 + ) - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + elif psutil.virtual_memory().available / (1024**3) < 12: + slice_size = 1 else: - if psutil.virtual_memory().available / (1024**3) < 12: - slice_size = 1 - else: - slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1]))) - + slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1]))) + for i in range(0, q.shape[1], slice_size): end = i + slice_size @@ -293,7 +289,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): + for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -436,7 +432,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): + for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -521,8 +517,10 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + print( + f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions." + ) + # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, @@ -549,7 +547,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, block = nn.ModuleList() attn = nn.ModuleList() block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + for _ in range(self.num_res_blocks+1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -639,11 +637,7 @@ def __init__(self, in_channels, out_channels, *args, **kwargs): def forward(self, x): for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - + x = layer(x, None) if i in [1,2,3] else layer(x) h = self.norm_out(x) h = silu(h) x = self.conv_out(h) @@ -665,7 +659,7 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, for i_level in range(self.num_resolutions): res_block = [] block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): + for _ in range(self.num_res_blocks + 1): res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -798,13 +792,6 @@ def __init__(self, in_channels=None, learned=False, mode="bilinear"): if self.with_conv: print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) def forward(self, x, scale_factor=1.0): if scale_factor==1.0: diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index d6baa76a1c1..7d0d6f6d90b 100644 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -590,22 +590,23 @@ def __init__( else num_head_channels ) layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( + SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ) + if use_spatial_transformer + else AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -655,20 +656,20 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( + SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + ) + if use_spatial_transformer + else AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, ), ResBlock( ch, @@ -679,6 +680,7 @@ def __init__( use_scale_shift_norm=use_scale_shift_norm, ), ) + self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -711,22 +713,23 @@ def __init__( else num_head_channels ) layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( + SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ) + if use_spatial_transformer + else AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) ) + if level and i == num_res_blocks: out_ch = ch layers.append( @@ -810,10 +813,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) + return self.id_predictor(h) if self.predict_codebook_ids else self.out(h) class EncoderUNetModel(nn.Module): @@ -1030,7 +1030,7 @@ def forward(self, x, timesteps): if self.pool.startswith('spatial'): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = th.cat(results, axis=-1) - return self.out(h) else: h = h.type(x.dtype) - return self.out(h) + + return self.out(h) diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 60b4d8a0280..eacb14b49a9 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -148,13 +148,7 @@ def checkpoint(func, inputs, params, flag): explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ - if ( - False - ): # disabled checkpointing to allow requires_grad = False for main model - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) + return func(*inputs) class CheckpointFunction(torch.autograd.Function): diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py index 67ed535791f..a8500f00a79 100644 --- a/ldm/modules/distributions/distributions.py +++ b/ldm/modules/distributions/distributions.py @@ -35,29 +35,27 @@ def __init__(self, parameters, deterministic=False): ) def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( + return self.mean + self.std * torch.randn(self.mean.shape).to( device=self.parameters.device ) - return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: @@ -81,11 +79,15 @@ def normal_kl(mean1, logvar1, mean2, logvar2): Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break + tensor = next( + ( + obj + for obj in (mean1, logvar1, mean2, logvar2) + if isinstance(obj, torch.Tensor) + ), + None, + ) + assert tensor is not None, 'at least one argument must be a Tensor' # Force variances to be Tensors. Broadcasting helps convert scalars to diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py index 2ceec5f0e79..71d6c35b7cb 100644 --- a/ldm/modules/ema.py +++ b/ldm/modules/ema.py @@ -51,7 +51,7 @@ def forward(self, model): one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -62,7 +62,7 @@ def copy_to(self, model): shadow_params[self.m_name2s_name[key]].data ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 18688708f9e..8b83694fd5d 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -251,11 +251,7 @@ def get_embedding_norms_squared(self): all_params = torch.cat( list(self.string_to_param_dict.values()), axis=0 ) # num_placeholders x embedding_dim - param_norm_squared = (all_params * all_params).sum( - axis=-1 - ) # num_placeholders - - return param_norm_squared + return (all_params * all_params).sum(axis=-1) def embedding_parameters(self): return self.string_to_param_dict.parameters() diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced31..26833aca1f4 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -85,8 +85,7 @@ def __init__( def forward(self, tokens): tokens = tokens.to(self.device) # meh - z = self.transformer(tokens, return_embeddings=True) - return z + return self.transformer(tokens, return_embeddings=True) def encode(self, x): return self(x) @@ -130,15 +129,12 @@ def forward(self, text): padding='max_length', return_tensors='pt', ) - tokens = batch_encoding['input_ids'].to(self.device) - return tokens + return batch_encoding['input_ids'].to(self.device) @torch.no_grad() def encode(self, text): tokens = self(text) - if not self.vq_interface: - return tokens - return None, None, [None, None, tokens] + return (None, None, [None, None, tokens]) if self.vq_interface else tokens def decode(self, text): return text @@ -172,14 +168,10 @@ def __init__( ) def forward(self, text, embedding_manager=None): - if self.use_tknz_fn: - tokens = self.tknz_fn(text) # .to(self.device) - else: - tokens = text - z = self.transformer( + tokens = self.tknz_fn(text) if self.use_tknz_fn else text + return self.transformer( tokens, return_embeddings=True, embedding_manager=embedding_manager ) - return z def encode(self, text, **kwargs): # output of length 77 @@ -221,7 +213,7 @@ def __init__( ) def forward(self, x): - for stage in range(self.n_stages): + for _ in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) if self.remap_output: @@ -447,9 +439,7 @@ def forward(self, text, **kwargs): return_tensors='pt', ) tokens = batch_encoding['input_ids'].to(self.device) - z = self.transformer(input_ids=tokens, **kwargs) - - return z + return self.transformer(input_ids=tokens, **kwargs) def encode(self, text, **kwargs): return self(text, **kwargs) diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py index b51217bd48e..814567318fa 100644 --- a/ldm/modules/image_degradation/bsrgan.py +++ b/ldm/modules/image_degradation/bsrgan.py @@ -85,9 +85,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k + return gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) def gm_blur_kernel(mean, cov, size=15): @@ -193,13 +191,7 @@ def gen_kernel( ZZ_t = ZZ.transpose(0, 1, 3, 2) raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel + return raw_kernel / np.sum(raw_kernel) def fspecial_gaussian(hsize, sigma): @@ -539,10 +531,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): for i in shuffle_order: - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: + if i in [0, 1]: img = add_blur(img, sf=sf) elif i == 2: @@ -645,10 +634,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): for i in shuffle_order: - if i == 0: - image = add_blur(image, sf=sf) - - elif i == 1: + if i in [0, 1]: image = add_blur(image, sf=sf) elif i == 2: @@ -694,17 +680,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): if random.random() < jpeg_prob: image = add_JPEG_noise(image) - # elif i == 6: - # # add processed camera sensor noise - # if random.random() < isp_prob and isp_model is not None: - # with torch.no_grad(): - # img, hq = isp_model.forward(img.copy(), hq) + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {'image': image} - return example + return {'image': image} # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... @@ -756,40 +741,24 @@ def degradation_bsrgan_plus( poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 for i in shuffle_order: - if i == 0: + if i in [0, 7]: img = add_blur(img, sf=sf) - elif i == 1: + elif i in [1, 8]: img = add_resize(img, sf=sf) - elif i == 2: + elif i in [2, 9]: img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 3: + elif i in [3, 10]: if random.random() < poisson_prob: img = add_Poisson_noise(img) - elif i == 4: + elif i in [4, 11]: if random.random() < speckle_prob: img = add_speckle_noise(img) - elif i == 5: + elif i in [5, 12]: if random.random() < isp_prob and isp_model is not None: with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) elif i == 6: img = add_JPEG_noise(img) - elif i == 7: - img = add_blur(img, sf=sf) - elif i == 8: - img = add_resize(img, sf=sf) - elif i == 9: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 10: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 11: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 12: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) else: print('check the shuffle!') @@ -843,4 +812,4 @@ def degradation_bsrgan_plus( img_concat = np.concatenate( [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 ) - util.imsave(img_concat, str(i) + '.png') + util.imsave(img_concat, f'{str(i)}.png') diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py index 3500ef7316f..8bb37ae58b9 100644 --- a/ldm/modules/image_degradation/bsrgan_light.py +++ b/ldm/modules/image_degradation/bsrgan_light.py @@ -85,9 +85,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k + return gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) def gm_blur_kernel(mean, cov, size=15): @@ -193,13 +191,7 @@ def gen_kernel( ZZ_t = ZZ.transpose(0, 1, 3, 2) raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel + return raw_kernel / np.sum(raw_kernel) def fspecial_gaussian(hsize, sigma): @@ -543,10 +535,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): for i in shuffle_order: - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: + if i in [0, 1]: img = add_blur(img, sf=sf) elif i == 2: @@ -711,8 +700,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {'image': image} - return example + return {'image': image} if __name__ == '__main__': @@ -748,4 +736,4 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): img_concat = np.concatenate( [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 ) - util.imsave(img_concat, str(i) + '.png') + util.imsave(img_concat, f'{str(i)}.png') diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py index 4b6e64658af..7dad0951cd5 100644 --- a/ldm/modules/image_degradation/utils_image.py +++ b/ldm/modules/image_degradation/utils_image.py @@ -78,10 +78,11 @@ def surf(Z, cmap='rainbow', figsize=None): def get_image_paths(dataroot): - paths = None # return None if dataroot is None - if dataroot is not None: - paths = sorted(_get_paths_from_images(dataroot)) - return paths + return ( + sorted(_get_paths_from_images(dataroot)) + if dataroot is not None + else None + ) def _get_paths_from_images(path): @@ -114,8 +115,7 @@ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): # print(w1) # print(h1) for i in w1: - for j in h1: - patches.append(img[i : i + p_size, j : j + p_size, :]) + patches.extend(img[i : i + p_size, j : j + p_size, :] for j in h1) else: patches.append(img) @@ -132,9 +132,9 @@ def imssave(imgs, img_path): if img.ndim == 3: img = img[:, :, [2, 1, 0]] new_path = os.path.join( - os.path.dirname(img_path), - img_name + str('_s{:04d}'.format(i)) + '.png', + os.path.dirname(img_path), img_name + '_s{:04d}'.format(i) + '.png' ) + cv2.imwrite(new_path, img) @@ -191,7 +191,7 @@ def mkdirs(paths): def mkdir_and_rename(path): if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() + new_name = f'{path}_archived_{get_timestamp()}' print('Path already exists. Rename it to [{:s}]'.format(new_name)) os.rename(path, new_name) os.makedirs(path) @@ -701,7 +701,7 @@ def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] # img1 = img1.squeeze() # img2 = img2.squeeze() - if not img1.shape == img2.shape: + if img1.shape != img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] img1 = img1[border : h - border, border : w - border] @@ -710,9 +710,7 @@ def calculate_psnr(img1, img2, border=0): img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img1 - img2) ** 2) - if mse == 0: - return float('inf') - return 20 * math.log10(255.0 / math.sqrt(mse)) + return float('inf') if mse == 0 else 20 * math.log10(255.0 / math.sqrt(mse)) # -------------------------------------------- @@ -725,7 +723,7 @@ def calculate_ssim(img1, img2, border=0): """ # img1 = img1.squeeze() # img2 = img2.squeeze() - if not img1.shape == img2.shape: + if img1.shape != img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] img1 = img1[border : h - border, border : w - border] @@ -735,9 +733,7 @@ def calculate_ssim(img1, img2, border=0): return ssim(img1, img2) elif img1.ndim == 3: if img1.shape[2] == 3: - ssims = [] - for i in range(3): - ssims.append(ssim(img1[:, :, i], img2[:, :, i])) + ssims = [ssim(img1[:, :, i], img2[:, :, i]) for i in range(3)] return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) @@ -851,7 +847,7 @@ def imresize(img, scale, antialiasing=True): # Now the scale should be the same for H and W # input: img: pytorch tensor, CHW or HW [0,1] # output: CHW or HW [0,1] w/o round - need_squeeze = True if img.dim() == 2 else False + need_squeeze = img.dim() == 2 if need_squeeze: img.unsqueeze_(0) in_C, in_H, in_W = img.size() @@ -937,7 +933,7 @@ def imresize_np(img, scale, antialiasing=True): # input: img: Numpy, HWC or HW [0,1] # output: HWC or HW [0,1] w/o round img = torch.from_numpy(img) - need_squeeze = True if img.dim() == 2 else False + need_squeeze = img.dim() == 2 if need_squeeze: img.unsqueeze_(2) diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py index 7fa41243462..fc44c8eaedc 100644 --- a/ldm/modules/losses/contperceptual.py +++ b/ldm/modules/losses/contperceptual.py @@ -132,15 +132,16 @@ def forward( ) log = { - '{}/total_loss'.format(split): loss.clone().detach().mean(), - '{}/logvar'.format(split): self.logvar.detach(), - '{}/kl_loss'.format(split): kl_loss.detach().mean(), - '{}/nll_loss'.format(split): nll_loss.detach().mean(), - '{}/rec_loss'.format(split): rec_loss.detach().mean(), - '{}/d_weight'.format(split): d_weight.detach(), - '{}/disc_factor'.format(split): torch.tensor(disc_factor), - '{}/g_loss'.format(split): g_loss.detach().mean(), + f'{split}/total_loss': loss.clone().detach().mean(), + f'{split}/logvar': self.logvar.detach(), + f'{split}/kl_loss': kl_loss.detach().mean(), + f'{split}/nll_loss': nll_loss.detach().mean(), + f'{split}/rec_loss': rec_loss.detach().mean(), + f'{split}/d_weight': d_weight.detach(), + f'{split}/disc_factor': torch.tensor(disc_factor), + f'{split}/g_loss': g_loss.detach().mean(), } + return loss, log if optimizer_idx == 1: @@ -168,8 +169,9 @@ def forward( d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = { - '{}/disc_loss'.format(split): d_loss.clone().detach().mean(), - '{}/logits_real'.format(split): logits_real.detach().mean(), - '{}/logits_fake'.format(split): logits_fake.detach().mean(), + f'{split}/disc_loss': d_loss.clone().detach().mean(), + f'{split}/logits_real': logits_real.detach().mean(), + f'{split}/logits_fake': logits_fake.detach().mean(), } + return d_loss, log diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py index 2f94bf5281a..2c711b43cca 100644 --- a/ldm/modules/losses/vqperceptual.py +++ b/ldm/modules/losses/vqperceptual.py @@ -17,8 +17,7 @@ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss + return 0.5 * (loss_real + loss_fake) def adopt_weight(weight, global_step, threshold=0, value=0.0): @@ -72,20 +71,15 @@ def __init__( assert pixel_loss in ['l1', 'l2'] self.codebook_weight = codebook_weight self.pixel_weight = pixelloss_weight - if perceptual_loss == 'lpips': - print(f'{self.__class__.__name__}: Running with LPIPS.') - self.perceptual_loss = LPIPS().eval() - else: + if perceptual_loss != 'lpips': raise ValueError( f'Unknown perceptual loss: >> {perceptual_loss} <<' ) + print(f'{self.__class__.__name__}: Running with LPIPS.') + self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight - if pixel_loss == 'l1': - self.pixel_loss = l1 - else: - self.pixel_loss = l2 - + self.pixel_loss = l1 if pixel_loss == 'l1' else l2 self.discriminator = NLayerDiscriminator( input_nc=disc_in_channels, n_layers=disc_num_layers, @@ -189,15 +183,16 @@ def forward( ) log = { - '{}/total_loss'.format(split): loss.clone().detach().mean(), - '{}/quant_loss'.format(split): codebook_loss.detach().mean(), - '{}/nll_loss'.format(split): nll_loss.detach().mean(), - '{}/rec_loss'.format(split): rec_loss.detach().mean(), - '{}/p_loss'.format(split): p_loss.detach().mean(), - '{}/d_weight'.format(split): d_weight.detach(), - '{}/disc_factor'.format(split): torch.tensor(disc_factor), - '{}/g_loss'.format(split): g_loss.detach().mean(), + f'{split}/total_loss': loss.clone().detach().mean(), + f'{split}/quant_loss': codebook_loss.detach().mean(), + f'{split}/nll_loss': nll_loss.detach().mean(), + f'{split}/rec_loss': rec_loss.detach().mean(), + f'{split}/p_loss': p_loss.detach().mean(), + f'{split}/d_weight': d_weight.detach(), + f'{split}/disc_factor': torch.tensor(disc_factor), + f'{split}/g_loss': g_loss.detach().mean(), } + if predicted_indices is not None: assert self.n_classes is not None with torch.no_grad(): @@ -233,8 +228,9 @@ def forward( d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = { - '{}/disc_loss'.format(split): d_loss.clone().detach().mean(), - '{}/logits_real'.format(split): logits_real.detach().mean(), - '{}/logits_fake'.format(split): logits_fake.detach().mean(), + f'{split}/disc_loss': d_loss.clone().detach().mean(), + f'{split}/logits_real': logits_real.detach().mean(), + f'{split}/logits_fake': logits_fake.detach().mean(), } + return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py index d6c4cc68819..9dfedd2e613 100644 --- a/ldm/modules/x_transformer.py +++ b/ldm/modules/x_transformer.py @@ -99,7 +99,7 @@ def pick_and_pop(keys, d): def group_dict_by_key(cond, d): - return_val = [dict(), dict()] + return_val = [{}, {}] for key in d.keys(): match = bool(cond(key)) ind = int(not match) @@ -213,11 +213,12 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) + GEGLU(dim, inner_dim) + if glu + else nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) ) + self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) @@ -332,7 +333,7 @@ def forward( q_mask = default( mask, lambda: torch.ones((b, n), device=device).bool() ) - k_mask = q_mask if not exists(context) else context_mask + k_mask = context_mask if exists(context) else q_mask k_mask = default( k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool(), @@ -471,7 +472,7 @@ def __init__( if cross_attend and not only_cross: default_block = ('a', 'c', 'f') - elif cross_attend and only_cross: + elif cross_attend: default_block = ('c', 'f') else: default_block = ('a', 'f') @@ -522,18 +523,14 @@ def __init__( layer = Attention(dim, heads=heads, **attn_kwargs) elif layer_type == 'f': layer = FeedForward(dim, **ff_kwargs) - layer = layer if not macaron else Scale(0.5, layer) + layer = Scale(0.5, layer) if macaron else layer else: raise Exception(f'invalid layer type {layer_type}') if isinstance(layer, Attention) and exists(branch_fn): layer = branch_fn(layer) - if gate_residual: - residual_fn = GRUGating(dim) - else: - residual_fn = Residual() - + residual_fn = GRUGating(dim) if gate_residual else Residual() self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) def forward( @@ -659,11 +656,12 @@ def __init__( self.init_() self.to_logits = ( - nn.Linear(dim, num_tokens) - if not tie_embedding - else lambda t: t @ self.token_emb.weight.t() + (lambda t: t @ self.token_emb.weight.t()) + if tie_embedding + else nn.Linear(dim, num_tokens) ) + # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) self.num_memory_tokens = num_memory_tokens diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 548c44fa492..7a5cdf1ca6e 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -9,5 +9,8 @@ class T2I(Generate): def __init__(self,**kwargs): - print(f'>> The ldm.simplet2i module is deprecated. Use ldm.generate instead. It is a drop-in replacement.') + print( + '>> The ldm.simplet2i module is deprecated. Use ldm.generate instead. It is a drop-in replacement.' + ) + super().__init__(kwargs) diff --git a/ldm/util.py b/ldm/util.py index 95cad79523c..3ec18e687ba 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -19,7 +19,7 @@ def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) - txts = list() + txts = [] for bi in range(b): txt = Image.new('RGB', wh, color='white') draw = ImageDraw.Draw(txt) @@ -42,15 +42,19 @@ def log_txt_as_img(wh, xc, size=10): def ismap(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] > 3) + return ( + (len(x.shape) == 4) and (x.shape[1] > 3) + if isinstance(x, torch.Tensor) + else False + ) def isimage(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + return ( + len(x.shape) == 4 and x.shape[1] in [3, 1] + if isinstance(x, torch.Tensor) + else False + ) def exists(x): @@ -104,10 +108,7 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): # create dummy dataset instance # run prefetching - if idx_to_fn: - res = func(data, worker_id=idx) - else: - res = func(data) + res = func(data, worker_id=idx) if idx_to_fn else func(data) Q.put([idx, res]) Q.put('Done') @@ -129,13 +130,11 @@ def parallel_data_prefetch( elif isinstance(data, abc.Iterable): if isinstance(data, dict): print( - f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + 'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' ) + data = list(data.values()) - if target_data_type == 'ndarray': - data = np.asarray(data) - else: - data = list(data) + data = np.asarray(data) if target_data_type == 'ndarray' else list(data) else: raise TypeError( f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.' @@ -162,16 +161,17 @@ def parallel_data_prefetch( arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate( - [data[i : i + step] for i in range(0, len(data), step)] + data[i : i + step] for i in range(0, len(data), step) ) ] + processes = [] for i in range(n_proc): p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) processes += [p] # start processes - print(f'Start prefetching...') + print('Start prefetching...') import time start = time.time() @@ -201,11 +201,12 @@ def parallel_data_prefetch( print(f'Prefetching complete. [{time.time() - start} sec.]') if target_data_type == 'ndarray': - if not isinstance(gather_res[0], np.ndarray): - return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + return ( + np.concatenate(gather_res, axis=0) + if isinstance(gather_res[0], np.ndarray) + else np.concatenate([np.asarray(r) for r in gather_res], axis=0) + ) - # order outputs - return np.concatenate(gather_res, axis=0) elif target_data_type == 'list': out = [] for r in gather_res: diff --git a/main.py b/main.py index 436b7251ba7..e45ee8e04cd 100644 --- a/main.py +++ b/main.py @@ -99,8 +99,9 @@ def str2bool(v): metavar='base_config.yaml', help='paths to base configs. Loaded from left-to-right. ' 'Parameters can be overwritten or added with command-line options of the form `--key value`.', - default=list(), + default=[], ) + parser.add_argument( '-t', '--train', @@ -255,7 +256,7 @@ def __init__( ): super().__init__() self.batch_size = batch_size - self.dataset_configs = dict() + self.dataset_configs = {} self.num_workers = ( num_workers if num_workers is not None else batch_size * 2 ) @@ -283,10 +284,11 @@ def prepare_data(self): instantiate_from_config(data_cfg) def setup(self, stage=None): - self.datasets = dict( - (k, instantiate_from_config(self.dataset_configs[k])) + self.datasets = { + k: instantiate_from_config(self.dataset_configs[k]) for k in self.dataset_configs - ) + } + if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) @@ -303,7 +305,7 @@ def _train_dataloader(self): self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, - shuffle=False if is_iterable_dataset else True, + shuffle=not is_iterable_dataset, worker_init_fn=init_fn, ) @@ -385,41 +387,37 @@ def on_pretrain_routine_start(self, trainer, pl_module): os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) - if 'callbacks' in self.lightning_config: - if ( - 'metrics_over_trainsteps_checkpoint' - in self.lightning_config['callbacks'] - ): - os.makedirs( - os.path.join(self.ckptdir, 'trainstep_checkpoints'), - exist_ok=True, - ) + if 'callbacks' in self.lightning_config and ( + 'metrics_over_trainsteps_checkpoint' + in self.lightning_config['callbacks'] + ): + os.makedirs( + os.path.join(self.ckptdir, 'trainstep_checkpoints'), + exist_ok=True, + ) print('Project config') print(OmegaConf.to_yaml(self.config)) OmegaConf.save( - self.config, - os.path.join(self.cfgdir, '{}-project.yaml'.format(self.now)), + self.config, os.path.join(self.cfgdir, f'{self.now}-project.yaml') ) + print('Lightning config') print(OmegaConf.to_yaml(self.lightning_config)) OmegaConf.save( OmegaConf.create({'lightning': self.lightning_config}), - os.path.join( - self.cfgdir, '{}-lightning.yaml'.format(self.now) - ), + os.path.join(self.cfgdir, f'{self.now}-lightning.yaml'), ) - else: - # ModelCheckpoint callback created log directory --- remove it - if not self.resume and os.path.exists(self.logdir): - dst, name = os.path.split(self.logdir) - dst = os.path.join(dst, 'child_runs', name) - os.makedirs(os.path.split(dst)[0], exist_ok=True) - try: - os.rename(self.logdir, dst) - except FileNotFoundError: - pass + + elif not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, 'child_runs', name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass class ImageLogger(Callback): @@ -448,7 +446,7 @@ def __init__( self.clamp = clamp self.disabled = disabled self.log_on_batch_idx = log_on_batch_idx - self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_images_kwargs = log_images_kwargs or {} self.log_first_step = log_first_step @rank_zero_only @@ -537,7 +535,6 @@ def check_frequency(self, check_idx): self.log_steps.pop(0) except IndexError as e: print(e) - pass return True return False @@ -554,11 +551,12 @@ def on_validation_batch_end( ): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split='val') - if hasattr(pl_module, 'calibrate_grad_norm'): - if ( - pl_module.calibrate_grad_norm and batch_idx % 25 == 0 - ) and batch_idx > 0: - self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + if ( + hasattr(pl_module, 'calibrate_grad_norm') + and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) + and batch_idx > 0 + ): + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) class CUDACallback(Callback): diff --git a/notebooks/notebook_helpers.py b/notebooks/notebook_helpers.py index 663b212ac5f..1adcb73f9f8 100644 --- a/notebooks/notebook_helpers.py +++ b/notebooks/notebook_helpers.py @@ -18,23 +18,21 @@ def download_models(mode): - if mode == "superresolution": - # this is the small bsr light model - url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' - url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' - - path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml' - path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt' + if mode != "superresolution": + raise NotImplementedError + # this is the small bsr light model + url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' + url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' - download_url(url_conf, path_conf) - download_url(url_ckpt, path_ckpt) + path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml' + path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt' - path_conf = path_conf + '/?dl=1' # fix it - path_ckpt = path_ckpt + '/?dl=1' # fix it - return path_conf, path_ckpt + download_url(url_conf, path_conf) + download_url(url_ckpt, path_ckpt) - else: - raise NotImplementedError + path_conf += '/?dl=1' + path_ckpt += '/?dl=1' + return path_conf, path_ckpt def load_model_from_config(config, ckpt): @@ -85,14 +83,14 @@ def get_custom_cond(mode): def get_cond_options(mode): path = "data/example_conditioning" path = os.path.join(path, mode) - onlyfiles = [f for f in sorted(os.listdir(path))] + onlyfiles = list(sorted(os.listdir(path))) return path, onlyfiles def select_cond_path(mode): path = "data/example_conditioning" # todo path = os.path.join(path, mode) - onlyfiles = [f for f in sorted(os.listdir(path))] + onlyfiles = list(sorted(os.listdir(path))) selected = widgets.RadioButtons( options=onlyfiles, @@ -100,12 +98,11 @@ def select_cond_path(mode): disabled=False ) display(selected) - selected_path = os.path.join(path, selected.value) - return selected_path + return os.path.join(path, selected.value) def get_cond(mode, selected_path): - example = dict() + example = {} if mode == "superresolution": up_f = 4 visualize_cond_img(selected_path) @@ -149,9 +146,9 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi split_input = height >= 128 and width >= 128 if split_input: - ks = 128 stride = 64 - vqf = 4 # + vqf = 4 + ks = 128 model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), "vqf": vqf, "patch_distributed_vq": True, @@ -160,14 +157,13 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi "clip_min_weight": 0.01, "clip_max_tie_weight": 0.5, "clip_min_tie_weight": 0.01} - else: - if hasattr(model, "split_input_params"): - delattr(model, "split_input_params") + elif hasattr(model, "split_input_params"): + delattr(model, "split_input_params") invert_mask = False x_T = None - for n in range(n_runs): + for _ in range(n_runs): if custom_shape is not None: x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0]) @@ -210,13 +206,15 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): - log = dict() + z, c, x, xrec, xc = model.get_input( + batch, + model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not hasattr(model, 'split_input_params') + or model.cond_stage_key != 'coordinates_bbox', + return_original_cond=True, + ) - z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=not (hasattr(model, 'split_input_params') - and model.cond_stage_key == 'coordinates_bbox'), - return_original_cond=True) log_every_t = 1 if save_intermediate_vid else None @@ -226,9 +224,7 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e z0 = None - log["input"] = x - log["reconstruction"] = xrec - + log = {"input": x, "reconstruction": xrec} if ismap(xc): log["original_conditioning"] = model.to_rgb(xc) if hasattr(model, 'cond_stage_key'): diff --git a/scripts/invoke.py b/scripts/invoke.py index 100ab2413b5..f0e4c851186 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -57,7 +57,7 @@ def main(): print('>> Upscaling disabled') else: print('>> Face restoration and upscaling disabled') - except (ModuleNotFoundError, ImportError): + except ImportError: print(traceback.format_exc(), file=sys.stderr) print('>> You may need to install the ESRGAN and/or GFPGAN modules') @@ -123,7 +123,7 @@ def main_loop(gen, opt, infile): """prompt/read/execute loop""" done = False path_filter = re.compile(r'[<>:"/\\|?*]') - last_results = list() + last_results = [] model_config = OmegaConf.load(opt.conf)[opt.model] # The readline completer reads history from the .dream_history file located in the @@ -145,7 +145,7 @@ def main_loop(gen, opt, infile): if completer: completer.set_default_dir(opt.outdir) - + try: command = get_next_command(infile) except EOFError: @@ -196,7 +196,7 @@ def main_loop(gen, opt, infile): command = completer.get_line(int(command_no)) completer.set_line(command) continue - + else: # not a recognized subcommand, so give the --help text command = '-h' @@ -286,7 +286,7 @@ def main_loop(gen, opt, infile): try: file_writer = PngWriter(current_outdir) results = [] # list of filename, prompt pairs - grid_images = dict() # seed -> Image, only used if `opt.grid` + grid_images = {} prior_variations = opt.with_variations or [] prefix = file_writer.unique_prefix() @@ -303,7 +303,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None): if opt.grid: grid_images[seed] = image else: - postprocessed = upscaled if upscaled else operation=='postprocess' + postprocessed = upscaled or operation=='postprocess' filename, formatted_dream_prompt = prepare_image_metadata( opt, prefix, @@ -334,7 +334,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None): tool, formatted_dream_prompt, ) - + if (not postprocessed) or opt.save_original: # only append to results if we didn't overwrite an earlier output results.append([path, formatted_dream_prompt]) @@ -357,7 +357,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None): print(f'>> fixing {opt.prompt}') opt.last_operation = do_postprocess(gen,opt,image_writer) - if opt.grid and len(grid_images) > 0: + if opt.grid and grid_images: grid_img = make_grid(list(grid_images.values())) grid_seeds = list(grid_images.keys()) first_seed = last_results[0][1] @@ -377,11 +377,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None): ) results = [[path, formatted_dream_prompt]] - except AssertionError as e: - print(e) - continue - - except OSError as e: + except (AssertionError, OSError) as e: print(e) continue @@ -470,15 +466,14 @@ def prepare_image_metadata( elif len(prior_variations) > 0: formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) elif operation == 'postprocess': - formatted_dream_prompt = '!fix '+opt.dream_prompt_str(seed=seed) + formatted_dream_prompt = f'!fix {opt.dream_prompt_str(seed=seed)}' else: formatted_dream_prompt = opt.dream_prompt_str(seed=seed) return filename,formatted_dream_prompt def choose_postprocess_name(opt,prefix,seed) -> str: - match = re.search('postprocess:(\w+)',opt.last_operation) - if match: - modifier = match.group(1) # will look like "gfpgan", "upscale", "outpaint" or "embiggen" + if match := re.search('postprocess:(\w+)', opt.last_operation): + modifier = match[1] else: modifier = 'postprocessed' @@ -540,12 +535,7 @@ def split_variations(variations_string) -> list: broken = True break parts.append([seed, weight]) - if broken: - return None - elif len(parts) == 0: - return None - else: - return parts + return None if broken or not parts else parts def retrieve_dream_command(opt,file_path,completer): ''' @@ -555,10 +545,7 @@ def retrieve_dream_command(opt,file_path,completer): for cut-and-paste (windows) ''' dir,basename = os.path.split(file_path) - if len(dir) == 0: - path = os.path.join(opt.outdir,basename) - else: - path = file_path + path = os.path.join(opt.outdir,basename) if len(dir) == 0 else file_path try: cmd = dream_cmd_from_png(path) except OSError: diff --git a/scripts/merge_embeddings.py b/scripts/merge_embeddings.py index 452b27faf4f..e73456fb3a8 100644 --- a/scripts/merge_embeddings.py +++ b/scripts/merge_embeddings.py @@ -7,9 +7,9 @@ import torch def get_placeholder_loop(placeholder_string, embedder, use_bert): - + new_placeholder = None - + while True: if new_placeholder is None: new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") @@ -34,17 +34,11 @@ def get_clip_token_for_string(tokenizer, string): tokens = batch_encoding["input_ids"] - if torch.count_nonzero(tokens - 49407) == 2: - return tokens[0, 1] - - return None + return tokens[0, 1] if torch.count_nonzero(tokens - 49407) == 2 else None def get_bert_token_for_string(tokenizer, string): token = tokenizer(string) - if torch.count_nonzero(token) == 3: - return token[0, 1] - - return None + return token[0, 1] if torch.count_nonzero(token) == 3 else None if __name__ == "__main__": @@ -81,7 +75,7 @@ def get_bert_token_for_string(tokenizer, string): EmbeddingManager = partial(EmbeddingManager, embedder, ["*"]) - string_to_token_dict = {} + string_to_token_dict = {} string_to_param_dict = torch.nn.ParameterDict() placeholder_to_src = {} @@ -93,7 +87,7 @@ def get_bert_token_for_string(tokenizer, string): manager.load(manager_ckpt) for placeholder_string in manager.string_to_token_dict: - if not placeholder_string in string_to_token_dict: + if placeholder_string not in string_to_token_dict: string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] diff --git a/scripts/orig_scripts/img2img.py b/scripts/orig_scripts/img2img.py index 9f74f25bf29..b90ca297903 100644 --- a/scripts/orig_scripts/img2img.py +++ b/scripts/orig_scripts/img2img.py @@ -205,7 +205,6 @@ def main(): if opt.plms: raise NotImplementedError("PLMS sampler not (yet) supported") - sampler = PLMSSampler(model) else: sampler = DDIMSampler(model) @@ -248,8 +247,8 @@ def main(): with precision_scope(device.type): with model.ema_scope(): tic = time.time() - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): + all_samples = [] + for _ in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): uc = None if opt.scale != 1.0: diff --git a/scripts/orig_scripts/knn2img.py b/scripts/orig_scripts/knn2img.py index e6eaaecab53..f57752b4b0c 100644 --- a/scripts/orig_scripts/knn2img.py +++ b/scripts/orig_scripts/knn2img.py @@ -150,15 +150,16 @@ def search(self, x, k): out_img_ids = self.database['img_id'][nns] out_pc = self.database['patch_coords'][nns] - out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], - 'img_ids': out_img_ids, - 'patch_coords': out_pc, - 'queries': x, - 'exec_time': end - start, - 'nns': nns, - 'q_embeddings': query_embeddings} - - return out + return { + 'nn_embeddings': out_embeddings + / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings, + } def __call__(self, x, n): return self.search(x, n) @@ -314,11 +315,7 @@ def __call__(self, x, n): clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) - if opt.plms: - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) - + sampler = PLMSSampler(model) if opt.plms else DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir @@ -342,14 +339,11 @@ def __call__(self, x, n): print(f"sampling scale for cfg is {opt.scale:.2f}") - searcher = None - if opt.use_neighbors: - searcher = Searcher(opt.database) - + searcher = Searcher(opt.database) if opt.use_neighbors else None with torch.no_grad(): with model.ema_scope(): - for n in trange(opt.n_iter, desc="Sampling"): - all_samples = list() + for _ in trange(opt.n_iter, desc="Sampling"): + all_samples = [] for prompts in tqdm(data, desc="data"): print("sampling prompts:", prompts) if isinstance(prompts, tuple): diff --git a/scripts/orig_scripts/sample_diffusion.py b/scripts/orig_scripts/sample_diffusion.py index 876fe3c3642..301e37a24e3 100644 --- a/scripts/orig_scripts/sample_diffusion.py +++ b/scripts/orig_scripts/sample_diffusion.py @@ -19,7 +19,7 @@ def custom_to_pil(x): x = x.permute(1, 2, 0).numpy() x = (255 * x).astype(np.uint8) x = Image.fromarray(x) - if not x.mode == "RGB": + if x.mode != "RGB": x = x.convert("RGB") return x @@ -34,7 +34,7 @@ def custom_to_np(x): def logs2pil(logs, keys=["sample"]): - imgs = dict() + imgs = {} for k in logs: try: if len(logs[k].shape) == 4: @@ -56,13 +56,16 @@ def convsample(model, shape, return_intermediates=True, make_prog_row=False): - if not make_prog_row: - return model.p_sample_loop(None, shape, - return_intermediates=return_intermediates, verbose=verbose) - else: - return model.progressive_denoising( - None, shape, verbose=True + return ( + model.progressive_denoising(None, shape, verbose=True) + if make_prog_row + else model.p_sample_loop( + None, + shape, + return_intermediates=return_intermediates, + verbose=verbose, ) + ) @torch.no_grad() @@ -79,8 +82,6 @@ def convsample_ddim(model, steps, shape, eta=1.0 def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): - log = dict() - shape = [batch_size, model.model.diffusion_model.in_channels, model.model.diffusion_model.image_size, @@ -99,9 +100,12 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non x_sample = model.decode_first_stage(sample) - log["sample"] = x_sample - log["time"] = t1 - t0 - log['throughput'] = sample.shape[0] / (t1 - t0) + log = { + "sample": x_sample, + "time": t1 - t0, + 'throughput': sample.shape[0] / (t1 - t0), + } + print(f'Throughput for this batch: {log["throughput"]}') return log @@ -249,7 +253,7 @@ def load_model(config, ckpt, gpu, eval_mode): ckpt = None if not os.path.exists(opt.resume): - raise ValueError("Cannot find {}".format(opt.resume)) + raise ValueError(f"Cannot find {opt.resume}") if os.path.isfile(opt.resume): # paths = opt.resume.split("/") try: diff --git a/scripts/orig_scripts/train_searcher.py b/scripts/orig_scripts/train_searcher.py index 1e7904889c0..06d88ac2365 100644 --- a/scripts/orig_scripts/train_searcher.py +++ b/scripts/orig_scripts/train_searcher.py @@ -96,7 +96,7 @@ def train_searcher(opt, if pool_size < 2e4: print('Using brute force search.') searcher = search_bruteforce(searcher) - elif 2e4 <= pool_size and pool_size < 1e5: + elif pool_size < 1e5: print('Using asymmetric hashing search and reordering.') searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) else: diff --git a/scripts/orig_scripts/txt2img.py b/scripts/orig_scripts/txt2img.py index 0d350d2c73c..059607354f5 100644 --- a/scripts/orig_scripts/txt2img.py +++ b/scripts/orig_scripts/txt2img.py @@ -213,11 +213,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale): uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) return uncond + (cond - uncond) * cond_scale - if opt.plms: - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) - + sampler = PLMSSampler(model) if opt.plms else DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir @@ -259,8 +255,8 @@ def forward(self, x, sigma, uncond, cond, cond_scale): with precision_scope(device.type): with model.ema_scope(): tic = time.time() - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): + all_samples = [] + for _ in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): uc = None if opt.scale != 1.0: @@ -282,10 +278,14 @@ def forward(self, x, sigma, uncond, cond, cond_scale): x_T=start_code) else: sigmas = model_wrap.get_sigmas(opt.ddim_steps) - if start_code: - x = start_code - else: - x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw + x = ( + start_code + or torch.randn( + [opt.n_samples, *shape], device=device + ) + * sigmas[0] + ) + model_wrap_cfg = CFGDenoiser(model_wrap) extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale} samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args) diff --git a/server/models.py b/server/models.py index 1a574aa1376..eb1080a2ca3 100644 --- a/server/models.py +++ b/server/models.py @@ -135,8 +135,7 @@ def clone_without_img(self): def to_json(self): copy = deepcopy(self) copy.initimg = None - j = json.dumps(copy.__dict__) - return j + return json.dumps(copy.__dict__) @staticmethod def from_json(j, newTime: bool = False): diff --git a/server/services.py b/server/services.py index 19b2360a372..5db888cd392 100644 --- a/server/services.py +++ b/server/services.py @@ -114,13 +114,12 @@ def __getName(self, dreamId: str, postfix: str = '') -> str: def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str: name = self.__getName(dreamResult.id, postfix) meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata. - path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name) - return path + return self.__pngWriter.save_image_and_prompt_to_png( + image, dream_prompt=meta, metadata=None, name=name) def path(self, dreamId: str, postfix: str = '') -> str: name = self.__getName(dreamId, postfix) - path = os.path.join(self.__location, name) - return path + return os.path.join(self.__location, name) # Returns true if found, false if not found or error def delete(self, dreamId: str, postfix: str = '') -> bool: @@ -135,31 +134,30 @@ def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult: path = self.path(dreamId, postfix) image = Image.open(path) text = image.text - if text.__contains__('Dream'): - dreamMeta = text.get('Dream') - try: - j = json.loads(dreamMeta) - return DreamResult.from_json(j) - except ValueError: + if not text.__contains__('Dream'): + return None + dreamMeta = text.get('Dream') + try: + j = json.loads(dreamMeta) + return DreamResult.from_json(j) + except ValueError: # Try to parse command-line format (legacy metadata format) - try: - opt = self.__parseLegacyMetadata(dreamMeta) - optd = opt.__dict__ - if (not 'width' in optd) or (optd.get('width') is None): - optd['width'] = image.width - if (not 'height' in optd) or (optd.get('height') is None): - optd['height'] = image.height - if (not 'steps' in optd) or (optd.get('steps') is None): - optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously + try: + opt = self.__parseLegacyMetadata(dreamMeta) + optd = opt.__dict__ + if 'width' not in optd or optd.get('width') is None: + optd['width'] = image.width + if 'height' not in optd or optd.get('height') is None: + optd['height'] = image.height + if 'steps' not in optd or optd.get('steps') is None: + optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously - optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though) + optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though) - return DreamResult.from_json(optd) + return DreamResult.from_json(optd) - except: - return None - else: - return None + except: + return None def __parseLegacyMetadata(self, command: str) -> DreamResult: # before splitting, escape single quotes so as not to mess @@ -183,11 +181,10 @@ def __parseLegacyMetadata(self, command: str) -> DreamResult: else: switches[0] += el switches[0] += ' ' - switches[0] = switches[0][: len(switches[0]) - 1] + switches[0] = switches[0][:-1] try: - opt = self.__legacyParser.parse_cmd(switches) - return opt + return self.__legacyParser.parse_cmd(switches) except SystemExit: return None @@ -196,7 +193,7 @@ def list_files(self, page: int, perPage: int) -> PaginatedItems: count = len(files) startId = page * perPage - pageCount = int(count / perPage) + 1 + pageCount = count // perPage + 1 endId = min(startId + perPage, count) items = [] if startId >= count else files[startId:endId] @@ -241,7 +238,7 @@ def __process(self): print('Preloading model') tic = time.time() self.__model.load_model() - print(f'>> model loaded in', '%4.2fs' % (time.time() - tic)) + print('>> model loaded in', '%4.2fs' % (time.time() - tic)) print('Started generation queue processor') try: @@ -266,7 +263,7 @@ def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False) # TODO: Separate status of GFPGAN? self.__imageStorage.save(image, dreamResult) - + # TODO: handle upscaling logic better (this is appending data to log, but only on first generation) if not upscaled: self.__log.log(dreamResult) @@ -275,7 +272,7 @@ def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False) self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult)) upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan - + # Report upscaling status # TODO: this is very coupled to logic inside the generator. Fix that. if upscaling_requested and any(result.has_upscaled for result in jobRequest.results): @@ -308,12 +305,11 @@ def __generate(self, jobRequest: JobRequest): try: # TODO: handle this file a file service for init images initimgfile = None # TODO: support this on the model directly? - if (jobRequest.enable_init_image): - if jobRequest.initimg is not None: - with open("./img2img-tmp.png", "wb") as f: - initimg = jobRequest.initimg.split(",")[1] # Ignore mime type - f.write(base64.b64decode(initimg)) - initimgfile = "./img2img-tmp.png" + if jobRequest.enable_init_image and jobRequest.initimg is not None: + with open("./img2img-tmp.png", "wb") as f: + initimg = jobRequest.initimg.split(",")[1] # Ignore mime type + f.write(base64.b64decode(initimg)) + initimgfile = "./img2img-tmp.png" # Use previous seed if set to -1 initSeed = jobRequest.seed @@ -333,8 +329,8 @@ def __generate(self, jobRequest: JobRequest): # TODO: Split job generation requests instead of fitting all parameters here # TODO: Support no generation (just upscaling/gfpgan) - upscale = None if not jobRequest.enable_upscale else jobRequest.upscale - gfpgan_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.gfpgan_strength + upscale = jobRequest.upscale if jobRequest.enable_upscale else None + gfpgan_strength = jobRequest.gfpgan_strength if jobRequest.enable_gfpgan else 0 if not jobRequest.enable_generate: # If not generating, check if we're upscaling or running gfpgan diff --git a/server/views.py b/server/views.py index db4857d14f5..f1afaef074c 100644 --- a/server/views.py +++ b/server/views.py @@ -92,8 +92,7 @@ def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.im def get(self, dreamId): meta = self.__storage.getMetadata(dreamId) - j = {} if meta is None else meta.__dict__ - return j + return {} if meta is None else meta.__dict__ class ApiIntermediates(MethodView):