diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7a816ad686..121f41a2f0 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -274,14 +274,16 @@ class Args: mask_truncated_completions: bool = False """Whether to mask out truncated completions. Also called overlong filtering, from DAPO (https://arxiv.org/abs/2503.14476).""" - fill_completions: bool = False - """Whether to refill the batchsize with after filtering.""" - record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" use_vllm_logprobs: bool = False """whether to use vLLM's logprobs for training instead of calculating them via forward pass""" + active_sampling: bool = False + """Whether to refill the batch with *new prompts/completions* after filtering.""" + active_sampling_max_attempts: int = 3 + """How many times to attempt to fill""" + # Reward # -- r1 style format reward apply_r1_style_format_reward: bool = False @@ -576,14 +578,18 @@ def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None self.rng.shuffle(self.data) # Ensure the effective dataset size is divisible by batch_size - self.effective_size = len(self.data) - (len(self.data) % batch_size) + self._update_effective_size() def __iter__(self) -> Iterator[List[int]]: return self def __next__(self) -> List[int]: + if self.effective_size == 0: + raise StopIteration("No data available to sample from the iterator.") + if self.index >= self.effective_size: self.index = 0 + self._update_effective_size() self.epoch_number += 1 self.rng.shuffle(self.data) @@ -608,6 +614,10 @@ def set_state(self, state: Dict[str, Any]) -> None: self.epoch_number = state["epoch_number"] self.data = state["data"].copy() self.rng.bit_generator.state = state["rng_state"] + self._update_effective_size() + + def _update_effective_size(self) -> None: + self.effective_size = len(self.data) - (len(self.data) % self.batch_size) @ray.remote(num_gpus=1) @@ -1556,7 +1566,7 @@ def accumulate_inference_batches( model_dims: utils.ModelDims, actor_manager=None, timeout: Optional[float] = None, -) -> tuple[GenerationResult, Batch]: +) -> tuple[GenerationResult, Batch, list[int], list[int]]: """Accumulate multiple inference results into a single training batch. Args: @@ -1686,17 +1696,70 @@ def accumulate_inference_batches( if actor_manager is not None: ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - # Note: We don't have dataset_indices here, but they're not needed for the returned batch batch = Batch( queries=all_queries, ground_truths=all_ground_truths, datasets=all_datasets, raw_queries=all_raw_queries, - indices=None, # Not meaningful for combined results + indices=None, ) return combined_result, batch, prompt_lengths, response_lengths +def combine_reward_metrics(metric_records: list[tuple[Dict[str, Any], int]]) -> Dict[str, Any]: + buckets = defaultdict(list) + total_num_records = 0 + for metrics, num_records in metric_records: + total_num_records += num_records + for key, value in metrics.items(): + buckets[key].append((value, num_records)) + + combined: Dict[str, Any] = {} + for key, records in buckets.items(): + sample_value = records[0][0] + if isinstance(sample_value, np.ndarray): + combined[key] = [x for value, _ in records for x in value] + elif isinstance(sample_value, (list, tuple)): + concatenated: list[Any] = [] + for value, _ in records: + concatenated.extend(list(value)) + combined[key] = concatenated + elif isinstance(sample_value, (int, float, bool, np.integer, np.floating)): + weighted_sum = sum(float(value) * num_records for value, num_records in records) + combined[key] = weighted_sum / total_num_records if total_num_records > 0 else sample_value + else: + # Fallback: keep the latest value if aggregation strategy is unclear. + combined[key] = records[-1][0] + return combined + + +@dataclass +class FilteredChunk: + result: GenerationResult + batch: Batch + kept_indices: list[int] + prompt_lengths: list[int] + response_lengths: list[int] + scores: np.ndarray + advantages: np.ndarray + reward_metrics: tuple[Dict[str, Any], int] + + +@dataclass +class AggregationState: + chunks: list[FilteredChunk] = field(default_factory=list) + token_stats: list[TokenStatistics] = field(default_factory=list) + prompts_filtered_counts: list[int] = field(default_factory=list) + + def add_chunk( + self, *, chunk: FilteredChunk, token_statistics: Optional[TokenStatistics], prompts_filtered: int + ) -> None: + self.chunks.append(chunk) + if token_statistics is not None: + self.token_stats.append(token_statistics) + self.prompts_filtered_counts.append(prompts_filtered) + + def data_preparation_thread( reward_fn: Callable, inference_results_Q: ray_queue.Queue, # Ray queue @@ -1711,188 +1774,328 @@ def data_preparation_thread( model_dims: utils.ModelDims = None, ): for training_step in range(resume_training_step, num_training_steps + 1): - # Streaming accumulation: collect results as they arrive - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, prompt_lengths, response_lengths = accumulate_inference_batches( - inference_results_Q, - pending_queries_map, - args, - generation_config, - num_prompts=args.num_unique_prompts_rollout, - model_dims=model_dims, - actor_manager=actor_manager, - ) - if isinstance(result, ShutdownSentinel): - logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") - return + per_prompt_rollout = args.num_samples_per_prompt_rollout + target_prompt_count = args.num_unique_prompts_rollout + + aggregation = AggregationState() + prompts_to_request = target_prompt_count + getting_response_time = 0.0 + + fill_iteration = 0 + while prompts_to_request > 0: + fill_iteration += 1 + # Streaming accumulation: collect results as they arrive + with Timer(f"🚀 [Data Preparation Thread] Getting {prompts_to_request} response ids") as timer: + result, batch, prompt_lengths, response_lengths = accumulate_inference_batches( + inference_results_Q, + pending_queries_map, + args, + generation_config, + num_prompts=prompts_to_request, + model_dims=model_dims, + actor_manager=actor_manager, + ) + if isinstance(result, ShutdownSentinel): + logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") + return + + getting_response_time += timer.duration + + if args.num_samples_per_prompt_rollout > 1: + batch = Batch( + queries=repeat_each(batch.queries, per_prompt_rollout), + ground_truths=repeat_each(batch.ground_truths, per_prompt_rollout), + datasets=repeat_each(batch.datasets, per_prompt_rollout), + raw_queries=repeat_each(batch.raw_queries, per_prompt_rollout), + indices=repeat_each(batch.indices, per_prompt_rollout) if batch.indices else None, + ) - getting_response_time = timer.duration + for i in range(len(result.finish_reasons)): + if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: + result.responses[i].append(tokenizer.eos_token_id) + result.masks[i].append(1) + result.logprobs[i].append(float("nan")) + + with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): + decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) + decoded_queries = batch.raw_queries + + with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"): + scores, reward_metrics = asyncio.run( + reward_fn( + result.responses, + decoded_responses, + batch, + result.finish_reasons, + result.request_info, + decoded_queries, + ) + ) + scores = np.array(scores) + scores_per_prompt = scores.reshape(-1, per_prompt_rollout) + mean_grouped_rewards = scores_per_prompt.mean(axis=-1) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, per_prompt_rollout, axis=0) + std_grouped_rewards = scores_per_prompt.std(axis=-1) + std_grouped_rewards = np.repeat(std_grouped_rewards, per_prompt_rollout, axis=0) + if args.advantage_normalization_type == "standard": + advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + elif args.advantage_normalization_type == "centered": + advantages = scores - mean_grouped_rewards + else: + raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") + + # Filtering (zero-gradient, truncated completions) + with Timer("📦 [Data Preparation Thread] Filtering sequences"): + max_possible_score = 0 + if args.apply_verifiable_reward: + max_possible_score += args.verification_reward + if args.apply_r1_style_format_reward and args.additive_format_reward: + max_possible_score += args.r1_style_format_reward + + non_zero_std_mask = scores_per_prompt.std(axis=-1) != 0 + expanded_mask = np.repeat(non_zero_std_mask, per_prompt_rollout) + non_zero_indices = np.where(expanded_mask)[0] + + num_zero_std_prompts = (~non_zero_std_mask).sum() + num_filtered_responses = len(scores) - len(non_zero_indices) + if num_filtered_responses > 0: + logger.info( + f"[Zero-gradient filtering] Filtered {num_zero_std_prompts} prompts with zero std " + f"({num_filtered_responses} responses). Retention rate: {len(non_zero_indices) / len(scores):.2%}" + ) - # ------------------------------------------------------------------------------------------------ - # Pack sequences - if args.num_samples_per_prompt_rollout > 1: - batch = Batch( - queries=repeat_each(batch.queries, args.num_samples_per_prompt_rollout), - ground_truths=repeat_each(batch.ground_truths, args.num_samples_per_prompt_rollout), - datasets=repeat_each(batch.datasets, args.num_samples_per_prompt_rollout), - raw_queries=repeat_each(batch.raw_queries, args.num_samples_per_prompt_rollout), - indices=repeat_each(batch.indices, args.num_samples_per_prompt_rollout) if batch.indices else None, + advantages = advantages[non_zero_indices] + scores = scores[non_zero_indices] + kept_indices = non_zero_indices.tolist() + batch = batch[kept_indices] + + prompt_indices_kept = np.where(non_zero_std_mask)[0].tolist() + prompt_lengths_kept = [prompt_lengths[i] for i in prompt_indices_kept] + response_lengths_kept = [response_lengths[i] for i in kept_indices] + + chunk = FilteredChunk( + result=result, + batch=batch, + kept_indices=kept_indices, + prompt_lengths=prompt_lengths_kept, + response_lengths=response_lengths_kept, + scores=scores, + advantages=advantages, + reward_metrics=(reward_metrics, len(scores)), ) - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - decoded_queries = batch.raw_queries - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons + aggregation.add_chunk( + chunk=chunk, token_statistics=result.token_statistics, prompts_filtered=int(num_zero_std_prompts) ) - with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"): - scores, reward_metrics = asyncio.run( - reward_fn( - result.responses, - decoded_responses, - batch, - result.finish_reasons, - result.request_info, - decoded_queries, + prompts_kept_this_iter = len(scores) // per_prompt_rollout if per_prompt_rollout > 0 else 0 + + if not args.active_sampling: + break + + collected_samples = sum(len(chunk.scores) for chunk in aggregation.chunks) + collected_prompts = collected_samples // per_prompt_rollout if per_prompt_rollout > 0 else 0 + prompts_remaining = target_prompt_count - collected_prompts + if prompts_remaining <= 0: + break + + if prompts_kept_this_iter == 0 and fill_iteration > args.active_sampling_max_attempts: + total_responses = sum(len(chunk.kept_indices) for chunk in aggregation.chunks) + logger.warning( + "[Active fill completions] Unable to collect non-zero advantage prompts in iteration %d; " + "set as max by args.active_sampling_max_attempts, proceeding with existing batch of size %d.", + fill_iteration, + total_responses, ) + break + + prompts_to_request = prompts_remaining + + # Build aggregated containers from collected data + if aggregation.chunks: + queries: list[List[int]] = [] + ground_truths: list[List[int]] = [] + datasets: list[str] = [] + raw_queries_list: Optional[list[str]] = None + indices_list: Optional[list[int]] = None + + raw_queries_present = aggregation.chunks[0].batch.raw_queries is not None + indices_present = aggregation.chunks[0].batch.indices is not None + if raw_queries_present: + raw_queries_list = [] + if indices_present: + indices_list = [] + + for chunk in aggregation.chunks: + chunk_batch = chunk.batch + queries.extend(chunk_batch.queries) + ground_truths.extend(chunk_batch.ground_truths) + datasets.extend(chunk_batch.datasets) + if raw_queries_present: + assert chunk_batch.raw_queries is not None + raw_queries_list.extend(chunk_batch.raw_queries) + if indices_present: + assert chunk_batch.indices is not None + indices_list.extend(chunk_batch.indices) + + batch = Batch( + queries=queries, + ground_truths=ground_truths, + datasets=datasets, + raw_queries=raw_queries_list if raw_queries_present else None, + indices=indices_list if indices_present else None, ) - scores = np.array(scores) - scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - if args.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif args.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards + else: + batch = Batch(queries=[], ground_truths=[], datasets=[], raw_queries=[], indices=[]) + + prompt_lengths = [length for chunk in aggregation.chunks for length in chunk.prompt_lengths] + response_lengths = [length for chunk in aggregation.chunks for length in chunk.response_lengths] + + responses: list[List[int]] = [] + masks: list[List[int]] = [] + finish_reasons: list[str] = [] + vllm_logprobs: list[List[float]] = [] + num_calls: list[int] = [] + timeouts: list[int] = [] + tool_errors: list[str] = [] + tool_outputs: list[str] = [] + tool_runtimes: list[float] = [] + tool_calleds: list[bool] = [] + + for chunk in aggregation.chunks: + result_chunk = chunk.result + indices = chunk.kept_indices + + responses.extend([result_chunk.responses[i] for i in indices]) + masks.extend([result_chunk.masks[i] for i in indices]) + finish_reasons.extend([result_chunk.finish_reasons[i] for i in indices]) + if result_chunk.logprobs is not None: + vllm_logprobs.extend([result_chunk.logprobs[i] for i in indices]) else: - raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") + vllm_logprobs.extend([[] for _ in indices]) + + request_info = result_chunk.request_info + num_calls.extend([request_info.num_calls[i] for i in indices]) + timeouts.extend([request_info.timeouts[i] for i in indices]) + tool_errors.extend([request_info.tool_errors[i] for i in indices]) + tool_outputs.extend([request_info.tool_outputs[i] for i in indices]) + tool_runtimes.extend([request_info.tool_runtimes[i] for i in indices]) + tool_calleds.extend([request_info.tool_calleds[i] for i in indices]) + + token_stats = aggregation.token_stats + total_prompt_tokens = sum(stat.num_prompt_tokens for stat in token_stats) + total_response_tokens = sum(stat.num_response_tokens for stat in token_stats) + total_generation_time = sum(stat.generation_time for stat in token_stats) + earliest_start_candidates = [ + stat.earliest_start_time for stat in token_stats if stat.earliest_start_time is not None + ] + earliest_start_time = min(earliest_start_candidates) if earliest_start_candidates else None - with Timer("📦 [Data Preparation Thread] Filtering sequences"): - # Here we get the max possible score for each prompt, and see how many prompts are unsolved - max_possible_score = 0 - if args.apply_verifiable_reward: - max_possible_score += args.verification_reward - if args.apply_r1_style_format_reward and args.additive_format_reward: - max_possible_score += args.r1_style_format_reward - unsolved_batch_size_ratio = ((scores != max_possible_score) > 0).sum() / len(scores) - # In GRPO, if the std of grouped rewards is 0, then there is zero gradient for the batch - # of args.num_samples_per_prompt_rollout responses, so we need to filter out those batches - non_zero_std_mask = scores_per_prompt.std(axis=-1) != 0 - real_batch_size_ratio = non_zero_std_mask.sum() * args.num_samples_per_prompt_rollout / len(scores) - expanded_mask = np.repeat(non_zero_std_mask, args.num_samples_per_prompt_rollout) - non_zero_gradient_index = np.where(expanded_mask)[0] - - # Log zero-gradient filtering statistics - num_zero_std_prompts = (~non_zero_std_mask).sum() - num_filtered_responses = len(scores) - len(non_zero_gradient_index) - if num_filtered_responses > 0: + token_statistics = TokenStatistics( + num_prompt_tokens=total_prompt_tokens, + num_response_tokens=total_response_tokens, + generation_time=total_generation_time, + earliest_start_time=earliest_start_time, + ) + result = GenerationResult( + responses=responses, + finish_reasons=finish_reasons, + masks=masks, + request_info=RequestInfo( + num_calls=num_calls, + timeouts=timeouts, + tool_errors=tool_errors, + tool_outputs=tool_outputs, + tool_runtimes=tool_runtimes, + tool_calleds=tool_calleds, + ), + token_statistics=token_statistics, + logprobs=vllm_logprobs, + ) + + good_outputs = [ + len(result.request_info.tool_outputs[i]) > 0 + and result.request_info.tool_calleds[i] + and not result.request_info.timeouts[i] + and not result.request_info.tool_errors[i] + for i in range(len(result.request_info.tool_outputs)) + ] + total_completions = len(finish_reasons) + total_stop_completions = sum(int(reason == "stop") for reason in finish_reasons) + stop_rate = (total_stop_completions / total_completions) if total_completions > 0 else 0 + total_prompts_filtered = sum(aggregation.prompts_filtered_counts) + + scores = np.concatenate([chunk.scores for chunk in aggregation.chunks]) if aggregation.chunks else np.array([]) + advantages = ( + np.concatenate([chunk.advantages for chunk in aggregation.chunks]) if aggregation.chunks else np.array([]) + ) + reward_metrics = combine_reward_metrics([chunk.reward_metrics for chunk in aggregation.chunks]) + + if args.mask_truncated_completions and len(finish_reasons) > 0: + truncated_mask = np.array([reason != "stop" for reason in finish_reasons], dtype=bool) + num_truncated = int(truncated_mask.sum()) + if num_truncated > 0: + if len(scores) > 0: + scores = scores.copy() + scores[truncated_mask] = 0.0 + if len(advantages) > 0: + advantages = advantages.copy() + advantages[truncated_mask] = 0.0 + for idx in np.where(truncated_mask)[0]: + masks[idx] = [0 for _ in masks[idx]] logger.info( - f"[Zero-gradient filtering] Filtered {num_zero_std_prompts} prompts with zero std " - f"({num_filtered_responses} responses). Retention rate: {len(non_zero_gradient_index) / len(scores):.2%}" + f"[Truncated completions masking] Zeroed scores/advantages and response masks for {num_truncated} responses that didn't finish with 'stop'." ) - advantages = advantages[non_zero_gradient_index] - original_batch_size = len(scores) - scores = scores[non_zero_gradient_index] - responses = [result.responses[i] for i in non_zero_gradient_index] - masks = [result.masks[i] for i in non_zero_gradient_index] - batch = batch[non_zero_gradient_index.tolist()] - finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index] - vllm_logprobs = [result.logprobs[i] for i in non_zero_gradient_index] - if args.mask_truncated_completions: - stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"]) - num_truncated = len(finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - responses = [responses[i] for i in stop_idxes] - masks = [masks[i] for i in stop_idxes] - batch = batch[stop_idxes.tolist()] - finish_reasons = [finish_reasons[i] for i in stop_idxes] - vllm_logprobs = [vllm_logprobs[i] for i in stop_idxes] - - if args.fill_completions: - with Timer("⏱ [Data Preparation Thread] Refill completions"): - current_batch_size = len(scores) - original_prompt_cnt = original_batch_size // args.num_samples_per_prompt_rollout - current_prompt_cnt = current_batch_size // args.num_samples_per_prompt_rollout - need_to_fill_prompt = original_prompt_cnt - current_prompt_cnt - k = args.num_samples_per_prompt_rollout - - if need_to_fill_prompt > 0 and current_prompt_cnt > 0: - scores_matrix = scores.reshape(current_prompt_cnt, k) - stds = scores_matrix.std(axis=1) + 1e-8 - probs = stds / stds.sum() - - logger.info( - f"[Refill completions] Need to fill {need_to_fill_prompt} prompts to maintain batch size. " - f"Original: {original_prompt_cnt}, Current: {current_prompt_cnt}" - ) - - sampled_prompt_ids = np.random.choice( - current_prompt_cnt, size=need_to_fill_prompt, replace=True, p=probs - ) - - sampled_indices = [] - for pid in sampled_prompt_ids: - start = pid * k - sampled_indices.extend(range(start, start + k)) + if len(scores) == 0: + logger.warning(f"No responses with non-zero advantages in batch {training_step}.") - advantages = np.concatenate([advantages, advantages[sampled_indices]]) - scores = np.concatenate([scores, scores[sampled_indices]]) - responses += [responses[i] for i in sampled_indices] - masks += [masks[i] for i in sampled_indices] - - sampled_batch = batch[sampled_indices] + if len(scores) == 0 or per_prompt_rollout == 0: + scores_per_prompt = np.zeros((0, per_prompt_rollout)) + unsolved_batch_size_ratio = 0.0 + real_batch_size_ratio = 0.0 + else: + scores_per_prompt = scores.reshape(-1, per_prompt_rollout) - batch = Batch( - queries=batch.queries + sampled_batch.queries, - ground_truths=batch.ground_truths + sampled_batch.ground_truths, - datasets=batch.datasets + sampled_batch.datasets, - indices=batch.indices + sampled_batch.indices if batch.indices is not None else None, - ) + max_possible_score = 0 + if args.apply_verifiable_reward: + max_possible_score += args.verification_reward + if args.apply_r1_style_format_reward and args.additive_format_reward: + max_possible_score += args.r1_style_format_reward - finish_reasons += [finish_reasons[i] for i in sampled_indices] - vllm_logprobs += [vllm_logprobs[i] for i in sampled_indices] + unsolved_batch_size_ratio = ((scores != max_possible_score) > 0).sum() / len(scores) + original_batch_size = target_prompt_count * per_prompt_rollout + real_batch_size_ratio = len(scores) / original_batch_size if original_batch_size > 0 else 0 - logger.info( - f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses" - ) + responses_list = responses + masks_list = masks + finish_reasons_list = finish_reasons + vllm_logprobs_list = vllm_logprobs - # Count groups with all zero rewards - all_zero_groups = (scores_per_prompt == 0).all(axis=-1).sum() - total_groups = len(scores_per_prompt) - logger.info( - f"[Reward Summary] Groups with all zero rewards: {all_zero_groups}/{total_groups} " - f"({all_zero_groups / total_groups:.1%})" - ) + # Count groups with all zero rewards + all_zero_groups = (scores_per_prompt == 0).all(axis=-1).sum() + max_possible_score = 0 + if args.apply_verifiable_reward: + max_possible_score += args.verification_reward + if args.apply_r1_style_format_reward and args.additive_format_reward: + max_possible_score += args.r1_style_format_reward + all_solved_groups = (scores_per_prompt == max_possible_score).all(axis=-1).sum() + total_groups = len(scores_per_prompt) + zero_groups_ratio = all_zero_groups / total_groups if total_groups > 0 else 0.0 + solved_groups_ratio = all_solved_groups / total_groups if total_groups > 0 else 0.0 + logger.info( + f"[Reward Summary] Groups with all zero rewards: {all_zero_groups}/{total_groups} " + f" Groups with all solved rewards: {all_solved_groups}/{total_groups} " + f"({zero_groups_ratio:.1%})" + ) with Timer("📦 [Data Preparation Thread] Packing sequences"): packed_sequences = pack_sequences( queries=batch.queries, - responses=responses, - masks=masks, + responses=responses_list, + masks=masks_list, pack_length=args.pack_length, pad_token_id=tokenizer.pad_token_id, - vllm_logprobs=vllm_logprobs, + vllm_logprobs=vllm_logprobs_list, ) num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses) # Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value @@ -1991,12 +2194,12 @@ def data_preparation_thread( ) # Create a result package with metrics and data - if len(responses) == 0: + if len(responses_list) == 0: # Handle empty responses case # in this case, we won't log metrics, so it should be fine. metrics = {} else: - sequence_lengths = np.array([len(response) for response in responses]) + sequence_lengths = np.array([len(response) for response in responses_list]) sequence_length_solved = ( np.array([]) if np.all(scores == 0) else np.array(sequence_lengths[scores == max_possible_score]) ) @@ -2004,16 +2207,18 @@ def data_preparation_thread( np.array([]) if np.all(scores == max_possible_score) else np.array(sequence_lengths[scores == 0]) ) - # Use the already calculated reward summary metrics for wandb - all_zero_groups_ratio = all_zero_groups / total_groups if total_groups > 0 else 0 - metrics = { "scores": np.array(scores).mean(), "real_batch_size_ratio": real_batch_size_ratio, "unsolved_batch_size_ratio": unsolved_batch_size_ratio, - "packed_ratio": len(packed_sequences.query_responses) / len(responses) if len(responses) > 0 else 0, + "packed_ratio": len(packed_sequences.query_responses) / len(responses_list) + if len(responses_list) > 0 + else 0, "val/all_zero_reward_groups": all_zero_groups, - "val/all_zero_reward_groups_ratio": all_zero_groups_ratio, + "val/all_zero_reward_groups_ratio": zero_groups_ratio, + "val/all_solved_reward_groups": all_solved_groups, + "val/all_solved_reward_groups_ratio": solved_groups_ratio, + "val/total_filtered_groups": total_prompts_filtered, "val/total_reward_groups": total_groups, "val/sequence_lengths": sequence_lengths.mean(), "val/sequence_lengths_min": sequence_lengths.min(), @@ -2042,13 +2247,14 @@ def data_preparation_thread( } total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time + if result.token_statistics.generation_time > 0: + metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time if args.save_traces: traces = { "scores": scores.tolist(), - "finish_reasons": finish_reasons, - "responses": responses, + "finish_reasons": finish_reasons_list, + "responses": responses_list, "training_step": training_step, **asdict(batch), # Unpack all batch fields **reward_metrics, @@ -2067,7 +2273,7 @@ def data_preparation_thread( "packed_sequences": packed_sequences, # for debugging purposes "collated_data": collated_data, "metrics": metrics, - "responses_count": len(responses), + "responses_count": len(responses_list), "num_new_tokens": num_new_tokens, "B": B, "prompt_lengths": prompt_lengths, @@ -2473,7 +2679,7 @@ def one_training_step( ) save_time = 0 - if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1): + if args.save_freq > 0 and (training_step % args.save_freq == 0 or (training_step == 1 and args.eval_on_step_0)): with Timer("[Main Thread] 🗡️ Saving model") as timer: checkpoint_dir = f"{args.output_dir}_checkpoints" step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") @@ -2912,6 +3118,7 @@ def health_check_fn(): else: num_total_tokens = 0 + filtered_prompts_count = 0 training_start_time = time.perf_counter() # Track overall training start time for training_step in range(resume_training_step, args.num_training_steps + 1): start_time = time.perf_counter() @@ -2947,6 +3154,23 @@ def health_check_fn(): generation_configs["train"], is_eval=False, ) + i = 0.1 + while filtered_prompts_count > args.num_unique_prompts_rollout: + another_batch = next_batch(next(iter_dataloader), train_dataset) + # little hack to make sure that we don't have the same training step, otherwise vllm requests can break + # putting both batches in the same training step might also break if we accidentally sample the same dataset index twice + split_and_insert_batch( + another_batch, + iter_dataloader.epoch_number, + training_step + i, + pending_queries_map, + param_prompt_Q, + generation_configs["train"], + is_eval=False, + ) + filtered_prompts_count -= args.num_unique_prompts_rollout + i += 0.1 + if ( training_step % args.local_eval_every == 0 and eval_batch is not None @@ -2962,12 +3186,14 @@ def health_check_fn(): is_eval=True, ) - collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths = ( + (collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths) = ( load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn) ) if collated_data is None: continue + filtered_prompts_count += data_thread_metrics["val/total_filtered_groups"] + for metrics_Q in [generate_metrics_Q, weight_sync_metrics_Q]: try: data_thread_metrics |= metrics_Q.get_nowait()