Let's generalize the easy_gcg code to optimize prompts on a dataset of (x, y) pairs, where each x is the question and y is the answer.
We want to solve u := argmax_u E [P(y | u + x)] where the expectation is taken over the dataset (x, y) ~ D.
We can start by simply aggregating gradients for the swaps in GCG over multiple elements of the batch (
|
def stochastic_easy_gcg_qa_ids(question_ids: list[torch.Tensor], |
).
All that remains is to create an efficient batch_compute_score_dataset() function to compute the scores of each potential new prompt w.r.t. the dataset (
|
alt_scores = batch_compute_score_dataset(alt_prompt_ids, |
)
Let's generalize the easy_gcg code to optimize prompts on a dataset of
(x, y)pairs, where eachxis thequestionandyis theanswer.We want to solve
u := argmax_u E [P(y | u + x)]where the expectation is taken over the dataset(x, y) ~ D.We can start by simply aggregating gradients for the swaps in GCG over multiple elements of the batch (
Magic_Words/magic_words/easy_gcg.py
Line 178 in 32840cd
All that remains is to create an efficient
batch_compute_score_dataset()function to compute the scores of each potential new prompt w.r.t. the dataset (Magic_Words/magic_words/easy_gcg.py
Line 263 in 32840cd