|
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
import torch |
|
from PIL import Image |
|
import os |
|
|
|
class VisionLLM: |
|
def __init__(self, device="cuda", model_id="google/paligemma2-3b-pt-224", use_auth_token=None): |
|
self.device = device |
|
if use_auth_token is None: |
|
use_auth_token=os.environ.get("HF_TOKEN", None) |
|
self.processor = AutoProcessor.from_pretrained(model_id, use_auth_token=use_auth_token) |
|
self.model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16, use_auth_token=use_auth_token).to(self.device) |
|
|
|
def describe_images(self, images, prompt="", max_length=128): |
|
if isinstance(images, list): |
|
images = [img.convert("RGB") for img in images] |
|
else: |
|
images = images.convert("RGB") |
|
inputs = self.processor(images=images, text=prompt, return_tensors="pt").to(self.device) |
|
with torch.no_grad(), torch.autocast("cuda"): |
|
outputs = self.model.generate(**inputs, max_length=max_length) |
|
descriptions = self.processor.batch_decode(outputs, skip_special_tokens=True) |
|
return descriptions |
|
|
|
if __name__ == "__main__": |
|
vllm = VisionLLM() |
|
images = [Image.open("multi_view_output.png")] |
|
prompt = "Describe the objects in the image" |
|
descriptions = vllm.describe_images(images, prompt) |
|
print(descriptions) |