Skip to content

Commit f04446d

Browse files
committed
Implementing support for dense rewards
1 parent 5c5abca commit f04446d

File tree

3 files changed

+120
-19
lines changed

3 files changed

+120
-19
lines changed

examples/ppo_redemption.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
2+
# with a sentiment reward function
3+
import json
4+
import os
5+
import sys
6+
from typing import List
7+
8+
import torch
9+
from datasets import load_dataset
10+
from transformers import pipeline, AutoTokenizer
11+
12+
import trlx
13+
from trlx.data.default_configs import TRLConfig, default_ppo_config
14+
15+
16+
def get_positive_score(scores):
17+
"Extract value associated with a positive sentiment from pipeline's output"
18+
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
19+
20+
def get_negative_score(scores):
21+
return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"]
22+
23+
24+
def main(hparams={}):
25+
# Merge sweep config with default config if given
26+
config = TRLConfig.update(default_ppo_config().to_dict(), hparams)
27+
config.method.cliprange_reward = False
28+
config.method.gen_kwargs["max_new_tokens"] = 70
29+
config.method.gen_kwargs["temperature"] = 0.3
30+
config.train.total_steps = 20000
31+
config.train.checkpoint_interval = 10000000
32+
#config.method.init_kl_coef = 0
33+
34+
if torch.cuda.is_available():
35+
device = int(os.environ.get("LOCAL_RANK", 0))
36+
else:
37+
device = -1
38+
39+
sentiment_fn = pipeline(
40+
"sentiment-analysis",
41+
"lvwerra/distilbert-imdb",
42+
top_k=2,
43+
truncation=True,
44+
batch_size=256,
45+
device=device,
46+
)
47+
48+
def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]:
49+
# Reward positively for initially negative then positive review
50+
# Reward functions should never receive padded text except for a singel EOS at the end
51+
# Reward function should return token rewards for just the response
52+
# Note: To get trajectory length, the reward fn should not tokenize the samples but should instead separately tokenizer prompts and outputs and then combine them
53+
# Also note outputs has a single EOS at end of each
54+
first_halves = [".".join(sample.split(".")[:len(sample.split(".")) // 2]) for sample in samples]
55+
negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves)))
56+
second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2:]) for sample in samples]
57+
positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves)))
58+
text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)]
59+
tok_scores = []
60+
for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores):
61+
toks = model_tok(response).input_ids
62+
tok_score = [0] * len(toks)
63+
# Hacky way of assigning intermediate score
64+
tok_score[len(tok_score) // 2] = text_score[0]
65+
tok_score[-1] = text_score[1]
66+
tok_scores.append(tok_score)
67+
return tok_scores
68+
69+
# Take few words off of movies reviews as prompts
70+
imdb = load_dataset("imdb", split="train+test")
71+
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
72+
73+
trlx.train(
74+
reward_fn=dense_reward_fn,
75+
prompts=prompts,
76+
eval_prompts=["I don't know much about Hungarian underground"] * 256,
77+
config=config,
78+
)
79+
80+
81+
if __name__ == "__main__":
82+
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
83+
main(hparams)

trlx/trainer/accelerate_base_trainer.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,7 @@ def decode(
232232
# or add one if it was trimmed with `self.stop_sequences`.
233233
# When a generation ended due to `max_new_tokens` exhaustion,
234234
# only then <pad> or <eos> token would not be present in the original sample at the end
235-
if append_eos_token and (
236-
trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id
237-
):
235+
if append_eos_token:
238236
str_output += self.tokenizer.eos_token
239237

240238
str_prompts.append(str_prompt)
@@ -427,10 +425,8 @@ def evaluate(self): # noqa: C901
427425
# in online setting, compute the reward for validation
428426
if self.reward_fn:
429427
logger.info("Computing rewards")
430-
rewards = torch.tensor(
431-
self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata),
432-
dtype=float,
433-
)
428+
rewards = self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata)
429+
rewards = torch.tensor([sum(r) if type(r) is list else r for r in rewards], dtype=float)
434430
mean_reward = rewards.mean().item()
435431
columns.append("reward")
436432
if not isinstance(rewards, list):

trlx/trainer/accelerate_ppo_trainer.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.nn.functional as F
9+
from torch.nn.utils.rnn import pad_sequence
910
import transformers
1011
from torch.utils.data import DataLoader
1112
from transformers import AutoTokenizer
@@ -297,21 +298,24 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
297298
)
298299

299300
rollout_score_time = time()
300-
all_scores = torch.tensor(
301-
self.reward_fn(
302-
samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata
303-
),
304-
dtype=torch.float,
305-
device=device,
306-
)
301+
# reward_fn should return list of rewards at each token per sample
302+
# NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed)
303+
all_scores = self.reward_fn(samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, model_tok=self.tokenizer, **metadata)
304+
all_scores = [torch.tensor(score, dtype=torch.float, device=device).view(-1,) for score in all_scores]
305+
# Pad 0 reward on the ends
306+
all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-1)
307+
max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device)
308+
307309
stats["time/rollout_score"] = time() - rollout_score_time
308310

309-
all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind())
311+
all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
310312
else:
311313
all_scores = None
314+
max_len = torch.tensor(0, dtype=torch.long, device=device)
312315

313316
if torch.distributed.is_initialized():
314-
scores = torch.empty(len(samples), device=device)
317+
torch.distributed.broadcast(max_len, 0)
318+
scores = torch.empty((len(samples), max_len), device=device)
315319
torch.distributed.scatter(scores, all_scores)
316320
else:
317321
scores = all_scores[0].clone().detach()
@@ -342,7 +346,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
342346

343347
# store statistics of the initial rollout as reference
344348
if self.ref_mean is None:
345-
self.ref_mean, self.ref_std = scores.mean(), scores.std()
349+
self.ref_mean, self.ref_std = scores.sum(dim=1).mean(), scores.sum(dim=1).std()
346350
all_scores_mean, all_scores_std = self.running_moments.update(scores)
347351
stats["rollout_scores/mean"] = all_scores_mean.item()
348352
stats["rollout_scores/std"] = all_scores_std.item()
@@ -415,6 +419,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
415419
logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:])
416420
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:])
417421
else:
422+
# NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled
418423
logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:])
419424
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])
420425

@@ -425,6 +430,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
425430
attention_mask = sample_outputs != self.tokenizer.pad_token_id
426431
start = 0
427432
else:
433+
# NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response
428434
start = prompt_tensors.shape[1] - 1
429435

430436
log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]
@@ -436,12 +442,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
436442
ref_logprobs = ref_logprobs.cpu()
437443
prompt_tensors = prompt_tensors.cpu()
438444
sample_outputs = sample_outputs.cpu()
445+
# TODO(dahoas): Why [:, :-1]? Redudant with clipping via start : ends[ix]?
446+
# Actually I think it's just wrong?
439447
values = values.cpu()[:, :-1]
440448

441449
# Get the logprobs and values, for tokens that are not padding,
442-
# from the start of the prompt up to the <eos> token, while also including the latter
450+
# from the end of the prompt up to the <eos> token, while also including the latter
443451
# (these are taken from the student model and not the reference model)
444452
ends = start + attention_mask[:, start:].sum(1) + 1
453+
# NOTE: values[i] is the value of the state after response token i
454+
# TODO(dahoas): Does it actually make sense to get the rewards one step early?
445455
all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
446456
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]
447457

@@ -451,8 +461,20 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
451461
rollout_count = 0
452462

453463
for sample_idx in range(n_samples):
464+
# To compute per token reward first add in kl penalties over trajectory
465+
# NOTE: kl_penalty[i] is kl_diff at token i+1 in the output (w/o EOS)
454466
rewards = kl_penalty[sample_idx]
455-
rewards[-1] += scores[sample_idx].cpu()
467+
# Then add in rewards
468+
if scores.shape[1] == 1:
469+
# NOTE: Final reward given at EOS token following HHH practice
470+
rewards[-1] += scores[sample_idx][0].cpu()
471+
else:
472+
score = scores[sample_idx]
473+
score_right_padding = torch.sum(score != -1)
474+
score = score[:score_right_padding].cpu()
475+
p_score = torch.zeros_like(rewards)
476+
p_score[:score.shape[0]] += score
477+
rewards += p_score
456478

457479
ppo_rl_elements.append(
458480
PPORLElement(

0 commit comments

Comments
 (0)