|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
from .config import MoondreamConfig |
|
from .moondream import MoondreamModel |
|
|
|
|
|
from .image_crops import * |
|
from .vision import * |
|
from .text import * |
|
from .region import * |
|
from .utils import * |
|
|
|
|
|
def extract_question(text): |
|
prefix = "<image>\n\nQuestion: " |
|
suffix = "\n\nAnswer:" |
|
|
|
if text.startswith(prefix) and text.endswith(suffix): |
|
return text[len(prefix) : -len(suffix)] |
|
else: |
|
return None |
|
|
|
|
|
class HfConfig(PretrainedConfig): |
|
_auto_class = "AutoConfig" |
|
model_type = "moondream1" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.config = {} |
|
|
|
|
|
class HfMoondream(PreTrainedModel): |
|
_auto_class = "AutoModelForCausalLM" |
|
config_class = HfConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = MoondreamModel(MoondreamConfig.from_dict(config.config)) |
|
|
|
@property |
|
def encode_image(self): |
|
return self.model.encode_image |
|
|
|
@property |
|
def query(self): |
|
return self.model.query |
|
|
|
@property |
|
def caption(self): |
|
return self.model.caption |
|
|
|
@property |
|
def detect(self): |
|
return self.model.detect |
|
|
|
@property |
|
def point(self): |
|
return self.model.point |
|
|
|
@property |
|
def detect_gaze(self): |
|
return self.model.detect_gaze |
|
|
|
def answer_question( |
|
self, |
|
image_embeds, |
|
question, |
|
tokenizer=None, |
|
chat_history="", |
|
result_queue=None, |
|
max_new_tokens=256, |
|
**kwargs |
|
): |
|
answer = self.query(image_embeds, question)["answer"].strip() |
|
|
|
if result_queue is not None: |
|
result_queue.put(answer) |
|
return answer |
|
|
|
def batch_answer(self, images, prompts, tokenizer=None, **kwargs): |
|
answers = [] |
|
for image, prompt in zip(images, prompts): |
|
answers.append(self.query(image, prompt)["answer"].strip()) |
|
return answers |
|
|
|
def _unsupported_exception(self): |
|
raise NotImplementedError( |
|
"This method is not supported in the latest version of moondream. " |
|
"Consider upgrading to the updated API spec, or alternately pin " |
|
"to 'revision=2024-08-26'." |
|
) |
|
|
|
def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs): |
|
""" |
|
Function definition remains unchanged for backwards compatibility. |
|
Be aware that tokenizer, max_new_takens, and kwargs are ignored. |
|
""" |
|
prompt_extracted = extract_question(prompt) |
|
if prompt_extracted is not None: |
|
answer = self.model.query(image=image_embeds, question=prompt_extracted, stream=False)[ |
|
"answer" |
|
] |
|
else: |
|
image_embeds = self.encode_image(image_embeds) |
|
prompt_tokens = torch.tensor( |
|
[self.model.tokenizer.encode(prompt).ids], |
|
device=self.device, |
|
) |
|
def generator(): |
|
for token in self.model._generate_text( |
|
prompt_tokens, image_embeds.kv_cache, image_embeds.pos, max_new_tokens |
|
): |
|
yield token |
|
answer = "".join(list(generator())) |
|
|
|
return [answer] |
|
|
|
def get_input_embeddings(self): |
|
return super().get_input_embeddings() |
|
|
|
def input_embeds(self, *args, **kwargs): |
|
self._unsupported_exception() |
|
|