Skip to content

Commit be8bc1a

Browse files
SharathRaparthyDahoas
authored andcommitted
summed along dim=1
1 parent f58170d commit be8bc1a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

trlx/trainer/accelerate_ppo_trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
515515

516516
@staticmethod
517517
def get_topk_indices(input_tensor, window_size: int, k: int, device):
518+
# Sum the scores along dim 1
519+
input_tensor = input_tensor.sum(1)
518520
# Use unfold to create the sliding windows
519521
unfolded = input_tensor.unfold(0, window_size, window_size)
520522
# Find the topk values and indices along the unfolded dimension

0 commit comments

Comments
 (0)