PaliGemma2 / app.py
breadlicker45's picture
Update app.py
9174c87 verified
import gradio as gr
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch
import os
import spaces # Import the spaces module
import requests
from io import BytesIO
from PIL import Image
def load_model():
"""Load PaliGemma2 model and processor with Hugging Face token."""
token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Retrieve token from environment variable
if not token:
raise ValueError(
"Hugging Face API token not found. Please set it in the environment variables."
)
# Load the processor and model using the correct identifier
model_id = "google/paligemma2-10b-pt-448"
processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, use_auth_token=token
).to(device).eval()
return processor, model
@spaces.GPU(duration=120) # Increased timeout to 120 seconds
def process_image_and_text(image_pil, num_beams, temperature, seed):
"""Extract text from image using PaliGemma2."""
try:
processor, model = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the image using load_image
image = load_image(image_pil)
# Add <image> token to the beginning of the text prompt
text_input = " "
# Use the provided text input
model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(
device, dtype=torch.bfloat16
)
input_len = model_inputs["input_ids"].shape[-1]
torch.manual_seed(seed) # Set random seed for reproducibility
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=200, do_sample=True, num_beams=num_beams, temperature=temperature)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
return decoded
except Exception as e:
print(f"Error during GPU task: {e}")
raise gr.Error(f"GPU task failed: {e}")
if __name__ == "__main__":
iface = gr.Interface(
fn=process_image_and_text,
inputs=[
gr.Image(type="pil", label="Upload an image"),
gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Number of Beams"),
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature"),
gr.Number(label="Random Seed", value=0, precision=0),
],
outputs=gr.Textbox(label="Generated Text"),
title="PaliGemma2 Image to Text",
description="Upload an image and the model will generate text.",
)
iface.launch()