6
6
7
7
import torch
8
8
import torch .nn .functional as F
9
+ from torch .nn .utils .rnn import pad_sequence
9
10
import transformers
10
11
from torch .utils .data import DataLoader
11
12
from transformers import AutoTokenizer
@@ -297,21 +298,24 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
297
298
)
298
299
299
300
rollout_score_time = time ()
300
- all_scores = torch .tensor (
301
- self .reward_fn (
302
- samples = all_str_samples , prompts = all_str_prompts , outputs = all_str_outputs , ** metadata
303
- ),
304
- dtype = torch .float ,
305
- device = device ,
306
- )
301
+ # reward_fn should return list of rewards at each token per sample
302
+ # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed)
303
+ all_scores = self .reward_fn (samples = all_str_samples , prompts = all_str_prompts , outputs = all_str_outputs , model_tok = self .tokenizer , ** metadata )
304
+ all_scores = [torch .tensor (score , dtype = torch .float , device = device ).view (- 1 ,) for score in all_scores ]
305
+ # Pad 0 reward on the ends
306
+ all_scores = pad_sequence (all_scores , batch_first = True , padding_value = - 1 )
307
+ max_len = torch .tensor (all_scores .shape [1 ], dtype = torch .long , device = device )
308
+
307
309
stats ["time/rollout_score" ] = time () - rollout_score_time
308
310
309
- all_scores = list (all_scores .reshape (self .accelerator .num_processes , - 1 ).unbind ())
311
+ all_scores = list (all_scores .reshape (self .accelerator .num_processes , - 1 , max_len ).unbind ())
310
312
else :
311
313
all_scores = None
314
+ max_len = torch .tensor (0 , dtype = torch .long , device = device )
312
315
313
316
if torch .distributed .is_initialized ():
314
- scores = torch .empty (len (samples ), device = device )
317
+ torch .distributed .broadcast (max_len , 0 )
318
+ scores = torch .empty ((len (samples ), max_len ), device = device )
315
319
torch .distributed .scatter (scores , all_scores )
316
320
else :
317
321
scores = all_scores [0 ].clone ().detach ()
@@ -342,7 +346,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
342
346
343
347
# store statistics of the initial rollout as reference
344
348
if self .ref_mean is None :
345
- self .ref_mean , self .ref_std = scores .mean (), scores .std ()
349
+ self .ref_mean , self .ref_std = scores .sum ( dim = 1 ). mean (), scores . sum ( dim = 1 ) .std ()
346
350
all_scores_mean , all_scores_std = self .running_moments .update (scores )
347
351
stats ["rollout_scores/mean" ] = all_scores_mean .item ()
348
352
stats ["rollout_scores/std" ] = all_scores_std .item ()
@@ -415,6 +419,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
415
419
logprobs = logprobs_of_labels (logits [:, :- 1 , :], sample_outputs [:, 1 :])
416
420
ref_logprobs = logprobs_of_labels (ref_logits [:, :- 1 , :], sample_outputs [:, 1 :])
417
421
else :
422
+ # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled
418
423
logprobs = logprobs_of_labels (logits [:, :- 1 , :], all_tokens [:, 1 :])
419
424
ref_logprobs = logprobs_of_labels (ref_logits [:, :- 1 , :], all_tokens [:, 1 :])
420
425
@@ -425,6 +430,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
425
430
attention_mask = sample_outputs != self .tokenizer .pad_token_id
426
431
start = 0
427
432
else :
433
+ # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response
428
434
start = prompt_tensors .shape [1 ] - 1
429
435
430
436
log_ratio = (logprobs - ref_logprobs ) * attention_mask [:, :- 1 ]
@@ -436,12 +442,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
436
442
ref_logprobs = ref_logprobs .cpu ()
437
443
prompt_tensors = prompt_tensors .cpu ()
438
444
sample_outputs = sample_outputs .cpu ()
445
+ # TODO(dahoas): Why [:, :-1]? Redudant with clipping via start : ends[ix]?
446
+ # Actually I think it's just wrong?
439
447
values = values .cpu ()[:, :- 1 ]
440
448
441
449
# Get the logprobs and values, for tokens that are not padding,
442
- # from the start of the prompt up to the <eos> token, while also including the latter
450
+ # from the end of the prompt up to the <eos> token, while also including the latter
443
451
# (these are taken from the student model and not the reference model)
444
452
ends = start + attention_mask [:, start :].sum (1 ) + 1
453
+ # NOTE: values[i] is the value of the state after response token i
454
+ # TODO(dahoas): Does it actually make sense to get the rewards one step early?
445
455
all_values = [values [ix , start : ends [ix ]] for ix in range (n_samples )]
446
456
all_logprobs = [logprobs [ix , start : ends [ix ]] for ix in range (n_samples )]
447
457
@@ -451,8 +461,20 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
451
461
rollout_count = 0
452
462
453
463
for sample_idx in range (n_samples ):
464
+ # To compute per token reward first add in kl penalties over trajectory
465
+ # NOTE: kl_penalty[i] is kl_diff at token i+1 in the output (w/o EOS)
454
466
rewards = kl_penalty [sample_idx ]
455
- rewards [- 1 ] += scores [sample_idx ].cpu ()
467
+ # Then add in rewards
468
+ if scores .shape [1 ] == 1 :
469
+ # NOTE: Final reward given at EOS token following HHH practice
470
+ rewards [- 1 ] += scores [sample_idx ][0 ].cpu ()
471
+ else :
472
+ score = scores [sample_idx ]
473
+ score_right_padding = torch .sum (score != - 1 )
474
+ score = score [:score_right_padding ].cpu ()
475
+ p_score = torch .zeros_like (rewards )
476
+ p_score [:score .shape [0 ]] += score
477
+ rewards += p_score
456
478
457
479
ppo_rl_elements .append (
458
480
PPORLElement (
0 commit comments