Skip to content

Commit

Permalink
support vertex format
Browse files Browse the repository at this point in the history
  • Loading branch information
semio committed Jan 14, 2025
1 parent 2f92f66 commit 0afcccb
Showing 1 changed file with 92 additions and 20 deletions.
112 changes: 92 additions & 20 deletions automation-api/lib/pilot/generate_eval_prompts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import argparse
import json
import os
from typing import Dict
from enum import Enum
from typing import Dict, List, Tuple

import polars as pl

from lib.app_singleton import AppSingleton


class JsonlFormat(Enum):
OPENAI = "openai"
VERTEX = "vertex"


logger = AppSingleton().get_logger()


Expand Down Expand Up @@ -157,7 +164,8 @@ def generate_eval_prompts(
output_path: str,
model: str = "gpt-4",
temperature: float = 0.0,
) -> None:
format: JsonlFormat = JsonlFormat.OPENAI,
) -> List[Tuple[str, str]]:
"""
Generate evaluation prompts for each response and metric.
Expand All @@ -169,6 +177,8 @@ def generate_eval_prompts(
model: Model to use for evaluation
temperature: Temperature setting for generation
"""
prompt_id_mapping = []

with open(output_path, "w", encoding="utf-8") as f:
for metric_row in metrics.iter_rows(named=True):
prompt_template = metric_row["prompt"]
Expand Down Expand Up @@ -197,26 +207,65 @@ def generate_eval_prompts(
option_c_correctness=question_row["option_c_correctness"],
)

# Create evaluation request object
eval_request = {
"model": model,
"messages": [{"role": "user", "content": eval_prompt}],
"temperature": temperature,
"max_tokens": 2000,
}

# Create the full request object with custom ID
request_obj = {
"custom_id": f"{prompt_id}-eval-{metric_id}",
"method": "POST",
"url": "/v1/chat/completions",
"body": eval_request,
}
custom_id = f"{prompt_id}-eval-{metric_id}"
prompt_id_mapping.append((custom_id, eval_prompt))

if format == JsonlFormat.OPENAI:
# Create evaluation request object for OpenAI
eval_request = {
"model": model,
"messages": [{"role": "user", "content": eval_prompt}],
"temperature": temperature,
"max_tokens": 2000,
}

# Create the full request object with custom ID
request_obj = {
"custom_id": custom_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": eval_request,
}
else: # Vertex format
request_obj = {
"request": {
"contents": [
{
"role": "user",
"parts": [{"text": eval_prompt}],
}
],
"generationConfig": {
"temperature": temperature,
"max_output_tokens": 2000,
},
"safety_settings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH",
},
],
}
}

# Write to output file
json_line = json.dumps(request_obj, ensure_ascii=False)
f.write(f"{json_line}\n")

return prompt_id_mapping


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate evaluation prompts")
Expand Down Expand Up @@ -244,6 +293,13 @@ def generate_eval_prompts(
default=0.0,
help="Temperature setting for generation",
)
parser.add_argument(
"--format",
type=str,
choices=[f.value for f in JsonlFormat],
default=JsonlFormat.OPENAI.value,
help="Format of JSONL output (openai or vertex)",
)
args = parser.parse_args()

# Construct input paths
Expand All @@ -265,20 +321,36 @@ def generate_eval_prompts(
# Read responses
responses = read_responses(args.response_file)

# Generate output path based on response file
# Generate output path based on response file and model
response_basename = os.path.splitext(os.path.basename(args.response_file))[0]
model_suffix = args.model.replace("/", "-").replace(
".", ""
) # Clean up model name for filename
output_path = os.path.join(
args.base_path, f"{response_basename}-eval-prompts.jsonl"
args.base_path, f"{response_basename}-eval-prompts-{model_suffix}.jsonl"
)

# Generate evaluation prompts
generate_eval_prompts(
prompt_id_mapping = generate_eval_prompts(
combined_questions,
responses,
metrics,
output_path,
model=args.model,
temperature=args.temperature,
format=JsonlFormat(args.format),
)

print(f"Generated evaluation prompts in {output_path}")

# Save prompt ID mapping if using Vertex format
if args.format == JsonlFormat.VERTEX.value:
mapping_df = pl.DataFrame(
{
"prompt_id": [x[0] for x in prompt_id_mapping],
"prompt_text": [x[1] for x in prompt_id_mapping],
}
)
mapping_path = output_path.replace(".jsonl", "-prompt-mapping.csv")
mapping_df.write_csv(mapping_path)
print(f"Generated prompt ID mapping in {mapping_path}")

0 comments on commit 0afcccb

Please sign in to comment.