-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathsl_main.py
236 lines (191 loc) · 10.2 KB
/
sl_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import streamlit as st
import pandas as pd
import numpy as np
from pb import create_population, init_run, run_for_n
from pb.mutation_prompts import mutation_prompts
from pb.thinking_styles import thinking_styles
import os
import logging
import argparse
import asyncio
import decimal
import time
from dotenv import load_dotenv
from rich import print
import cohere
load_dotenv() # load environment variables
st.set_page_config(layout="wide")
# Init state
if 'population' not in st.session_state:
st.session_state['population'] = None
if 'size' not in st.session_state:
st.session_state['size'] = 0
if 'evals' not in st.session_state:
st.session_state['evals'] = 0
if 'calls' not in st.session_state:
st.session_state['calls'] = 0
if 'generations' not in st.session_state:
st.session_state['generations'] = 0
if "fitness_history" not in st.session_state:
st.session_state['fitness_history'] = []
if "elite_fitness_history" not in st.session_state:
st.session_state['elite_fitness_history'] = []
if 'current_generation' not in st.session_state:
st.session_state['current_generation'] = 0
if 'running' not in st.session_state:
st.session_state['running'] = False
if 'histogram_data' not in st.session_state:
st.session_state['histogram_data'] = {}
if 'COHERE_API_KEY' not in st.session_state:
if 'COHERE_API_KEY' in os.environ:
st.session_state['COHERE_API_KEY'] = os.environ['COHERE_API_KEY']
else:
st.session_state['COHERE_API_KEY'] = ""
# thinking_styles dataframe
ts_df = pd.DataFrame(
thinking_styles
)
# mutation prompts dataframe
mp_df = pd.DataFrame(
mutation_prompts
)
st.title('PromptBreeder + Cohere')
st.markdown(f""""PROMPTBREEDER, a general-purpose self-referential self-improvement mechanism that evolves and adapts prompts for a given domain.
Driven by an LLM, Promptbreeder mutates a population of task-prompts, evaluates them for fitness on a training set, and repeats this process
over multiple generations to evolve task-prompts. Crucially, the mutation of these task-prompts is governed by mutation-prompts that the LLM
generates and improves throughout evolution in a self-referential way." - https://arxiv.org/pdf/2309.16797.pdf
We start by picking mutation prompts (M) and thinking styles (T). From that, a initial population of task-prompts (P) is generated according
to \n
P = LLM("[T] [M] INSTRUCTION: [problem_description] INSTRUCTION MUTANT = ")
Then, the fitness level of these task-prompts is evaluated against a random sample of N questions from the gsm8k dataset (any dataset can be used).
N is determined below, in the "number of examples to evaluate for fitness calculation" input.
Each generation is run through a standard binary tournament genetic algorithm. Basically, randomly sample 2 units from the population, and
compare fitness levels. Whoever has the lower fitness loses, and has a random mutation applied to them.
There are 12 mutations outlined by the promptbreeder paper. Only 9 are implemented here, as a few are ambiguous. Will add a "experimental mode"
where you can use those extra 3 mutations, but I can't promise they will be exactly the same as DeepMind's implementations.
""")
problem_description = st.text_input("problem description", value="Solve the math word problem, giving your answer as an arabic numeral.", key="pd")
st.session_state.COHERE_API_KEY = st.text_input("Cohere PROD key", key="ch", type='password')
col1, col2, = st.columns(2)
with col1:
st.session_state.evals = st.number_input("Number of examples to evaluate for fitness calculation", value=4)
with col2:
st.session_state.generations = st.number_input("Number of generations to run for", value=5)
def dataframe_with_selections(mp_df, ts_df):
mp_df_with_selections = mp_df.copy()
ts_df_with_selections = ts_df.copy()
mp_df_with_selections.insert(0, "Select", False)
ts_df_with_selections.insert(0, "Select", False)
# Get dataframe row-selections from user with st.data_editor
col1, col2, = st.columns(2)
with col1:
st.header("mutation prompts (M)")
mp_edited_df = st.data_editor(
mp_df_with_selections,
hide_index=True,
column_config={"Select": st.column_config.CheckboxColumn(required=True)},
disabled=mp_df.columns,
)
with col2:
st.header("thinking styles (T)")
ts_edited_df = st.data_editor(
ts_df_with_selections,
hide_index=True,
column_config={"Select": st.column_config.CheckboxColumn(required=True)},
disabled=ts_df.columns,
)
# Filter the dataframe using the temporary column, then drop the column
mp_selected_rows = mp_edited_df[mp_edited_df.Select]
ts_selected_rows = ts_edited_df[ts_edited_df.Select]
return mp_selected_rows, ts_selected_rows
mp_selected_rows, ts_selected_rows = dataframe_with_selections(mp_df, ts_df)
st.session_state.size = len(mp_selected_rows) * len(ts_selected_rows)
st.session_state.calls = (st.session_state.size*st.session_state.evals + st.session_state.size // 2)*st.session_state.generations
second_button = st.button(f"run for {st.session_state.generations} generations", disabled=(not (st.session_state.size > 0)))
if second_button:
st.session_state.population = create_population(tp_set=ts_selected_rows['0'].tolist(), mutator_set=mp_selected_rows['0'].tolist(), problem_description=problem_description)
st.session_state.size = st.session_state.population.size
st.session_state.calls = st.session_state.evals*st.session_state.generations
st.session_state.start_time = time.time()
st.session_state.running = True
co = cohere.Client(api_key=st.session_state.COHERE_API_KEY, num_workers=st.session_state.evals, max_retries=5, timeout=60) #override the 2 min timeout with 60s. The APIs performance varies heavily.
st.session_state.population = init_run(st.session_state.population, co, st.session_state.evals)
fitness_avg = 0
elite_fitness = 0
for i in range(st.session_state.evals):
temp = decimal.Decimal(i / st.session_state.evals)
roundedNumber = temp.quantize(decimal.Decimal('0.00'))
st.session_state.histogram_data[str(roundedNumber)] = 0
for j in st.session_state.population.units:
temp = j.model_dump()['fitness']
decimalValue = decimal.Decimal(temp)
roundedNumber = decimalValue.quantize(decimal.Decimal('0.00'))
fitness_avg += roundedNumber
if roundedNumber > elite_fitness:
elite_fitness = float(roundedNumber)
if str(roundedNumber) not in st.session_state.histogram_data.keys():
st.session_state.histogram_data[str(roundedNumber)] = 1
else:
st.session_state.histogram_data[str(roundedNumber)] += 1
st.session_state.elite_fitness_history.append(elite_fitness)
elite_fitness = 0
st.session_state.fitness_history.append(float(fitness_avg) / st.session_state.size)
outputs = st.container()
fitness_avg = 0
elite_fitness = 0
with outputs:
pop_hist_header = st.empty()
fit_hist = st.empty()
historical_fitness_header = st.empty()
fit_line = st.empty()
current_pop_header = st.empty()
population_table = st.empty()
while st.session_state.current_generation < st.session_state.generations:
st.session_state.population = run_for_n(1, st.session_state.population, co, st.session_state.evals)
st.session_state.current_generation += 1
fitness_avg = 0
st.session_state.histogram_data = {}
for i in range(st.session_state.evals):
temp = decimal.Decimal(i / st.session_state.evals)
roundedNumber = temp.quantize(decimal.Decimal('0.00'))
st.session_state.histogram_data[str(roundedNumber)] = 0
for j in st.session_state.population.units:
temp = j.model_dump()['fitness']
decimalValue = decimal.Decimal(temp)
roundedNumber = decimalValue.quantize(decimal.Decimal('0.00'))
fitness_avg += roundedNumber
if roundedNumber > elite_fitness:
elite_fitness = float(roundedNumber)
if str(roundedNumber) not in st.session_state.histogram_data.keys():
st.session_state.histogram_data[str(roundedNumber)] = 1
else:
st.session_state.histogram_data[str(roundedNumber)] += 1
st.session_state.elite_fitness_history.append(elite_fitness)
elite_fitness = 0
st.session_state.fitness_history.append(float(fitness_avg) / st.session_state.size)
pop_hist_header = st.header(f"Population {st.session_state.current_generation} Histogram")
fit_hist = st.bar_chart(data=st.session_state.histogram_data)
col1, col2, = st.columns(2)
with col1:
historical_fitness_header = st.header("Historical fitness average")
fit_line = st.line_chart(data=st.session_state.fitness_history)
with col2:
elite_fitness_header = st.header("Historical elite fitness")
elite_line = st.line_chart(data=st.session_state.elite_fitness_history)
current_pop_header=st.header(f"Population {st.session_state.current_generation}")
population_table = st.dataframe(pd.DataFrame([s.model_dump() for s in st.session_state.population.units]))
st.session_state.running = False
# iterate and update graph each time
with st.sidebar:
st.title("Population Information")
st.header("problem description")
st.text(problem_description)
st.metric("Population Size", st.session_state.size)
st.metric("Fitness evals", st.session_state.evals)
st.metric("Generations", st.session_state.generations)
st.session_state.calls = (st.session_state.size*st.session_state.evals + st.session_state.size // 2)*st.session_state.generations
st.metric("Calls", st.session_state.calls)
st.metric("Approximate runtime", str(round(st.session_state.calls * 1.17, 2))+"s")
st.metric("Approximate cost", "$"+str(round(st.session_state.calls * 0.00234,2)))
st.title("Current Information")
st.metric("Current generation", str(st.session_state.current_generation))