diff --git a/dataflow/operators/core_vision/generate/batch_vqa_generator.py b/dataflow/operators/core_vision/generate/batch_vqa_generator.py index 3194879..398bb1c 100644 --- a/dataflow/operators/core_vision/generate/batch_vqa_generator.py +++ b/dataflow/operators/core_vision/generate/batch_vqa_generator.py @@ -1,8 +1,18 @@ +import pandas as pd +from typing import List + from dataflow.utils.registry import OPERATOR_REGISTRY -from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.storage import FileStorage, DataFlowStorage from dataflow.core import OperatorABC, LLMServingABC from dataflow import get_logger -from qwen_vl_utils import process_vision_info + +from dataflow.serving.local_model_vlm_serving import LocalModelVLMServing_vllm +from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai + + +# 提取判断是否为 API Serving 的辅助函数 +def is_api_serving(serving): + return isinstance(serving, APIVLMServing_openai) @OPERATOR_REGISTRY.register() @@ -30,7 +40,8 @@ def get_desc(lang: str = "zh"): " - output_key: 生成的答案列表列 (List[str])\n" "功能特点:\n" " - 自动进行广播 (Broadcasting),将单图映射到多个问题\n" - " - 适用于由粗到细 (Coarse-to-Fine) 的密集描述生成场景\n" + " - 统一支持 API 和本地 Local 模型部署模式\n" + " - 支持全局批处理加速推理\n" ) else: return ( @@ -43,52 +54,119 @@ def get_desc(lang: str = "zh"): " - output_key: Column storing the list of generated answers\n" "Features:\n" " - Automatically broadcasts one image to multiple prompts\n" - " - Ideal for coarse-to-fine dense captioning scenarios\n" + " - Unifies support for API and Local model deployment modes\n" + " - Supports global batch processing for faster inference\n" ) def run(self, storage: DataFlowStorage, input_prompts_key: str, input_image_key: str, output_key: str): self.logger.info(f"Running BatchVQAGenerator on {input_prompts_key}...") - df = storage.read("dataframe") - - all_answers_nested = [] + df: pd.DataFrame = storage.read("dataframe") + use_api_mode = is_api_serving(self.serving) + if use_api_mode: + self.logger.info("Using API serving mode") + else: + self.logger.info("Using local serving mode") + + # 1. 展平数据阶段 (Flatten Data) + # 将 [ [q1, q2], [q3] ] 展平为 [q1, q2, q3],以便一次性送入大模型获得最高并发性能 + flat_conversations = [] + flat_images = [] + row_question_counts = [] # 记录每一行有几个问题,用于后续重组答案 + for idx, row in df.iterrows(): questions = row.get(input_prompts_key, []) image_path = row.get(input_image_key) - if not questions or not isinstance(questions, list) or not image_path: - all_answers_nested.append([]) - continue + # 统一将图片路径处理为 List 格式 + if isinstance(image_path, str): + image_path = [image_path] + elif not image_path: + image_path = [] + + if not isinstance(questions, list): + questions = [] + + row_question_counts.append(len(questions)) - batch_prompts = [] - batch_images = [] - for q in questions: - raw = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": [ - {"type": "image", "image": image_path}, - {"type": "text", "text": q} - ]} - ] - image_inputs, _ = process_vision_info(raw) - final_p = self.serving.processor.apply_chat_template(raw, tokenize=False, add_generation_prompt=True) - - batch_prompts.append(final_p) - batch_images.append(image_inputs) - - if not batch_prompts: - all_answers_nested.append([]) - continue + # 构造标准对话格式 + if use_api_mode: + # API 模式通常只需要标准文本,图片通过 image_list 单独传入 + conversation = [{"role": "user", "content": q}] + else: + # Local 模式(如 vLLM)通常需要手动在文本前拼接 占位符 + img_tokens = "" * len(image_path) + conversation = [{"role": "user", "content": img_tokens + q}] + + flat_conversations.append(conversation) + flat_images.append(image_path) - # 批量调用 - row_answers = self.serving.generate_from_input( + # 2. 批量推理阶段 (Batch Inference) + if flat_conversations: + flat_outputs = self.serving.generate_from_input_messages( + conversations=flat_conversations, + image_list=flat_images, system_prompt=self.system_prompt, - user_inputs=batch_prompts, - image_inputs=batch_images ) + else: + flat_outputs = [] + + # 3. 重组数据阶段 (Unflatten Data) + # 将展平的输出 [a1, a2, a3] 根据 row_question_counts 重组回 [ [a1, a2], [a3] ] + all_answers_nested = [] + current_idx = 0 + for count in row_question_counts: + row_answers = flat_outputs[current_idx : current_idx + count] all_answers_nested.append(row_answers) + current_idx += count df[output_key] = all_answers_nested - storage.write(df) - return [output_key] \ No newline at end of file + output_file = storage.write(df) + + self.logger.info("Results saved to %s", output_file) + return [output_key] + + +# ========================================== +# 测试用例 (Main Block) +# ========================================== +if __name__ == "__main__": + # 使用 API 模式测试 + model = APIVLMServing_openai( + api_url="http://172.96.141.132:3001/v1", + key_name_of_api_key="DF_API_KEY", + model_name="gpt-5-nano-2025-08-07", + image_io=None, + send_request_stream=False, + max_workers=10, + timeout=1800 + ) + + # 如果需要测试本地模型,可以解开注释: + # model = LocalModelVLMServing_vllm( + # hf_model_name_or_path="Qwen/Qwen2.5-VL-3B-Instruct", + # vllm_tensor_parallel_size=1, + # ... + # ) + + generator = BatchVQAGenerator( + serving=model, + system_prompt="You are a helpful visual assistant." + ) + + storage = FileStorage( + first_entry_file_name="./dataflow/example/image_to_text_pipeline/sample_data.json", + cache_path="./cache_local", + file_name_prefix="batch_vqa", + cache_type="json", + ) + + storage.step() + + generator.run( + storage=storage, + input_prompts_key="questions", # 假设输入列包含多个问题 + input_image_key="image", + output_key="answers", # 输出列表 + ) \ No newline at end of file diff --git a/dataflow/operators/core_vision/generate/fix_prompted_vqa_generator.py b/dataflow/operators/core_vision/generate/fix_prompted_vqa_generator.py index 8a74ad1..df4a22a 100644 --- a/dataflow/operators/core_vision/generate/fix_prompted_vqa_generator.py +++ b/dataflow/operators/core_vision/generate/fix_prompted_vqa_generator.py @@ -1,18 +1,25 @@ import pandas as pd +from typing import List + from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.utils.storage import FileStorage, DataFlowStorage +from dataflow.core import OperatorABC, LLMServingABC from dataflow import get_logger -from dataflow.utils.storage import FileStorage, DataFlowStorage -from dataflow.core import OperatorABC -from dataflow.core import LLMServingABC from dataflow.serving.local_model_vlm_serving import LocalModelVLMServing_vllm -from qwen_vl_utils import process_vision_info +from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai + + +# 提取判断是否为 API Serving 的辅助函数 +def is_api_serving(serving): + return isinstance(serving, APIVLMServing_openai) + @OPERATOR_REGISTRY.register() class FixPromptedVQAGenerator(OperatorABC): - ''' + """ FixPromptedVQAGenerator generate answers for questions based on provided context. The context can be image/video. - ''' + """ def __init__(self, serving: LLMServingABC, system_prompt: str = "You are a helpful assistant.", @@ -24,38 +31,17 @@ def __init__(self, @staticmethod def get_desc(lang: str = "zh"): - return "基于给定的 system prompt 和 user prompt,并读取 image/video 生成答案" if lang == "zh" else "Generate answers for questions based on provided context. The context can be image/video." - - def _prepare_batch_inputs(self, input_media_paths, is_image: bool = True): - """ - Construct batched prompts and multimodal inputs from media paths. - """ - prompt_list = [] - media_paths = [] - type_media = "image" if is_image else "video" - - for paths in input_media_paths: - raw_prompt = [ - {"role": "system", "content": self.system_prompt}, - { - "role": "user", - "content": [ - ], - }, - ] - for path in paths: - raw_prompt[1]["content"].append({"type": type_media, type_media: path}) - raw_prompt[1]["content"].append({"type": "text", "text": self.user_prompt}) - # Get multimodal inputs - media_path, _ = process_vision_info(raw_prompt) - - prompt = self.serving.processor.apply_chat_template( - raw_prompt, tokenize=False, add_generation_prompt=True + if lang == "zh": + return ( + "固定提示词视觉问答生成算子 (FixPromptedVQAGenerator)。\n" + "基于给定的 system prompt 和 user prompt,读取 image/video 生成答案。\n\n" + "特点:\n" + " - 支持图像或视频模态\n" + " - 统一支持 API 和本地 Local 模型部署模式\n" + " - 自动管理底层的