-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
263 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import json | ||
import logging | ||
import os | ||
from datetime import datetime | ||
from typing import Any, Dict, Optional | ||
|
||
from openai import OpenAI | ||
|
||
from lib.app_singleton import AppSingleton | ||
from lib.config import read_config | ||
|
||
logger = AppSingleton().get_logger() | ||
logger.setLevel(logging.DEBUG) | ||
|
||
read_config() # FIXME: maybe I should read config in application code, not lib code. | ||
client = OpenAI(base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") | ||
|
||
# Statuses that indicate the batch is still processing | ||
PROCESSING_STATUSES = {"validating", "in_progress", "finalizing"} | ||
|
||
|
||
def send_batch_file(jsonl_path: str, endpoint: str = "/v1/chat/completions") -> str: | ||
""" | ||
Send a JSONL file to OpenAI's batch API. | ||
Args: | ||
jsonl_path: Path to the JSONL file containing prompts | ||
endpoint: OpenAI API endpoint to use | ||
Returns: | ||
The batch ID for tracking the request | ||
""" | ||
# Upload the JSONL file | ||
batch_input_file = client.files.create(file=open(jsonl_path, "rb"), purpose="batch") | ||
batch_input_file_id = batch_input_file.id | ||
|
||
# Generate batch ID using timestamp | ||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | ||
batch_id = f"batch-{timestamp}" | ||
|
||
# Create the batch | ||
batch = client.batches.create( | ||
input_file_id=batch_input_file_id, | ||
endpoint=endpoint, | ||
completion_window="24h", | ||
metadata={ | ||
"batch_id": batch_id, | ||
"source_file": os.path.basename(jsonl_path), | ||
}, | ||
) | ||
|
||
logger.info(f"Created batch with ID: {batch.id}") | ||
return batch.id | ||
|
||
|
||
def check_batch_job_status(batch_id: str) -> str: | ||
""" | ||
Get the current status of a batch job. | ||
Args: | ||
batch_id: The batch ID to check | ||
Returns: | ||
Current status of the batch job | ||
""" | ||
batch = client.batches.retrieve(batch_id) | ||
return batch.status | ||
|
||
|
||
def simplify_openai_response(response_data: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Simplify OpenAI batch response format to keep only essential information. | ||
Args: | ||
response_data: Raw response data from OpenAI batch API | ||
Returns: | ||
Simplified response dictionary containing only essential fields | ||
""" | ||
simplified = { | ||
"custom_id": response_data.get("custom_id"), | ||
"status_code": response_data.get("response", {}).get("status_code"), | ||
"content": None, | ||
"error": response_data.get("error"), | ||
} | ||
|
||
# Extract content from choices if available | ||
try: | ||
choices = response_data.get("response", {}).get("body", {}).get("choices", []) | ||
if choices and len(choices) > 0: | ||
simplified["content"] = choices[0]["message"]["content"] | ||
except (KeyError, TypeError): | ||
pass | ||
|
||
return simplified | ||
|
||
|
||
def download_batch_job_output(batch_id: str, output_path: str) -> Optional[str]: | ||
""" | ||
Download and simplify results for a completed batch job. | ||
Args: | ||
batch_id: The batch ID to download results for | ||
output_path: Path to save results file | ||
Returns: | ||
Path to the downloaded results file if successful, None if batch not completed | ||
""" | ||
# Get batch info | ||
batch = client.batches.retrieve(batch_id) | ||
|
||
if batch.status != "completed": | ||
logger.error(f"Cannot download results - batch status is {batch.status}") | ||
return None | ||
|
||
# Create a temporary file for raw results | ||
temp_output = f"{output_path}.temp" | ||
|
||
# Download raw results file | ||
client.files.content(batch.output_file_id).write_to_file(temp_output) | ||
|
||
# Process and simplify the results | ||
with open(temp_output, "r", encoding="utf-8") as raw_file, open( | ||
output_path, "w", encoding="utf-8" | ||
) as out_file: | ||
for line in raw_file: | ||
try: | ||
response_data = json.loads(line) | ||
simplified = simplify_openai_response(response_data) | ||
out_file.write(json.dumps(simplified, ensure_ascii=False) + "\n") | ||
except json.JSONDecodeError as e: | ||
logger.error(f"Error processing line: {e}") | ||
continue | ||
|
||
# Clean up temporary file | ||
os.remove(temp_output) | ||
|
||
logger.info(f"Saved simplified batch results to {output_path}") | ||
return output_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import time | ||
from typing import Optional | ||
|
||
from lib.app_singleton import AppSingleton | ||
from lib.config import read_config | ||
from lib.llm.alibaba_batch_api import ( | ||
PROCESSING_STATUSES, | ||
check_batch_job_status, | ||
download_batch_job_output, | ||
send_batch_file, | ||
) | ||
|
||
logger = AppSingleton().get_logger() | ||
logger.setLevel(logging.DEBUG) | ||
|
||
read_config() | ||
|
||
|
||
def get_batch_id_and_output_path(jsonl_path: str) -> str: | ||
""" | ||
Generate output path from input JSONL filename. | ||
Args: | ||
jsonl_path: Path to input JSONL file | ||
Returns: | ||
Path where response file will be saved | ||
""" | ||
# Get base filename without extension and output directory | ||
base_name = os.path.splitext(os.path.basename(jsonl_path))[0] | ||
output_dir = os.path.dirname(jsonl_path) | ||
return os.path.join(output_dir, f"{base_name}-response.jsonl") | ||
|
||
|
||
def send_batch_to_openai(jsonl_path: str) -> str: | ||
""" | ||
Send a JSONL file to OpenAI's batch API using the OpenAI client. | ||
Uses a .processing file to cache the batch ID and avoid re-uploads. | ||
Args: | ||
jsonl_path: Path to the JSONL file containing prompts | ||
Returns: | ||
The batch ID for tracking the request | ||
""" | ||
output_path = get_batch_id_and_output_path(jsonl_path) | ||
|
||
# Check for existing processing file | ||
processing_file = f"{output_path}.processing" | ||
if os.path.exists(processing_file): | ||
logger.info("Batch already being processed.") | ||
# Read and return the batch ID from the file | ||
with open(processing_file, "r") as f: | ||
return f.read().strip() | ||
|
||
# Send batch to OpenAI | ||
batch_id = send_batch_file(jsonl_path) | ||
|
||
# Create processing file with batch info | ||
with open(processing_file, "w") as f: | ||
f.write(batch_id) | ||
logger.info("Batch created successfully.") | ||
|
||
return batch_id | ||
|
||
|
||
def wait_for_batch_completion(batch_id: str, output_path: str) -> Optional[str]: | ||
""" | ||
Wait for a batch to complete and download results when ready. | ||
Args: | ||
batch_id: The batch ID to monitor | ||
output_path: Path to save results file | ||
Returns: | ||
Path to results file if successful, None if batch failed/cancelled | ||
""" | ||
logger.info(f"Waiting for batch {batch_id} to complete...") | ||
while True: | ||
status = check_batch_job_status(batch_id) | ||
logger.info(f"Batch status: {status}") | ||
|
||
if status == "completed": | ||
# Download results | ||
return download_batch_job_output(batch_id, output_path) | ||
elif status in PROCESSING_STATUSES: | ||
# Still processing, wait before checking again | ||
time.sleep(60) | ||
else: | ||
# Failed or cancelled | ||
logger.error(f"Batch {batch_id} ended with status: {status}") | ||
return None | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Send JSONL prompts to OpenAI batch API" | ||
) | ||
parser.add_argument( | ||
"jsonl_file", type=str, help="Path to the JSONL file containing prompts" | ||
) | ||
parser.add_argument( | ||
"--wait", | ||
action="store_true", | ||
help="Wait for batch completion and download results", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
try: | ||
batch_id = send_batch_to_openai(args.jsonl_file) | ||
if args.wait: | ||
output_path = get_batch_id_and_output_path(args.jsonl_file) | ||
result_path = wait_for_batch_completion(batch_id, output_path) | ||
if result_path: | ||
print(f"Results saved to: {result_path}") | ||
else: | ||
print(f"Batch ID: {batch_id}") | ||
except Exception as e: | ||
logger.error(f"Error sending batch: {str(e)}") | ||
raise |