From d1942a7b8ee859b328e9be5108092e805b44621f Mon Sep 17 00:00:00 2001 From: Calvin Xu Date: Sat, 7 Jun 2025 04:45:02 -0700 Subject: [PATCH 1/3] fix return type annotation of run_get_response_log_probs --- tests/adapters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/adapters.py b/tests/adapters.py index 22c6e76f..b3eefca0 100644 --- a/tests/adapters.py +++ b/tests/adapters.py @@ -90,7 +90,7 @@ def run_get_response_log_probs( input_ids: torch.Tensor, labels: torch.Tensor, return_token_entropy: bool, -) -> torch.Tensor: +) -> dict[str, torch.Tensor]: """Get the conditional log-probs of the response given the prompt, and optionally the entropy of the next token predictions. From 8c40f7e4b4fc1b1b0fbfb816f1a3986065e9cb4e Mon Sep 17 00:00:00 2001 From: Calvin Xu Date: Sat, 7 Jun 2025 05:18:10 -0700 Subject: [PATCH 2/3] fix type annotation in run_sft_microbatch_train_step --- tests/adapters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/adapters.py b/tests/adapters.py index b3eefca0..d72c4a7b 100644 --- a/tests/adapters.py +++ b/tests/adapters.py @@ -199,7 +199,7 @@ def run_sft_microbatch_train_step( policy_log_probs: torch.Tensor, response_mask: torch.Tensor, gradient_accumulation_steps: int, - normalize_constant: int | None = 1.0, + normalize_constant: float | None = 1.0, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Compute the policy gradient loss and backprop its gradients for a microbatch. """ From 894b39aca9e9c91966ddb0654d7518201d71ba38 Mon Sep 17 00:00:00 2001 From: Calvin Xu Date: Sun, 8 Jun 2025 00:45:23 -0700 Subject: [PATCH 3/3] fix return type annotation of run_compute_group_normalized_rewards --- tests/adapters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/adapters.py b/tests/adapters.py index d72c4a7b..b8c15f0c 100644 --- a/tests/adapters.py +++ b/tests/adapters.py @@ -41,7 +41,7 @@ def run_compute_group_normalized_rewards( group_size: int, advantage_eps: float, normalize_by_std: bool, -) -> tuple[torch.Tensor, dict[str, float]]: +) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: """ Compute rewards for each group of rollout responses, normalized by the group size.