-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
78 lines (59 loc) · 2.55 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from Processors import ChatProcessor
import csv
# Open the CSV file
with open('PRIME_EVAL.csv', 'r') as file:
# Create a CSV reader
reader = csv.reader(file)
# Read the first line to get the headers
headers = next(reader)
# prep the output file
outfile = open('results.csv', 'a', newline='')
writer = csv.writer(outfile)
writer.writerow(['query', 'model', 'response', 'ref_answer', 'extracted_answer'])
gpt35old = 0
gpt35new = 0
gpt4old = 0
gpt4new = 0
# Iterate over the rest of the lines
for line in reader:
# Split the line into columns (a list of values)
columns = dict(zip(headers, line))
# Access column values like this:
model = columns['model']
model = model.split('/')[-1]
if model == "gpt-3.5-turbo-0301":
gpt35old = gpt35old + 1
if gpt35old > 100:
continue
if model == "gpt-3.5-turbo-0613":
gpt35new = gpt35new + 1
if gpt35new > 100:
continue
if model == "gpt-4-0314":
gpt4old = gpt4old + 1
if gpt4old > 100:
continue
if model == "gpt-4-0613":
gpt4new = gpt4new + 1
if gpt4new > 100:
continue
# temperature = columns['temperature']
# etc.
# Your code here
p = ChatProcessor(model)
p.start_messages = [{"role":"system", "content": "You are an expert mathematician and careful reasoner that produces highly accurate results."}]
p.messages = [{"role":"system", "content": "You are an expert mathematician and careful reasoner that produces highly accurate results."}]
p.CREATE_TITLES=False
r = p.generate_response(columns['query'])
# Lowercase the last_response for case insensitive match
response_lower = r['last_response'].lower()
# Check for the presence of the specific strings
matches = {"yes": "[yes]" in response_lower or "is a prime number" in response_lower,
"no": "[no]" in response_lower or "is not a prime number" in response_lower}
# Classify the response
if sum(matches.values()) == 1: # Only one match
classification = next(key for key, value in matches.items() if value)
elif sum(matches.values()) == 0 or sum(matches.values()) > 1: # No match or more than one match
classification = "ambiguous"
writer.writerow([columns['query'], model, r['last_response'], columns['ref_answer'], classification])
outfile.close()