myyim's picture
Update app.py
dcb8512 verified
import torch
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
import streamlit as st
from PIL import Image
import os
# write access token in secrets
token = os.environ.get('HF_TOKEN')
# paligemma model
model_id = "google/paligemma2-3b-pt-896"
@st.cache_resource
def model_setup(model_id):
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",token=token).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id,token=token)
return model,processor
def runModel(prompt,image):
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=1000, do_sample=False)
generation = generation[0][input_len:]
return processor.decode(generation, skip_special_tokens=True)
def initialize():
# initialize chat history
st.session_state.messages = []
### load model
model,processor = model_setup(model_id)
### upload a file
uploaded_file = st.file_uploader("Choose an image",on_change=initialize)
if uploaded_file:
st.image(uploaded_file)
image = Image.open(uploaded_file).convert("RGB")
# tasks
task = st.radio(
"Task",
tuple(['Caption','OCR','Segment','Enter your prompt']),
horizontal=True)
# display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if task == 'Enter your prompt':
if prompt := st.chat_input("Type here!",key="question"):
# display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# run the VLM
response = runModel(prompt,image)
# display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
else:
# display user message in chat message container
with st.chat_message("user"):
st.markdown(task)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": task})
# run the VLM
response = runModel(task,image)
# display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})