diff --git a/.gitignore b/.gitignore index 0e5ac793..3768286e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .venv -__pycache__ \ No newline at end of file +__pycache__ +/huggingface +/.vs diff --git a/moondream/moondream.py b/moondream/moondream.py index 8127e83d..a17e69b4 100644 --- a/moondream/moondream.py +++ b/moondream/moondream.py @@ -85,6 +85,7 @@ def answer_question( image_embeds, question, tokenizer, + max_new_tokens, chat_history="", result_queue=None, **kwargs, diff --git a/openai_api_demo.py b/openai_api_demo.py new file mode 100644 index 00000000..54cd358e --- /dev/null +++ b/openai_api_demo.py @@ -0,0 +1,351 @@ +import time +import uvicorn +import argparse + +import torch +from transformers import TextIteratorStreamer, CodeGenTokenizerFast as Tokenizer +from sse_starlette.sse import EventSourceResponse + +from loguru import logger +from typing import List, Literal, Union, Tuple, Optional + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware + +from pydantic import BaseModel, Field + +import requests +import base64 +from PIL import Image +from io import BytesIO +import re +from threading import Thread + +from moondream import Moondream, detect_device +from contextlib import asynccontextmanager + +# 请求 +class TextContent(BaseModel): + type: Literal["text"] + text: str +class ImageUrl(BaseModel): + url: str +class ImageUrlContent(BaseModel): + type: Literal["image_url"] + image_url: ImageUrl +ContentItem = Union[TextContent, ImageUrlContent] +class ChatMessageInput(BaseModel): + role: Literal["user", "assistant", "system"] + content: Union[str, List[ContentItem]] + name: Optional[str] = None +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessageInput] + temperature: Optional[float] = 0.8 + top_p: Optional[float] = 0.8 + max_tokens: Optional[int] = None + stream: Optional[bool] = False + # Additional parameters + repetition_penalty: Optional[float] = 1.0 + +# 响应 +class ChatMessageResponse(BaseModel): + role: Literal["assistant"] + content: str = None + name: Optional[str] = None +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessageResponse +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 +class ChatCompletionResponse(BaseModel): + model: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + usage: Optional[UsageInfo] = None + +# 图片输入处理 +def process_img(input_data): + if isinstance(input_data, str): + # URL + if input_data.startswith("http://") or input_data.startswith("https://"): + response = requests.get(input_data) + image_data = response.content + pil_image = Image.open(BytesIO(image_data)).convert('RGB') + # base64 + elif input_data.startswith("data:image/"): + base64_data = input_data.split(",")[1] + image_data = base64.b64decode(base64_data) + pil_image = Image.open(BytesIO(image_data)).convert('RGB') + # img_path + else: + pil_image = Image.open(input_data) + # PIL + elif isinstance(input_data, Image.Image): + pil_image = input_data + else: + raise ValueError("data type error") + + return pil_image + +# 历史消息处理 +def process_history_and_images(messages: List[ChatMessageInput]) -> Tuple[ + Optional[str], Optional[str], Optional[List[Image.Image]]]: + + def chat_history_to_prompt(history): + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += f"Question: {old_query}\n\nAnswer: {response}\n\n" + return prompt + + last_user_texts = '' + formatted_history = [] + image_list = [] + + for i, message in enumerate(messages): + role = message.role + content = message.content + + if isinstance(content, list): # text + text_content = ' '.join(item.text for item in content if isinstance(item, TextContent)) + else: + text_content = content + + if isinstance(content, list): # image + for item in content: + if isinstance(item, ImageUrlContent): + image_url = item.image_url.url + image = process_img(image_url) + image_list.append(image) + + if role == 'user': + if i == len(messages) - 1: # last message + last_user_texts = text_content + else: + formatted_history.append((text_content, '')) + elif role == 'assistant': + if formatted_history: + if formatted_history[-1][1] != '': + assert False, f"the last texts is answered. answer again. {formatted_history[-1][0]}, {formatted_history[-1][1]}, {text_content}" + formatted_history[-1] = (formatted_history[-1][0], text_content) + else: + assert False, f"assistant reply before user" + else: + assert False, f"unrecognized role: {role}" + + history = chat_history_to_prompt(formatted_history) + + return last_user_texts, history, image_list + + +@torch.inference_mode() +# Moondrean推理 +def generate_stream_moondream(params: dict): + global model, tokenizer + + # 输入处理 + messages = params["messages"] + + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_tokens", 128)) + + prompt, history, image_list = process_history_and_images(messages) + # 只处理最后一张图 + img = image_list[-1] + + # 构建输入 + ''' + answer_question( + image_embeds, + question, + tokenizer, + max_new_tokens, + chat_history="", + result_queue=None, + **kwargs, + ) + ''' + image_embeds = model.encode_image(img) + streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) + gen_kwargs = { + "image_embeds": image_embeds, + "question": prompt, + "tokenizer": tokenizer, + "max_new_tokens": max_new_tokens, + "chat_history":history, + "repetition_penalty": repetition_penalty, + "do_sample": False, + "top_p": top_p, + "streamer": streamer, + } + if temperature > 1e-5: + gen_kwargs["temperature"] = temperature + + thread = Thread( + target=model.answer_question, + kwargs=gen_kwargs, + ) + + input_echo_len = 0 + total_len = 0 + # 启动推理 + thread.start() + buffer = "" + for new_text in streamer: + clean_text = re.sub("<$|END$", "", new_text) + buffer += clean_text + yield { + "text": buffer.strip("= 8: + torch_type = torch.bfloat16 + else: + torch_type = torch.float16 + + print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) + + load_mod(MODEL_PATH) + + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/openapi_requirements.txt b/openapi_requirements.txt new file mode 100644 index 00000000..32da27aa --- /dev/null +++ b/openapi_requirements.txt @@ -0,0 +1,6 @@ +sse-starlette>=1.8.2 +fastapi>=0.105.0 +loguru~=0.7.2 +uvicorn~=0.24.0 +requests +pydantic>=2.5.2 \ No newline at end of file