Skip to content

Commit

Permalink
add alibaba batch api
Browse files Browse the repository at this point in the history
  • Loading branch information
semio committed Jan 20, 2025
1 parent 07f561b commit dc3a233
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 0 deletions.
139 changes: 139 additions & 0 deletions automation-api/lib/llm/alibaba_batch_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import json
import logging
import os
from datetime import datetime
from typing import Any, Dict, Optional

from openai import OpenAI

from lib.app_singleton import AppSingleton
from lib.config import read_config

logger = AppSingleton().get_logger()
logger.setLevel(logging.DEBUG)

read_config() # FIXME: maybe I should read config in application code, not lib code.
client = OpenAI(base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")

# Statuses that indicate the batch is still processing
PROCESSING_STATUSES = {"validating", "in_progress", "finalizing"}


def send_batch_file(jsonl_path: str, endpoint: str = "/v1/chat/completions") -> str:
"""
Send a JSONL file to OpenAI's batch API.
Args:
jsonl_path: Path to the JSONL file containing prompts
endpoint: OpenAI API endpoint to use
Returns:
The batch ID for tracking the request
"""
# Upload the JSONL file
batch_input_file = client.files.create(file=open(jsonl_path, "rb"), purpose="batch")
batch_input_file_id = batch_input_file.id

# Generate batch ID using timestamp
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
batch_id = f"batch-{timestamp}"

# Create the batch
batch = client.batches.create(
input_file_id=batch_input_file_id,
endpoint=endpoint,
completion_window="24h",
metadata={
"batch_id": batch_id,
"source_file": os.path.basename(jsonl_path),
},
)

logger.info(f"Created batch with ID: {batch.id}")
return batch.id


def check_batch_job_status(batch_id: str) -> str:
"""
Get the current status of a batch job.
Args:
batch_id: The batch ID to check
Returns:
Current status of the batch job
"""
batch = client.batches.retrieve(batch_id)
return batch.status


def simplify_openai_response(response_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Simplify OpenAI batch response format to keep only essential information.
Args:
response_data: Raw response data from OpenAI batch API
Returns:
Simplified response dictionary containing only essential fields
"""
simplified = {
"custom_id": response_data.get("custom_id"),
"status_code": response_data.get("response", {}).get("status_code"),
"content": None,
"error": response_data.get("error"),
}

# Extract content from choices if available
try:
choices = response_data.get("response", {}).get("body", {}).get("choices", [])
if choices and len(choices) > 0:
simplified["content"] = choices[0]["message"]["content"]
except (KeyError, TypeError):
pass

return simplified


def download_batch_job_output(batch_id: str, output_path: str) -> Optional[str]:
"""
Download and simplify results for a completed batch job.
Args:
batch_id: The batch ID to download results for
output_path: Path to save results file
Returns:
Path to the downloaded results file if successful, None if batch not completed
"""
# Get batch info
batch = client.batches.retrieve(batch_id)

if batch.status != "completed":
logger.error(f"Cannot download results - batch status is {batch.status}")
return None

# Create a temporary file for raw results
temp_output = f"{output_path}.temp"

# Download raw results file
client.files.content(batch.output_file_id).write_to_file(temp_output)

# Process and simplify the results
with open(temp_output, "r", encoding="utf-8") as raw_file, open(
output_path, "w", encoding="utf-8"
) as out_file:
for line in raw_file:
try:
response_data = json.loads(line)
simplified = simplify_openai_response(response_data)
out_file.write(json.dumps(simplified, ensure_ascii=False) + "\n")
except json.JSONDecodeError as e:
logger.error(f"Error processing line: {e}")
continue

# Clean up temporary file
os.remove(temp_output)

logger.info(f"Saved simplified batch results to {output_path}")
return output_path
124 changes: 124 additions & 0 deletions automation-api/lib/pilot/send_batch_prompt_alibaba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import argparse
import logging
import os
import time
from typing import Optional

from lib.app_singleton import AppSingleton
from lib.config import read_config
from lib.llm.alibaba_batch_api import (
PROCESSING_STATUSES,
check_batch_job_status,
download_batch_job_output,
send_batch_file,
)

logger = AppSingleton().get_logger()
logger.setLevel(logging.DEBUG)

read_config()


def get_batch_id_and_output_path(jsonl_path: str) -> str:
"""
Generate output path from input JSONL filename.
Args:
jsonl_path: Path to input JSONL file
Returns:
Path where response file will be saved
"""
# Get base filename without extension and output directory
base_name = os.path.splitext(os.path.basename(jsonl_path))[0]
output_dir = os.path.dirname(jsonl_path)
return os.path.join(output_dir, f"{base_name}-response.jsonl")


def send_batch_to_openai(jsonl_path: str) -> str:
"""
Send a JSONL file to OpenAI's batch API using the OpenAI client.
Uses a .processing file to cache the batch ID and avoid re-uploads.
Args:
jsonl_path: Path to the JSONL file containing prompts
Returns:
The batch ID for tracking the request
"""
output_path = get_batch_id_and_output_path(jsonl_path)

# Check for existing processing file
processing_file = f"{output_path}.processing"
if os.path.exists(processing_file):
logger.info("Batch already being processed.")
# Read and return the batch ID from the file
with open(processing_file, "r") as f:
return f.read().strip()

# Send batch to OpenAI
batch_id = send_batch_file(jsonl_path)

# Create processing file with batch info
with open(processing_file, "w") as f:
f.write(batch_id)
logger.info("Batch created successfully.")

return batch_id


def wait_for_batch_completion(batch_id: str, output_path: str) -> Optional[str]:
"""
Wait for a batch to complete and download results when ready.
Args:
batch_id: The batch ID to monitor
output_path: Path to save results file
Returns:
Path to results file if successful, None if batch failed/cancelled
"""
logger.info(f"Waiting for batch {batch_id} to complete...")
while True:
status = check_batch_job_status(batch_id)
logger.info(f"Batch status: {status}")

if status == "completed":
# Download results
return download_batch_job_output(batch_id, output_path)
elif status in PROCESSING_STATUSES:
# Still processing, wait before checking again
time.sleep(60)
else:
# Failed or cancelled
logger.error(f"Batch {batch_id} ended with status: {status}")
return None


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Send JSONL prompts to OpenAI batch API"
)
parser.add_argument(
"jsonl_file", type=str, help="Path to the JSONL file containing prompts"
)
parser.add_argument(
"--wait",
action="store_true",
help="Wait for batch completion and download results",
)

args = parser.parse_args()

try:
batch_id = send_batch_to_openai(args.jsonl_file)
if args.wait:
output_path = get_batch_id_and_output_path(args.jsonl_file)
result_path = wait_for_batch_completion(batch_id, output_path)
if result_path:
print(f"Results saved to: {result_path}")
else:
print(f"Batch ID: {batch_id}")
except Exception as e:
logger.error(f"Error sending batch: {str(e)}")
raise

0 comments on commit dc3a233

Please sign in to comment.