Skip to content

Commit

Permalink
add anthropic API
Browse files Browse the repository at this point in the history
  • Loading branch information
semio committed Jan 14, 2025
1 parent a6301de commit 4c97c50
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 90 deletions.
136 changes: 136 additions & 0 deletions automation-api/lib/llm/anthropic_batch_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
import logging
from typing import Any, Dict, Optional

import anthropic

from lib.app_singleton import AppSingleton
from lib.config import read_config

logger = AppSingleton().get_logger()
logger.setLevel(logging.DEBUG)

# Initialize Anthropic client
read_config() # FIXME: don't do read config in lib code
client = anthropic.Anthropic()


def send_batch_file(jsonl_path: str) -> str:
"""
Send a batch of prompts to Anthropic API.
Args:
jsonl_path: Path to JSONL file containing prompts
Returns:
Batch ID for tracking the job
"""
try:
# Read and parse the JSONL file
with open(jsonl_path, "r", encoding="utf-8") as f:
requests = [json.loads(line) for line in f]

# Convert to Anthropic format
anthropic_requests = []
for req in requests:
messages = req["body"]["messages"]
content = messages[0]["content"] # Assuming single user message
model = req["body"]["model"]
max_tokens = req["body"].get("max_tokens", 2048)
temperature = req["body"].get("temperature", 0.01)

anthropic_requests.append(
client.messages.batches.Request(
custom_id=req["custom_id"],
params=anthropic.types.MessageCreateParamsNonStreaming(
model=model,
max_tokens=max_tokens,
temperature=temperature,
messages=[{"role": "user", "content": content}],
),
)
)

# Send batch
batch = client.messages.batches.create(requests=anthropic_requests)
logger.info(f"Created Anthropic batch with ID: {batch.id}")
return batch.id

except Exception as e:
logger.error(f"Error sending batch to Anthropic: {str(e)}")
raise


def check_batch_job_status(batch_id: str) -> str:
"""
Check the status of a batch job.
Args:
batch_id: The batch ID to check
Returns:
Current processing status
"""
try:
batch = client.messages.batches.retrieve(batch_id)
return batch.processing_status
except Exception as e:
logger.error(f"Error checking batch status: {str(e)}")
raise


def simplify_anthropic_response(response_data: Any) -> Dict[str, Any]:
"""
Simplify Anthropic response to consistent format.
Args:
response_data: Raw response from Anthropic API
Returns:
Simplified response dictionary
"""
status = response_data.result.type
simplified = {
"custom_id": response_data.custom_id,
"status": status,
"content": None,
"error": None,
}

if status == "succeeded":
simplified["content"] = response_data.result.message.content[0].text
elif status == "errored":
simplified["error"] = {
"type": response_data.result.error.type,
"message": str(response_data.result.error),
}

return simplified


def download_batch_job_output(batch_id: str, output_path: str) -> Optional[str]:
"""
Download and process batch results.
Args:
batch_id: The batch ID to download
output_path: Path to save results
custom_id_mapping: Optional mapping of custom IDs
Returns:
Path to the output file
"""
try:
with open(output_path, "w", encoding="utf-8") as out_file:
for result in client.messages.batches.results(batch_id):
# Convert to dict and simplify
simplified = simplify_anthropic_response(result)

out_file.write(json.dumps(simplified, ensure_ascii=False) + "\n")

logger.info(f"Saved Anthropic batch results to {output_path}")
return output_path

except Exception as e:
logger.error(f"Error downloading batch results: {str(e)}")
raise
102 changes: 12 additions & 90 deletions automation-api/lib/pilot/send_batch_prompt_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,22 @@
import argparse
import json
import logging
import os
import re
import time
from typing import List

import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request

from lib.app_singleton import AppSingleton
from lib.config import read_config
from lib.llm.anthropic_batch_api import (
check_batch_job_status,
download_batch_job_output,
send_batch_file,
)

logger = AppSingleton().get_logger()
logger.setLevel(logging.DEBUG)

# Initialize Anthropic client
# Initialize config
read_config()
client = anthropic.Anthropic()


def read_jsonl_requests(jsonl_path: str) -> List[dict]:
"""Read JSONL file and return list of request objects."""
requests = []
with open(jsonl_path, "r") as f:
for line in f:
requests.append(json.loads(line))
return requests


def convert_to_anthropic_requests(requests: List[dict]) -> List[Request]:
"""Convert OpenAI format requests to Anthropic Request objects."""
anthropic_requests = []
for req in requests:
# Extract the message content and parameters from OpenAI format
messages = req["body"]["messages"]
content = messages[0]["content"] # Assuming single user message
model = req["body"]["model"]
max_tokens = req["body"].get("max_tokens", 2048)
temperature = req["body"].get("temperature", 0.01)

# Create Anthropic Request object
anthropic_requests.append(
Request(
custom_id=req["custom_id"],
params=MessageCreateParamsNonStreaming(
model=model,
max_tokens=max_tokens,
temperature=temperature,
messages=[{"role": "user", "content": content}],
),
)
)
return anthropic_requests


def process_batch(requests: List[Request]) -> str:
"""Send batch requests to Anthropic API."""
batch = client.messages.batches.create(requests=requests)
logger.info(f"Created batch with ID: {batch.id}")
return batch.id


def get_batch_status(batch_id: str) -> str:
"""Get the current status of a batch."""
batch = client.messages.batches.retrieve(batch_id)
return batch.processing_status


def save_batch_results(batch_id: str, output_path: str) -> None:
"""Save batch results to JSONL file."""
with open(output_path, "w") as f:
for result in client.messages.batches.results(batch_id):
result_dict = {
"custom_id": result.custom_id,
"result_type": result.result.type,
}

if result.result.type == "succeeded":
result_dict["content"] = result.result.message.content[0].text
result_dict["error"] = None
elif result.result.type == "errored":
result_dict["content"] = None
result_dict["error"] = {
"type": result.result.error.type,
"message": str(result.result.error),
}

json_line = json.dumps(result_dict, ensure_ascii=False)
f.write(f"{json_line}\n")

logger.info(f"Saved batch results to {output_path}")


if __name__ == "__main__":
Expand All @@ -116,7 +41,7 @@ def save_batch_results(batch_id: str, output_path: str) -> None:
args = parser.parse_args()

try:
# Extract model_config_id from input filename first
# Extract model_config_id from input filename
base_name = os.path.basename(args.input_jsonl)
match = re.match(r"^(.*?)-question_prompts\.jsonl$", base_name)
if not match:
Expand All @@ -140,12 +65,8 @@ def save_batch_results(batch_id: str, output_path: str) -> None:
batch_id = f.read().strip()
print(f"Existing batch ID: {batch_id}")
else:
# Read and convert requests
openai_requests = read_jsonl_requests(args.input_jsonl)
anthropic_requests = convert_to_anthropic_requests(openai_requests)

# Process batch
batch_id = process_batch(anthropic_requests)
# Send batch
batch_id = send_batch_file(args.input_jsonl)

# Save batch ID to processing file
with open(processing_file, "w") as f:
Expand All @@ -154,11 +75,12 @@ def save_batch_results(batch_id: str, output_path: str) -> None:

if args.wait:
while True:
status = get_batch_status(batch_id)
status = check_batch_job_status(batch_id)
logger.info(f"Batch status: {status}")

if status == "ended":
save_batch_results(batch_id, output_path)
# Download results
download_batch_job_output(batch_id, output_path)
print(f"Results saved to: {output_path}")
break
else:
Expand Down

0 comments on commit 4c97c50

Please sign in to comment.