From 52913bbc7eb5c781cebeaea659e59bfb229feba9 Mon Sep 17 00:00:00 2001 From: micimize Date: Tue, 27 Dec 2022 12:47:14 -0800 Subject: [PATCH 1/2] troubleshooting --- server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index afe475f..65c6d94 100644 --- a/server.py +++ b/server.py @@ -18,6 +18,8 @@ # TODO: detect if there's 8G VRAM before we enable this model revision="fp16", torch_dtype=torch.float16, + safety_checker=None, + requires_safety_checker=False, ).to("cuda") secho("Finished!", fg="green") @@ -68,7 +70,7 @@ def img2img(): strength=float(headers["sketch_strength"]), guidance_scale=float(headers["prompt_strength"]), num_inference_steps=int(headers["steps"]), - )["sample"][0] + ).images[0] return_bytes = BytesIO() From eda084f88c1a88100f02edef7a17ccddc77092fb Mon Sep 17 00:00:00 2001 From: micimize Date: Tue, 27 Dec 2022 14:05:25 -0800 Subject: [PATCH 2/2] add txt2img for empty canvases --- server.py | 68 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/server.py b/server.py index 65c6d94..2877019 100644 --- a/server.py +++ b/server.py @@ -4,7 +4,7 @@ from PIL import Image from io import BytesIO from torch import autocast -from diffusers import StableDiffusionImg2ImgPipeline +from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline from click import secho from zipfile import ZipFile @@ -12,17 +12,23 @@ secho("Loading Model...", fg="yellow") -# FIXME: more elegant model scope -pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", use_auth_token=True, +MODEL = "stabilityai/stable-diffusion-2-1" + +params = dict( + use_auth_token=True, # TODO: detect if there's 8G VRAM before we enable this model revision="fp16", torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False, -).to("cuda") +) + +# FIXME: more elegant model scope +pipe = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL, **params).to("cuda") -secho("Finished!", fg="green") +text_pipe = StableDiffusionPipeline.from_pretrained(MODEL, **params).to("cuda") + +secho("Model Loaded!", fg="green") app = Flask(__name__) @@ -39,45 +45,59 @@ def seed_everything(seed: int): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True +def is_one_color(image): + # https://stackoverflow.com/a/28915260/2234013 + colors = image.getcolors() + print(colors) + return len(colors) == 1 + @app.route("/api/img2img", methods=["POST"]) def img2img(): global pipe + global text_pipe - r = request - headers = r.headers + headers = request.headers + data = request.data - data = r.data buff = BytesIO(data) img = Image.open(buff).convert("RGB") + print(type(img)) seed = int(headers["seed"]) prompt = headers['prompt'] - - - print(r.headers) + variant_count = int(headers.get('variations', 1) or 1) zip_stream = BytesIO() with ZipFile(zip_stream, 'w') as zf: - - for index in range(int(headers['variations'])): - variation_seed = seed + index - seed_everything(variation_seed) - - with autocast("cuda"): - return_image = pipe( - init_image=img, + # TODO num_images_per_prompt results in memory issues easily + # num_images_per_prompt=variant_count + for index in range(variant_count): + variant_seed = seed + index + seed_everything(variant_seed) + + if is_one_color(img): + secho('Image is empty – generating with txt2img') + diffusion_results = text_pipe( + prompt=prompt, + guidance_scale=float(headers["prompt_strength"]), + num_inference_steps=int(headers["steps"]), + width=512, + height=512 + ) + else: + diffusion_results = pipe( + image=img, prompt=prompt, strength=float(headers["sketch_strength"]), guidance_scale=float(headers["prompt_strength"]), num_inference_steps=int(headers["steps"]), - ).images[0] - - + ) + return_image = diffusion_results.images[0] return_bytes = BytesIO() return_image.save(return_bytes, format="JPEG") return_bytes.seek(0) - zf.writestr(get_name(prompt, variation_seed), return_bytes.read()) + zf.writestr(get_name(prompt, variant_seed), return_bytes.read()) zip_stream.seek(0)