Skip to content

Commit

Permalink
fixed gender not in image gen
Browse files Browse the repository at this point in the history
  • Loading branch information
thijsi123 committed Feb 1, 2024
1 parent bbe3cba commit c12fdb4
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 174 deletions.
2 changes: 1 addition & 1 deletion app/main-mistral-webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def generate_character_avatar(
example_dialogue
+ "\n[INST] create a prompt that lists the appearance "
+ "characteristics of a character whose summary is "
+ "if lack of info, generate something based on available info."
+ f" {character_summary}. Topic: {topic} [/INST]\n"
+ "if lack of info, generate something based on available info."
)
print(sd_prompt)
sd_filter(nsfw_filter)
Expand Down
86 changes: 52 additions & 34 deletions app/oobabooga - webui.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import aichar
import requests
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionPipeline
import torch
import gradio as gr
import re
Expand All @@ -12,9 +12,34 @@
sd = None
safety_checker_sd = None

folder_path = "models"
global_url = ""
folder_path = "models" # Base directory for models

def load_model(model_name, use_safetensors=False, use_local=False):
global sd

# Enable TensorFloat-32 for matrix multiplications
torch.backends.cuda.matmul.allow_tf32 = True

if use_local:
model_path = os.path.join(folder_path, model_name).replace("\\", "/")
if not os.path.exists(model_path):
print(f"Model {model_name} not found at {model_path}.")
return
print(f"Loading local model from: {model_path}")
sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16)
else:
print(f"Loading {model_name} from Hugging Face with safetensors={use_safetensors}.")
sd = StableDiffusionPipeline.from_pretrained(model_name, use_safetensors=use_safetensors, torch_dtype=torch.float16)

if torch.cuda.is_available():
sd.to("cuda")
print(f"Loaded {model_name} to GPU in half precision (float16).")
else:
print(f"Loaded {model_name} to CPU.")


# For a local .safetensors model
load_model("oof.safetensors", use_safetensors=True, use_local=True)

def process_url(url):
global global_url
Expand Down Expand Up @@ -97,29 +122,9 @@ def send_message(prompt):
return f"Error sending request: {e}"


def load_models():
global sd
sd = DiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-8",
torch_dtype=torch.float16,
variant="fp16",
low_cpu_mem_usage=False,
)
if torch.cuda.is_available():
sd.to("cuda")
print("Loading Stable Diffusion to GPU...")
else:
print("Loading Stable Diffusion to CPU...")
global llm
gpu_layers = 0
if torch.cuda.is_available():
gpu_layers = 110
print("Loading LLM to GPU...")
else:
print("Loading LLM to CPU...")


load_models()
# Example Usage
# For a model hosted on Hugging Face without safetensors
# load_models("dreamshaper-8")


def generate_character_name(topic, gender, name, surname_checkbox):
Expand Down Expand Up @@ -700,7 +705,7 @@ def generate_character_avatar(
topic,
negative_prompt,
avatar_prompt,
nsfw_filter,
nsfw_filter, gender
):
example_dialogue = """
<|system|>
Expand All @@ -723,16 +728,22 @@ def generate_character_avatar(
<|user|> create a prompt that lists the appearance characteristics of a character whose summary is Name: suzie Summary: Topic: none Gender: none</s>
<|assistant|> 1girl, improvised tag, </s>
""" # nopep8
sd_prompt = (
print(gender)
# Detect if "anime" is in the character summary or topic and adjust the prompt
anime_specific_tag = "anime, 2d, " if 'anime' in character_summary.lower() or 'anime' in topic.lower() else ""
raw_sd_prompt = (
input_none(avatar_prompt)
or send_message(
example_dialogue
+ "\n<|user|> create a prompt that lists the appearance "
+ "\n<|user|> create a prompt that lists the appearance " #create a prompt that lists the appearance characteristics of a character whose summary is Gender: male, name=gabe. Topic: anime
+ "characteristics of a character whose summary is "
+ "if lack of info, generate something based on available info."
+ f"Gender: {gender}"
+ f"{character_summary}. Topic: {topic}</s>\n<|assistant|> "
+ "if lack of info, generate something based on available info."
).strip()
)
# Append the anime_specific_tag at the beginning of the raw_sd_prompt
sd_prompt = anime_specific_tag + raw_sd_prompt.strip()
print(sd_prompt)
sd_filter(nsfw_filter)
return image_generate(character_name,
Expand Down Expand Up @@ -777,7 +788,7 @@ def image_generate(character_name, prompt, negative_prompt):
# Call process_uploaded_image
process_uploaded_image(reloaded_image_np)

print("Generated character avatar")
print("Generated character avatar" + prompt)
return generated_image


Expand Down Expand Up @@ -855,7 +866,6 @@ def export_as_json(
# Global variable to store the path of the processed image
processed_image_path = None


def export_character_card(name, summary, personality, scenario, greeting_message, example_messages):
global processed_image_path # Access the global variable

Expand Down Expand Up @@ -890,7 +900,6 @@ def export_character_card(name, summary, personality, scenario, greeting_message
character.export_neutral_card_file(card_path)
return Image.open(card_path)


with gr.Blocks() as webui:
gr.Markdown("# Character Factory WebUI")
gr.Markdown("## OOBABOOGA MODE")
Expand All @@ -911,9 +920,11 @@ def export_character_card(name, summary, personality, scenario, greeting_message
placeholder="Topic: The topic for character generation (e.g., Fantasy, Anime, etc.)", # nopep8
label="topic",
)

gender = gr.Textbox(
placeholder="Gender: Gender of the character", label="gender"
)

with gr.Column():
with gr.Row():
name = gr.Textbox(placeholder="character name", label="name")
Expand Down Expand Up @@ -1017,7 +1028,9 @@ def handle_example_messages_button_click(
inputs=[name, summary, personality, topic, switch_function_checkbox],
outputs=example_messages,
)

'''gender = gr.Textbox(
placeholder="Gender: Gender of the character", label="gender"
)'''
with gr.Row():
with gr.Column():
image_input = gr.Image(interactive=True, label="Character Image", width=512, height=512)
Expand All @@ -1031,6 +1044,7 @@ def handle_example_messages_button_click(
outputs=[image_input] # You can update the same image display with the processed image
)
with gr.Column():
gender = gender
negative_prompt = gr.Textbox(
placeholder="negative prompt for stable diffusion (optional)", # nopep8
label="negative prompt",
Expand Down Expand Up @@ -1058,6 +1072,7 @@ def handle_example_messages_button_click(
negative_prompt,
avatar_prompt,
potential_nsfw_checkbox,
gender,
],
outputs=image_input,
)
Expand Down Expand Up @@ -1143,3 +1158,6 @@ def handle_example_messages_button_click(
safety_checker_sd = sd.safety_checker

webui.launch(debug=True)



2 changes: 1 addition & 1 deletion app/oobabooga.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,8 @@ def generate_character_avatar(character_name, character_summary, args):
example_dialogue
+ "\n<|user|> create a prompt that lists the appearance "
+ "characteristics of a character whose summary is "
+ "if lack of info, generate something based on available info."
+ f"{character_summary}. Topic: {topic} </s>\n<|assistant|> "
+ "if lack of info, generate something based on available info."
)
)
print(sd_prompt)
Expand Down
Loading

0 comments on commit c12fdb4

Please sign in to comment.