Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,31 @@
from PIL import Image
from io import BytesIO
from torch import autocast
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from click import secho
from zipfile import ZipFile

# TODO: add command line arguments

secho("Loading Model...", fg="yellow")

# FIXME: more elegant model scope
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True,
MODEL = "stabilityai/stable-diffusion-2-1"

params = dict(
use_auth_token=True,
# TODO: detect if there's 8G VRAM before we enable this model
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
safety_checker=None,
requires_safety_checker=False,
)

# FIXME: more elegant model scope
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL, **params).to("cuda")

secho("Finished!", fg="green")
text_pipe = StableDiffusionPipeline.from_pretrained(MODEL, **params).to("cuda")

secho("Model Loaded!", fg="green")

app = Flask(__name__)

Expand All @@ -37,45 +45,59 @@ def seed_everything(seed: int):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

def is_one_color(image):
# https://stackoverflow.com/a/28915260/2234013
colors = image.getcolors()
print(colors)
return len(colors) == 1

@app.route("/api/img2img", methods=["POST"])
def img2img():
global pipe
global text_pipe

r = request
headers = r.headers
headers = request.headers
data = request.data

data = r.data
buff = BytesIO(data)
img = Image.open(buff).convert("RGB")
print(type(img))

seed = int(headers["seed"])
prompt = headers['prompt']


print(r.headers)
variant_count = int(headers.get('variations', 1) or 1)

zip_stream = BytesIO()
with ZipFile(zip_stream, 'w') as zf:

for index in range(int(headers['variations'])):
variation_seed = seed + index
seed_everything(variation_seed)

with autocast("cuda"):
return_image = pipe(
init_image=img,
# TODO num_images_per_prompt results in memory issues easily
# num_images_per_prompt=variant_count
for index in range(variant_count):
variant_seed = seed + index
seed_everything(variant_seed)

if is_one_color(img):
secho('Image is empty – generating with txt2img')
diffusion_results = text_pipe(
prompt=prompt,
guidance_scale=float(headers["prompt_strength"]),
num_inference_steps=int(headers["steps"]),
width=512,
height=512
)
else:
diffusion_results = pipe(
image=img,
prompt=prompt,
strength=float(headers["sketch_strength"]),
guidance_scale=float(headers["prompt_strength"]),
num_inference_steps=int(headers["steps"]),
)["sample"][0]


)
return_image = diffusion_results.images[0]
return_bytes = BytesIO()
return_image.save(return_bytes, format="JPEG")

return_bytes.seek(0)
zf.writestr(get_name(prompt, variation_seed), return_bytes.read())
zf.writestr(get_name(prompt, variant_seed), return_bytes.read())

zip_stream.seek(0)

Expand Down