diff --git a/.gitignore b/.gitignore index 2a9250276..f0d64d70f 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,4 @@ dmypy.json cache/ local_dataset_cache/ scratch/ +vllm_olmo2.5/ diff --git a/Dockerfile b/Dockerfile index 2f6dae3c4..098fa0b95 100644 --- a/Dockerfile +++ b/Dockerfile @@ -78,7 +78,7 @@ COPY configs configs COPY scripts scripts COPY mason.py mason.py # Copy oe-eval-internal if it exists (wildcard pattern won't fail if missing) -COPY oe-eval-interna[l] oe-eval-internal/ +COPY oe-eval-internal oe-eval-internal COPY open_instruct open_instruct # Add build arguments for git information diff --git a/Makefile b/Makefile index 55ce7d63e..17f2ffa37 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: style quality +.PHONY: style quality docker # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = open_instruct @@ -16,3 +16,13 @@ style-check: ## *fail* if anything needs rewriting quality-check: ## *fail* if any rewrite was needed uv run ruff check --exit-non-zero-on-fix $(check_dirs) + +setup: + git clone -b shanea/olmo2-retrofit https://github.com/2015aroras/vllm.git vllm_olmo2.5 + +docker: + DOCKER_BUILDKIT=1 docker build -f Dockerfile --build-arg UV_CACHE_DIR=$(UV_CACHE_DIR) -t open_instruct_rlzero . + # if you are internally at AI2, you can create an image like this: + $(eval beaker_user := $(shell beaker account whoami --format json | jq -r '.[0].name')) + beaker image delete $(beaker_user)/open_instruct_rlzero + beaker image create open_instruct_rlzero -n open_instruct_rlzero -w ai2/$(beaker_user) diff --git a/generate_olmo25.sh b/generate_olmo25.sh new file mode 100755 index 000000000..ea80ed685 --- /dev/null +++ b/generate_olmo25.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +MODEL_NAME_OR_PATH="/weka/oe-training-default/ai2-llm/checkpoints/tylerr/long-context/olmo25_7b_lc_64k_6T_M100B_round5-sparkle_6634-pre_s2pdf_gzip2080_cweN-yake-all-olmo_packing_yarn-fullonly_50B-fb13a737/step11921-hf" +# DATASET="mnoukhov/DAPO-Math-14k-Processed-RLVR" +DATASET="TTTXXX01/MATH_3000_Filtered" +EXP_NAME="generate_olmo25_teng3k" + +python mason.py \ + --task_name ${EXP_NAME} \ + --cluster ai2/jupiter \ + --image ${1:-michaeln/open_instruct_olmo2_retrofit} \ + --workspace ai2/tulu-thinker \ + --priority high \ + --pure_docker_mode \ + --preemptible \ + --gpus 2 \ + --num_nodes 1 \ + --max_retries 0 \ + --budget ai2/oe-adapt \ + --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + --env VLLM_ATTENTION_BACKEND="FLASH_ATTN" \ + -- \ +python scripts/data/rlvr/filtering_vllm.py \ + --model $MODEL_NAME_OR_PATH \ + --dataset $DATASET \ + --split train \ + --temperature 0.7 \ + --top_p 0.95 \ + --offset 0 \ + --size 100000 \ + --chat_template olmo_thinker_r1_style_nochat \ + --output-file filtered_datasets/olmo25_7b_lc_dapo.jsonl \ + --number_samples 16 diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 52a414042..56188e255 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -442,6 +442,33 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai "{% endif %}" "{% endfor %}" ), + "olmo_thinker_r1_style_nochat": ( + "Solve the following math problem step by step. " + "Reason about the question in tags " + "then provide the final answer in tags " + "so the full response is reasoning process here " + " answer here ." + "\n\n" + "{% for message in messages %}" + "{{ '\n\n' if not loop.first else '' }}" + "{{ message['content'] + '\n' }}" + "{% if loop.last and add_generation_prompt %}" + "{{ 'Solving step by step\n' }}" + "{% endif %}" + "{% endfor %}" + ), + "olmo_thinker_dapo": ( + "Solve the following math problem step by step. " + "The last line of your response should be the answer to the problem in form Answer: $Answer (without quotes) where $Answer is the answer to the problem." + "\n\n" + "{% for message in messages %}" + "{{ '\n\n' if not loop.first else '' }}" + "{{ message['content'] + '\n' }}" + "{% if loop.last and add_generation_prompt %}" + "{{ '\nRemember to put your answer on its own line after \"Answer:\"' }}" + "{% endif %}" + "{% endfor %}" + ), # template is taken from https://arxiv.org/abs/2501.12948. "r1_simple_chat": ( "A conversation between User and Assistant. " diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7a816ad68..e9c7ac758 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -282,6 +282,13 @@ class Args: use_vllm_logprobs: bool = False """whether to use vLLM's logprobs for training instead of calculating them via forward pass""" + active_fill_completions: bool = False + """Whether to refill the batch with *new prompts/completions* after filtering.""" + active_fill_max_attempts: int = 3 + """How many times to attempt to fill""" + no_resample_solve_rate: float = None + """If a prompt is solved at more than this rate across K completions, don't resample it for next epoch""" + # Reward # -- r1 style format reward apply_r1_style_format_reward: bool = False @@ -411,6 +418,8 @@ class Args: """the max generation length for evaluation for oe-eval""" oe_eval_beaker_image: Optional[str] = None """the docker image for evaluation for oe-eval""" + oe_eval_gpu_multiplier: Optional[int] = 1 + """gpu mulitplier for eval jobs""" eval_priority: Literal["low", "normal", "high", "urgent"] = "normal" """the priority of auto-launched evaluation jobs""" @@ -574,16 +583,21 @@ def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None self.epoch_number = 0 self.rng = np.random.default_rng(seed) self.rng.shuffle(self.data) + self.exclude_list = [] # Ensure the effective dataset size is divisible by batch_size - self.effective_size = len(self.data) - (len(self.data) % batch_size) + self._update_effective_size() def __iter__(self) -> Iterator[List[int]]: return self def __next__(self) -> List[int]: + if self.effective_size == 0: + raise StopIteration("No data available to sample from the iterator.") + if self.index >= self.effective_size: self.index = 0 + self._update_effective_size() self.epoch_number += 1 self.rng.shuffle(self.data) @@ -593,6 +607,10 @@ def __next__(self) -> List[int]: return batch + def exclude_indices(self, exclude_list: List) -> None: + """Exclude provided data points from future sampling.""" + self.exclude_list.extend(exclude_list) + def get_state(self) -> Dict[str, Any]: """Get the current state of the iterator for checkpointing.""" return { @@ -600,6 +618,7 @@ def get_state(self) -> Dict[str, Any]: "epoch_number": self.epoch_number, "data": self.data.copy(), "rng_state": self.rng.bit_generator.state, + "exclude_list": self.exclude_list.copy(), } def set_state(self, state: Dict[str, Any]) -> None: @@ -608,6 +627,18 @@ def set_state(self, state: Dict[str, Any]) -> None: self.epoch_number = state["epoch_number"] self.data = state["data"].copy() self.rng.bit_generator.state = state["rng_state"] + self.exclude_list = state.get("exclude_list", []) + self._update_effective_size() + + def _update_effective_size(self) -> None: + if self.exclude_list: + mask = ~np.isin(self.data, self.exclude_list) + if mask.all(): + return + + self.data = self.data[mask] + self.exclude_list = [] + self.effective_size = len(self.data) - (len(self.data) % self.batch_size) @ray.remote(num_gpus=1) @@ -1356,6 +1387,7 @@ def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url args.gs_bucket_path, args.eval_priority, args.oe_eval_beaker_image, + args.oe_eval_gpu_multiplier, ) @@ -1579,6 +1611,7 @@ def accumulate_inference_batches( all_ground_truths = [] all_datasets = [] all_raw_queries = [] + all_indices = [] for i in tqdm( range(num_prompts), total=num_prompts, @@ -1605,6 +1638,7 @@ def accumulate_inference_batches( all_ground_truths.append(ground_truth) all_datasets.append(dataset) all_raw_queries.append(raw_query) + all_indices.append(result.dataset_index) # Combine all results into a single GenerationResult combined_responses = [] @@ -1686,13 +1720,13 @@ def accumulate_inference_batches( if actor_manager is not None: ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - # Note: We don't have dataset_indices here, but they're not needed for the returned batch + # Preserve dataset indices so downstream filtering keeps track of original prompts batch = Batch( queries=all_queries, ground_truths=all_ground_truths, datasets=all_datasets, raw_queries=all_raw_queries, - indices=None, # Not meaningful for combined results + indices=all_indices, ) return combined_result, batch, prompt_lengths, response_lengths @@ -1709,181 +1743,351 @@ def data_preparation_thread( resume_training_step: int, actor_manager=None, model_dims: utils.ModelDims = None, + train_iterator: ShufflingIterator = None, ): - for training_step in range(resume_training_step, num_training_steps + 1): - # Streaming accumulation: collect results as they arrive - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, prompt_lengths, response_lengths = accumulate_inference_batches( - inference_results_Q, - pending_queries_map, - args, - generation_config, - num_prompts=args.num_unique_prompts_rollout, - model_dims=model_dims, - actor_manager=actor_manager, - ) - if isinstance(result, ShutdownSentinel): - logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") - return - - getting_response_time = timer.duration - - # ------------------------------------------------------------------------------------------------ - # Pack sequences - if args.num_samples_per_prompt_rollout > 1: - batch = Batch( - queries=repeat_each(batch.queries, args.num_samples_per_prompt_rollout), - ground_truths=repeat_each(batch.ground_truths, args.num_samples_per_prompt_rollout), - datasets=repeat_each(batch.datasets, args.num_samples_per_prompt_rollout), - raw_queries=repeat_each(batch.raw_queries, args.num_samples_per_prompt_rollout), - indices=repeat_each(batch.indices, args.num_samples_per_prompt_rollout) if batch.indices else None, - ) - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - decoded_queries = batch.raw_queries - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons - ) + def combine_reward_metrics(metric_records: list[tuple[Dict[str, Any], int]]) -> Dict[str, Any]: + if not metric_records: + return {} + + buckets: Dict[str, list[tuple[Any, int]]] = {} + for metrics, weight in metric_records: + if not metrics: + continue + for key, value in metrics.items(): + buckets.setdefault(key, []).append((value, weight)) + + combined: Dict[str, Any] = {} + for key, records in buckets.items(): + sample_value = records[0][0] + if isinstance(sample_value, np.ndarray): + combined[key] = np.concatenate([np.asarray(value) for value, _ in records]) + elif isinstance(sample_value, (list, tuple)): + concatenated: list[Any] = [] + for value, _ in records: + concatenated.extend(list(value)) + combined[key] = concatenated + elif isinstance(sample_value, (int, float, bool, np.integer, np.floating)): + total_weight = sum(weight for _, weight in records) + if total_weight == 0: + combined[key] = float(sample_value) + else: + combined[key] = sum(float(value) * weight for value, weight in records) / total_weight + else: + # Fallback: keep the latest value if aggregation strategy is unclear. + combined[key] = records[-1][0] + return combined - with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"): - scores, reward_metrics = asyncio.run( - reward_fn( - result.responses, - decoded_responses, - batch, - result.finish_reasons, - result.request_info, - decoded_queries, + for training_step in range(resume_training_step, num_training_steps + 1): + per_prompt_rollout = args.num_samples_per_prompt_rollout + target_prompt_count = args.num_unique_prompts_rollout + + aggregated_scores: list[np.ndarray] = [] + aggregated_advantages: list[np.ndarray] = [] + aggregated_reward_metrics: list[tuple[Dict[str, Any], int]] = [] + aggregated_responses: list[List[int]] = [] + aggregated_masks: list[List[int]] = [] + aggregated_finish_reasons: list[str] = [] + aggregated_vllm_logprobs: list[List[float]] = [] + aggregated_batches: list[Batch] = [] + aggregated_prompt_lengths: list[int] = [] + aggregated_response_lengths: list[int] = [] + aggregated_num_calls: list[int] = [] + aggregated_timeouts: list[int] = [] + aggregated_tool_errors: list[str] = [] + aggregated_tool_outputs: list[str] = [] + aggregated_tool_runtimes: list[float] = [] + aggregated_tool_calleds: list[bool] = [] + + total_prompt_tokens = 0 + total_response_tokens = 0 + total_generation_time = 0.0 + earliest_start_time: Optional[float] = None + + total_stop_completions = 0 + total_completions = 0 + + total_prompts_kept = 0 + total_prompts_filtered = 0 + solved_prompts_filtered = 0 + zero_prompts_filtered = 0 + noresample_prompts_filtered = 0 + prompts_to_request = target_prompt_count + getting_response_time = 0.0 + + fill_iteration = 0 + while prompts_to_request > 0: + fill_iteration += 1 + # Streaming accumulation: collect results as they arrive + with Timer(f"🚀 [Data Preparation Thread] Getting {prompts_to_request} response ids") as timer: + result_step, batch_step, prompt_lengths_step, response_lengths_step = accumulate_inference_batches( + inference_results_Q, + pending_queries_map, + args, + generation_config, + num_prompts=prompts_to_request, + model_dims=model_dims, + actor_manager=actor_manager, ) - ) - scores = np.array(scores) - scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - if args.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif args.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") - - with Timer("📦 [Data Preparation Thread] Filtering sequences"): - # Here we get the max possible score for each prompt, and see how many prompts are unsolved - max_possible_score = 0 - if args.apply_verifiable_reward: - max_possible_score += args.verification_reward - if args.apply_r1_style_format_reward and args.additive_format_reward: - max_possible_score += args.r1_style_format_reward - unsolved_batch_size_ratio = ((scores != max_possible_score) > 0).sum() / len(scores) - # In GRPO, if the std of grouped rewards is 0, then there is zero gradient for the batch - # of args.num_samples_per_prompt_rollout responses, so we need to filter out those batches - non_zero_std_mask = scores_per_prompt.std(axis=-1) != 0 - real_batch_size_ratio = non_zero_std_mask.sum() * args.num_samples_per_prompt_rollout / len(scores) - expanded_mask = np.repeat(non_zero_std_mask, args.num_samples_per_prompt_rollout) - non_zero_gradient_index = np.where(expanded_mask)[0] - - # Log zero-gradient filtering statistics - num_zero_std_prompts = (~non_zero_std_mask).sum() - num_filtered_responses = len(scores) - len(non_zero_gradient_index) - if num_filtered_responses > 0: - logger.info( - f"[Zero-gradient filtering] Filtered {num_zero_std_prompts} prompts with zero std " - f"({num_filtered_responses} responses). Retention rate: {len(non_zero_gradient_index) / len(scores):.2%}" + if isinstance(result_step, ShutdownSentinel): + logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") + return + + getting_response_time += timer.duration + + if args.num_samples_per_prompt_rollout > 1: + batch_step = Batch( + queries=repeat_each(batch_step.queries, per_prompt_rollout), + ground_truths=repeat_each(batch_step.ground_truths, per_prompt_rollout), + datasets=repeat_each(batch_step.datasets, per_prompt_rollout), + raw_queries=repeat_each(batch_step.raw_queries, per_prompt_rollout), + indices=repeat_each(batch_step.indices, per_prompt_rollout) if batch_step.indices else None, ) - advantages = advantages[non_zero_gradient_index] - original_batch_size = len(scores) - scores = scores[non_zero_gradient_index] - responses = [result.responses[i] for i in non_zero_gradient_index] - masks = [result.masks[i] for i in non_zero_gradient_index] - batch = batch[non_zero_gradient_index.tolist()] - finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index] - vllm_logprobs = [result.logprobs[i] for i in non_zero_gradient_index] - if args.mask_truncated_completions: - stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"]) - num_truncated = len(finish_reasons) - len(stop_idxes) - if num_truncated > 0: + for i in range(len(result_step.finish_reasons)): + if result_step.finish_reasons[i] == "stop" and len(result_step.responses[i]) == 0: + result_step.responses[i].append(tokenizer.eos_token_id) + result_step.masks[i].append(1) + result_step.logprobs[i].append(float("nan")) + + with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): + decoded_responses_step = tokenizer.batch_decode(result_step.responses, skip_special_tokens=True) + decoded_queries_step = batch_step.raw_queries + + with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"): + scores_step, reward_metrics_step = asyncio.run( + reward_fn( + result_step.responses, + decoded_responses_step, + batch_step, + result_step.finish_reasons, + result_step.request_info, + decoded_queries_step, + ) + ) + scores_step = np.array(scores_step) + scores_per_prompt_step = scores_step.reshape(-1, per_prompt_rollout) + mean_grouped_rewards = scores_per_prompt_step.mean(axis=-1) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, per_prompt_rollout, axis=0) + std_grouped_rewards = scores_per_prompt_step.std(axis=-1) + std_grouped_rewards = np.repeat(std_grouped_rewards, per_prompt_rollout, axis=0) + if args.advantage_normalization_type == "standard": + advantages_step = (scores_step - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + elif args.advantage_normalization_type == "centered": + advantages_step = scores_step - mean_grouped_rewards + else: + raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") + + # Filtering (zero-gradient, truncated completions) + with Timer("📦 [Data Preparation Thread] Filtering sequences"): + max_possible_score = 0 + if args.apply_verifiable_reward: + max_possible_score += args.verification_reward + if args.apply_r1_style_format_reward and args.additive_format_reward: + max_possible_score += args.r1_style_format_reward + + if args.no_resample_solve_rate is not None: + percent_solved_prompt = scores_per_prompt_step.mean(axis=-1) / max_possible_score + solved_mask = percent_solved_prompt >= args.no_resample_solve_rate + solved_expanded_mask = np.repeat(solved_mask, per_prompt_rollout) + solved_indices = np.where(solved_expanded_mask)[0] + solved_prompt_indices = batch_step[solved_indices.tolist()].indices + train_iterator.exclude_indices(list(set(solved_prompt_indices))) + noresample_prompts_filtered += solved_mask.sum() + + non_zero_std_mask = scores_per_prompt_step.std(axis=-1) != 0 + expanded_mask = np.repeat(non_zero_std_mask, per_prompt_rollout) + non_zero_indices = np.where(expanded_mask)[0] + + num_zero_std_prompts = (~non_zero_std_mask).sum() + num_filtered_responses = len(scores_step) - len(non_zero_indices) + + # Count groups with all zero rewards + zero_prompts_filtered += (scores_per_prompt_step == 0).all(axis=-1).sum() + solved_prompts_filtered += (scores_per_prompt_step == max_possible_score).all(axis=-1).sum() + total_prompts_filtered += num_zero_std_prompts + + if num_filtered_responses > 0: logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(finish_reasons):.2%}" + f"[Zero-gradient filtering] Filtered {num_zero_std_prompts} prompts with zero std " + f"({num_filtered_responses} responses). Retention rate: {len(non_zero_indices) / len(scores_step):.2%}" ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - responses = [responses[i] for i in stop_idxes] - masks = [masks[i] for i in stop_idxes] - batch = batch[stop_idxes.tolist()] - finish_reasons = [finish_reasons[i] for i in stop_idxes] - vllm_logprobs = [vllm_logprobs[i] for i in stop_idxes] - - if args.fill_completions: - with Timer("⏱ [Data Preparation Thread] Refill completions"): - current_batch_size = len(scores) - original_prompt_cnt = original_batch_size // args.num_samples_per_prompt_rollout - current_prompt_cnt = current_batch_size // args.num_samples_per_prompt_rollout - need_to_fill_prompt = original_prompt_cnt - current_prompt_cnt - k = args.num_samples_per_prompt_rollout - - if need_to_fill_prompt > 0 and current_prompt_cnt > 0: - scores_matrix = scores.reshape(current_prompt_cnt, k) - stds = scores_matrix.std(axis=1) + 1e-8 - probs = stds / stds.sum() + advantages_step = advantages_step[non_zero_indices] + scores_step = scores_step[non_zero_indices] + responses_step = [result_step.responses[i] for i in non_zero_indices] + masks_step = [result_step.masks[i] for i in non_zero_indices] + batch_step = batch_step[non_zero_indices.tolist()] + finish_reasons_step = [result_step.finish_reasons[i] for i in non_zero_indices] + vllm_logprobs_step = [result_step.logprobs[i] for i in non_zero_indices] + + prompt_indices_kept = np.where(non_zero_std_mask)[0] + prompt_lengths_kept = [prompt_lengths_step[i] for i in prompt_indices_kept] + response_lengths_kept = [response_lengths_step[i] for i in non_zero_indices] + + num_calls_kept = [result_step.request_info.num_calls[i] for i in non_zero_indices] + timeouts_kept = [result_step.request_info.timeouts[i] for i in non_zero_indices] + tool_errors_kept = [result_step.request_info.tool_errors[i] for i in non_zero_indices] + tool_outputs_kept = [result_step.request_info.tool_outputs[i] for i in non_zero_indices] + tool_runtimes_kept = [result_step.request_info.tool_runtimes[i] for i in non_zero_indices] + tool_calleds_kept = [result_step.request_info.tool_calleds[i] for i in non_zero_indices] + + if args.mask_truncated_completions: + truncated_mask = np.array([reason != "stop" for reason in finish_reasons_step], dtype=bool) + num_truncated = int(truncated_mask.sum()) + if num_truncated > 0: + scores_step = scores_step.copy() + scores_step[truncated_mask] = 0.0 + advantages_step = advantages_step.copy() + advantages_step[truncated_mask] = 0.0 + for idx in np.where(truncated_mask)[0]: + masks_step[idx] = [0 for _ in masks_step[idx]] logger.info( - f"[Refill completions] Need to fill {need_to_fill_prompt} prompts to maintain batch size. " - f"Original: {original_prompt_cnt}, Current: {current_prompt_cnt}" + f"[Truncated completions masking] Zeroed scores/advantages and response masks for {num_truncated} responses that didn't finish with 'stop'." ) - - sampled_prompt_ids = np.random.choice( - current_prompt_cnt, size=need_to_fill_prompt, replace=True, p=probs + aggregated_scores.append(scores_step) + aggregated_advantages.append(advantages_step) + aggregated_reward_metrics.append((reward_metrics_step, len(scores_step))) + aggregated_responses.extend(responses_step) + aggregated_masks.extend(masks_step) + aggregated_finish_reasons.extend(finish_reasons_step) + aggregated_vllm_logprobs.extend(vllm_logprobs_step) + aggregated_batches.append(batch_step) + aggregated_prompt_lengths.extend(prompt_lengths_kept) + aggregated_response_lengths.extend(response_lengths_kept) + aggregated_num_calls.extend(num_calls_kept) + aggregated_timeouts.extend(timeouts_kept) + aggregated_tool_errors.extend(tool_errors_kept) + aggregated_tool_outputs.extend(tool_outputs_kept) + aggregated_tool_runtimes.extend(tool_runtimes_kept) + aggregated_tool_calleds.extend(tool_calleds_kept) + + total_stop_completions += sum(int(reason == "stop") for reason in finish_reasons_step) + total_completions += len(finish_reasons_step) + + if result_step.token_statistics is not None: + total_prompt_tokens += result_step.token_statistics.num_prompt_tokens + total_response_tokens += result_step.token_statistics.num_response_tokens + total_generation_time += result_step.token_statistics.generation_time + if result_step.token_statistics.earliest_start_time is not None: + if earliest_start_time is None: + earliest_start_time = result_step.token_statistics.earliest_start_time + else: + earliest_start_time = min( + earliest_start_time, result_step.token_statistics.earliest_start_time ) - sampled_indices = [] - for pid in sampled_prompt_ids: - start = pid * k - sampled_indices.extend(range(start, start + k)) - - advantages = np.concatenate([advantages, advantages[sampled_indices]]) - scores = np.concatenate([scores, scores[sampled_indices]]) - responses += [responses[i] for i in sampled_indices] - masks += [masks[i] for i in sampled_indices] + prompts_kept_this_iter = len(scores_step) // per_prompt_rollout if per_prompt_rollout > 0 else 0 + total_prompts_kept += prompts_kept_this_iter - sampled_batch = batch[sampled_indices] + if not args.active_fill_completions: + break - batch = Batch( - queries=batch.queries + sampled_batch.queries, - ground_truths=batch.ground_truths + sampled_batch.ground_truths, - datasets=batch.datasets + sampled_batch.datasets, - indices=batch.indices + sampled_batch.indices if batch.indices is not None else None, - ) + prompts_remaining = target_prompt_count - total_prompts_kept + if prompts_remaining <= 0: + break - finish_reasons += [finish_reasons[i] for i in sampled_indices] - vllm_logprobs += [vllm_logprobs[i] for i in sampled_indices] + if prompts_kept_this_iter == 0 and fill_iteration > args.active_fill_max_attempts: + logger.warning( + "[Active fill completions] Unable to collect non-zero advantage prompts in iteration %d; " + "set as max by args.active_fill_max_attempts, proceeding with existing batch of size %d.", + fill_iteration, + len(aggregated_responses), + ) + break - logger.info( - f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses" - ) + prompts_to_request = prompts_remaining + + # Build aggregated containers from collected data + if aggregated_batches: + queries: list[List[int]] = [] + ground_truths: list[List[int]] = [] + datasets: list[str] = [] + raw_queries_list: Optional[list[str]] = None + indices_list: Optional[list[int]] = None + + raw_queries_present = aggregated_batches[0].raw_queries is not None + indices_present = aggregated_batches[0].indices is not None + if raw_queries_present: + raw_queries_list = [] + if indices_present: + indices_list = [] + + for batch_step in aggregated_batches: + queries.extend(batch_step.queries) + ground_truths.extend(batch_step.ground_truths) + datasets.extend(batch_step.datasets) + if raw_queries_present: + assert batch_step.raw_queries is not None + raw_queries_list.extend(batch_step.raw_queries) + if indices_present: + assert batch_step.indices is not None + indices_list.extend(batch_step.indices) - # Count groups with all zero rewards - all_zero_groups = (scores_per_prompt == 0).all(axis=-1).sum() - total_groups = len(scores_per_prompt) - logger.info( - f"[Reward Summary] Groups with all zero rewards: {all_zero_groups}/{total_groups} " - f"({all_zero_groups / total_groups:.1%})" + batch = Batch( + queries=queries, + ground_truths=ground_truths, + datasets=datasets, + raw_queries=raw_queries_list if raw_queries_present else None, + indices=indices_list if indices_present else None, ) + else: + batch = Batch(queries=[], ground_truths=[], datasets=[], raw_queries=[], indices=[]) + + prompt_lengths = aggregated_prompt_lengths + response_lengths = aggregated_response_lengths + + token_statistics = TokenStatistics( + num_prompt_tokens=total_prompt_tokens, + num_response_tokens=total_response_tokens, + generation_time=total_generation_time, + earliest_start_time=earliest_start_time, + ) + result = GenerationResult( + responses=aggregated_responses, + finish_reasons=aggregated_finish_reasons, + masks=aggregated_masks, + request_info=RequestInfo( + num_calls=aggregated_num_calls, + timeouts=aggregated_timeouts, + tool_errors=aggregated_tool_errors, + tool_outputs=aggregated_tool_outputs, + tool_runtimes=aggregated_tool_runtimes, + tool_calleds=aggregated_tool_calleds, + ), + token_statistics=token_statistics, + logprobs=aggregated_vllm_logprobs, + ) + + good_outputs = [ + len(result.request_info.tool_outputs[i]) > 0 + and result.request_info.tool_calleds[i] + and not result.request_info.timeouts[i] + and not result.request_info.tool_errors[i] + for i in range(len(result.request_info.tool_outputs)) + ] + stop_rate = (total_stop_completions / total_completions) if total_completions > 0 else 0 + + scores = np.concatenate(aggregated_scores) if aggregated_scores else np.array([]) + advantages = np.concatenate(aggregated_advantages) if aggregated_advantages else np.array([]) + reward_metrics = combine_reward_metrics(aggregated_reward_metrics) + + if len(scores) == 0: + logger.warning(f"No responses with non-zero advantages in batch {training_step}.") + + max_possible_score = 0 + if args.apply_verifiable_reward: + max_possible_score += args.verification_reward + if args.apply_r1_style_format_reward and args.additive_format_reward: + max_possible_score += args.r1_style_format_reward + + unsolved_batch_size_ratio = ((scores != max_possible_score) > 0).sum() / len(scores) + original_batch_size = target_prompt_count * per_prompt_rollout + real_batch_size_ratio = len(scores) / original_batch_size if original_batch_size > 0 else 0 + + responses = aggregated_responses + masks = aggregated_masks + finish_reasons = aggregated_finish_reasons + vllm_logprobs = aggregated_vllm_logprobs with Timer("📦 [Data Preparation Thread] Packing sequences"): packed_sequences = pack_sequences( @@ -2004,17 +2208,15 @@ def data_preparation_thread( np.array([]) if np.all(scores == max_possible_score) else np.array(sequence_lengths[scores == 0]) ) - # Use the already calculated reward summary metrics for wandb - all_zero_groups_ratio = all_zero_groups / total_groups if total_groups > 0 else 0 - metrics = { "scores": np.array(scores).mean(), "real_batch_size_ratio": real_batch_size_ratio, "unsolved_batch_size_ratio": unsolved_batch_size_ratio, "packed_ratio": len(packed_sequences.query_responses) / len(responses) if len(responses) > 0 else 0, - "val/all_zero_reward_groups": all_zero_groups, - "val/all_zero_reward_groups_ratio": all_zero_groups_ratio, - "val/total_reward_groups": total_groups, + "val/zero_filtered_groups": zero_prompts_filtered, + "val/solved_filtered_groups": solved_prompts_filtered, + "val/total_filtered_groups": total_prompts_filtered, + "val/noresample_filtered_groups": noresample_prompts_filtered, "val/sequence_lengths": sequence_lengths.mean(), "val/sequence_lengths_min": sequence_lengths.min(), "val/sequence_lengths_max": sequence_lengths.max(), @@ -2473,7 +2675,7 @@ def one_training_step( ) save_time = 0 - if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1): + if args.save_freq > 0 and (training_step % args.save_freq == 0 or (training_step == 1 and args.eval_on_step_0)): with Timer("[Main Thread] 🗡️ Saving model") as timer: checkpoint_dir = f"{args.output_dir}_checkpoints" step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") @@ -2883,6 +3085,7 @@ def run_training( resume_training_step, actor_manager, model_dims, + iter_dataloader, ) def health_check_fn(): @@ -2912,6 +3115,7 @@ def health_check_fn(): else: num_total_tokens = 0 + filtered_prompts_count = 0 training_start_time = time.perf_counter() # Track overall training start time for training_step in range(resume_training_step, args.num_training_steps + 1): start_time = time.perf_counter() @@ -2947,6 +3151,23 @@ def health_check_fn(): generation_configs["train"], is_eval=False, ) + i = 0.1 + while filtered_prompts_count > args.num_unique_prompts_rollout: + another_batch = next_batch(next(iter_dataloader), train_dataset) + # little hack to make sure that we don't have the same training step, otherwise vllm requests can break + # putting both batches in the same training step might also break if we accidentally sample the same dataset index twice + split_and_insert_batch( + another_batch, + iter_dataloader.epoch_number, + training_step + i, + pending_queries_map, + param_prompt_Q, + generation_configs["train"], + is_eval=False, + ) + filtered_prompts_count -= args.num_unique_prompts_rollout + i += 0.1 + if ( training_step % args.local_eval_every == 0 and eval_batch is not None @@ -2962,12 +3183,14 @@ def health_check_fn(): is_eval=True, ) - collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths = ( + (collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths) = ( load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn) ) if collated_data is None: continue + filtered_prompts_count += data_thread_metrics["val/total_filtered_groups"] + for metrics_Q in [generate_metrics_Q, weight_sync_metrics_Q]: try: data_thread_metrics |= metrics_Q.get_nowait() diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 38697c67e..d878170a1 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -1149,6 +1149,7 @@ def launch_ai2_evals_on_weka( gs_bucket_path: Optional[str] = None, eval_priority: Optional[str] = "normal", beaker_image: Optional[str] = None, + oe_eval_gpu_multiplier: Optional[int] = 1, ) -> None: weka_cluster = "ai2/saturn ai2/neptune" gcp_cluster = "ai2/augusta" @@ -1178,6 +1179,7 @@ def launch_ai2_evals_on_weka( command = f"""\ python scripts/submit_eval_jobs.py \ +--gpu_multiplier {oe_eval_gpu_multiplier} \ --model_name {leaderboard_name} \ --location {path} \ --cluster {cluster} \ diff --git a/pyproject.toml b/pyproject.toml index 8dca0b159..b8af0df55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "bitsandbytes>=0.44.1; platform_system != 'Darwin'", "datasets>=4.0.0", "debugpy>=1.8.13", - "deepspeed<0.17.6", # version 17.6 and up break our training code right now + "deepspeed<0.17.3", # version 17.6 and up break our training code right now "hf-transfer>=0.1.8", "litellm>=1.72.0,<1.75.2", # avoid needing backoff https://github.com/BerriAI/litellm/issues/13827 "matplotlib>=3.9.3", @@ -95,6 +95,7 @@ target-version = ['py310'] [tool.isort] known_first_party = ["open_instruct"] +known-third-party = ["wandb"] profile = "black" src_paths = ["open_instruct"] diff --git a/scripts/convert_deepspeed_checkpoint_to_hf.py b/scripts/convert_deepspeed_checkpoint_to_hf.py new file mode 100644 index 000000000..4e83ced9d --- /dev/null +++ b/scripts/convert_deepspeed_checkpoint_to_hf.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +""" +Utility to turn a DeepSpeed ZeRO-2/3 checkpoint produced by grpo_fast.py into a Hugging Face model folder. + +Usage example: + python scripts/convert_deepspeed_checkpoint_to_hf.py \ + --checkpoint-dir /path/to/checkpoint_state_dir \ + --output-dir /path/to/exported_model \ + --model-config /path/to/base_model_or_config \ + --tokenizer /path/to/tokenizer \ + --tag global_step16384 +""" + +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path +from typing import Optional + +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +try: + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint +except ImportError as exc: # pragma: no cover - makes failure mode obvious to user + raise SystemExit( + "Could not import DeepSpeed. Install deepspeed in the current environment before running this script." + ) from exc + +try: + from open_instruct.dataset_transformation import CHAT_TEMPLATES +except Exception: # pragma: no cover - script is still useful without templates + CHAT_TEMPLATES = {} + +try: + from open_instruct.model_utils import get_olmo3_generation_config +except Exception: # pragma: no cover - optional helper + get_olmo3_generation_config = None + + +def _resolve_checkpoint_leaf(checkpoint_dir: Path, explicit_tag: Optional[str]) -> tuple[Path, Optional[str], Path]: + """ + Determine which folder actually holds the DeepSpeed shards. + + Returns a tuple of (root_dir, tag, leaf_dir). + """ + checkpoint_dir = checkpoint_dir.expanduser().resolve() + if not checkpoint_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory {checkpoint_dir} does not exist.") + + tag = explicit_tag + if tag is None: + latest_file = checkpoint_dir / "latest" + if latest_file.is_file(): + tag = latest_file.read_text().strip() + + # If a tag is set, assume that `checkpoint_dir/tag` holds the shard files. + if tag: + leaf_dir = checkpoint_dir / tag + else: + leaf_dir = checkpoint_dir + + if not leaf_dir.exists(): + raise FileNotFoundError(f"Derived checkpoint leaf {leaf_dir} does not exist.") + + shard = leaf_dir / "mp_rank_00_model_states.pt" + if not shard.exists(): + # Some jobs nest the shards a level deeper, so search for a unique match. + matches = list(leaf_dir.glob("**/mp_rank_00_model_states.pt")) + if len(matches) == 1: + leaf_dir = matches[0].parent + elif len(matches) == 0: + raise FileNotFoundError( + f"Could not find mp_rank_00_model_states.pt under {leaf_dir}. " + "Make sure you are pointing at a DeepSpeed ZeRO checkpoint." + ) + else: + raise RuntimeError( + f"Found multiple shard folders in {leaf_dir}. " + "Specify --tag to disambiguate which checkpoint to convert." + ) + + return checkpoint_dir, tag, leaf_dir + + +def _parse_dtype(dtype_str: Optional[str]) -> Optional[torch.dtype]: + if dtype_str is None: + return None + + normalized = dtype_str.lower() + mapping = { + "float16": torch.float16, + "fp16": torch.float16, + "half": torch.float16, + "float32": torch.float32, + "fp32": torch.float32, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float64": torch.float64, + "fp64": torch.float64, + } + if normalized not in mapping: + valid = ", ".join(sorted(set(mapping.keys()))) + raise ValueError(f"Unknown dtype '{dtype_str}'. Expected one of: {valid}") + return mapping[normalized] + + +def _load_state_dict(leaf_dir: Path, tag: Optional[str]) -> dict[str, torch.Tensor]: + """ + Load the aggregated state dict from the DeepSpeed checkpoint. + """ + # The DeepSpeed utility accepts either the leaf directory that holds the shards + # (when tag=None) or the checkpoint root plus an explicit tag. + if tag is None: + state_dict = load_state_dict_from_zero_checkpoint(str(leaf_dir)) + else: + state_dict = load_state_dict_from_zero_checkpoint(str(leaf_dir.parent), tag=tag) + return state_dict + + +def _maybe_set_chat_template(tokenizer, chat_template_name: Optional[str]) -> None: + if not chat_template_name: + return + + template = None + if chat_template_name in CHAT_TEMPLATES: + template = CHAT_TEMPLATES[chat_template_name] + else: + path = Path(chat_template_name) + if path.is_file(): + template = path.read_text() + + if template is None: + raise ValueError( + f"Could not resolve chat template '{chat_template_name}'. " + "Pass a key from CHAT_TEMPLATES or a path to a .jinja template file." + ) + + tokenizer.chat_template = template + + +def _maybe_add_generation_config(model, tokenizer, chat_template_name: Optional[str]) -> None: + if not chat_template_name or get_olmo3_generation_config is None: + return + normalized = chat_template_name.lower() + if "olmo" in normalized: + gen_config = get_olmo3_generation_config(tokenizer) + model.generation_config = gen_config + + +def convert_checkpoint( + checkpoint_dir: Path, + output_dir: Path, + model_config_source: str, + tokenizer_source: Optional[str], + tag: Optional[str], + dtype_str: Optional[str], + trust_remote_code: bool, + chat_template_name: Optional[str], +) -> None: + checkpoint_dir, tag, leaf_dir = _resolve_checkpoint_leaf(checkpoint_dir, tag) + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + dtype = _parse_dtype(dtype_str) + + state_dict = _load_state_dict(leaf_dir, tag) + + config = AutoConfig.from_pretrained(model_config_source, trust_remote_code=trust_remote_code) + if dtype is not None: + config.torch_dtype = dtype + model = AutoModelForCausalLM.from_config(config, trust_remote_code=trust_remote_code) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if missing_keys: + print(f"[warning] Missing keys while loading state dict ({len(missing_keys)}): {missing_keys[:8]}") + if unexpected_keys: + print(f"[warning] Unexpected keys in state dict ({len(unexpected_keys)}): {unexpected_keys[:8]}") + model.tie_weights() + + if tokenizer_source is None: + tokenizer_source = model_config_source + tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, trust_remote_code=trust_remote_code) + + _maybe_set_chat_template(tokenizer, chat_template_name) + _maybe_add_generation_config(model, tokenizer, chat_template_name) + + model.save_pretrained(output_dir, safe_serialization=True) + tokenizer.save_pretrained(output_dir) + + metadata = { + "source_checkpoint": str(checkpoint_dir), + "tag": tag, + "model_config_source": model_config_source, + "tokenizer_source": tokenizer_source, + "dtype": dtype_str, + } + with (output_dir / "conversion_metadata.json").open("w") as f: + json.dump(metadata, f, indent=2) + + print(f"Converted checkpoint from {leaf_dir} into {output_dir}") + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--checkpoint-dir", required=True, help="Directory that holds DeepSpeed checkpoint state.") + parser.add_argument( + "--output-dir", + required=True, + help="Where to write the Hugging Face model (created if it does not exist).", + ) + parser.add_argument( + "--model-config", + required=True, + help="Path or identifier for AutoConfig.from_pretrained to define the model architecture.", + ) + parser.add_argument( + "--tokenizer", + default=None, + help="Path or identifier for the tokenizer. Defaults to the same value as --model-config.", + ) + parser.add_argument( + "--tag", + default=None, + help="Optional DeepSpeed tag (e.g., global_step1234). If omitted, uses the 'latest' file when present.", + ) + parser.add_argument( + "--torch-dtype", + default=None, + help="Optional torch dtype to record in the config (e.g., bf16, fp16).", + ) + parser.add_argument( + "--chat-template-name", + default=None, + help="Tokenizer chat template key (from CHAT_TEMPLATES) or path to a template file to attach.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Forward trust_remote_code=True when loading config/tokenizer/model.", + ) + return parser + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + convert_checkpoint( + checkpoint_dir=Path(args.checkpoint_dir), + output_dir=Path(args.output_dir), + model_config_source=args.model_config, + tokenizer_source=args.tokenizer, + tag=args.tag, + dtype_str=args.torch_dtype, + trust_remote_code=args.trust_remote_code, + chat_template_name=args.chat_template_name, + ) + + +if __name__ == "__main__": + main() + diff --git a/scripts/data/rlvr/filtering_vllm.py b/scripts/data/rlvr/filtering_vllm.py index 160dca91f..bf8c0d700 100644 --- a/scripts/data/rlvr/filtering_vllm.py +++ b/scripts/data/rlvr/filtering_vllm.py @@ -1,4 +1,4 @@ -''' +""" python mason.py \ --cluster ai2/jupiter --image nathanl/open_instruct_auto \ --workspace ai2/tulu-thinker \ @@ -15,7 +15,8 @@ --size 100000 \ --output-file filtered_datasets/qwen2_5_openthoughts2/orz.jsonl \ --number_samples 8 -''' +""" + import argparse import json @@ -27,65 +28,20 @@ def main(): - parser = argparse.ArgumentParser( - description="Bulk-generate N samples per HF dataset record using vLLM." - ) - parser.add_argument( - "--model", - required=True, - help="vLLM model ID (e.g. facebook/opt-125m)" - ) - parser.add_argument( - "--dataset", - required=True, - help="HF dataset name (e.g. squad)" - ) + parser = argparse.ArgumentParser(description="Bulk-generate N samples per HF dataset record using vLLM.") + parser.add_argument("--model", required=True, help="vLLM model ID (e.g. facebook/opt-125m)") + parser.add_argument("--dataset", required=True, help="HF dataset name (e.g. squad)") + parser.add_argument("--split", default="train", help="Which split to load") + parser.add_argument("--offset", type=int, required=True, help="Start index into the split") + parser.add_argument("--size", type=int, required=True, help="Number of records to process") + parser.add_argument("--output-file", default=None, help="Path for output JSONL") parser.add_argument( - "--split", - default="train", - help="Which split to load" - ) - parser.add_argument( - "--offset", - type=int, - required=True, - help="Start index into the split" - ) - parser.add_argument( - "--size", - type=int, - required=True, - help="Number of records to process" - ) - parser.add_argument( - "--output-file", - default=None, - help="Path for output JSONL" - ) - parser.add_argument( - "--push_to_hub", - default=None, - type=str, - help="Give a dataset name to push this data to the hub." - ) - parser.add_argument( - "--chat_template", - type=str, - default=None, - help="Chat template name" - ) - parser.add_argument( - "--number_samples", - type=int, - default=8, - help="Number of samples to generate per record" - ) - parser.add_argument( - "--temperature", - type=float, - default=1.0, - help="Sampling temperature" + "--push_to_hub", default=None, type=str, help="Give a dataset name to push this data to the hub." ) + parser.add_argument("--chat_template", type=str, default=None, help="Chat template name") + parser.add_argument("--number_samples", type=int, default=8, help="Number of samples to generate per record") + parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") + parser.add_argument("--top_p", type=float, default=1.0, help="Sampling temperature") args = parser.parse_args() # 1. Load and slice dataset @@ -106,20 +62,14 @@ def main(): tokenizer.apply_chat_template( sample["messages"][:-1] if len(sample["messages"]) > 1 else sample["messages"], add_generation_prompt=True, - tokenize=False + tokenize=False, ) for sample in subset ] # 4. vLLM bulk generate - llm = LLM( - model=args.model, - dtype="bfloat16", - enable_prefix_caching=True - ) + llm = LLM(model=args.model, dtype="bfloat16", enable_prefix_caching=True) sampling_params = SamplingParams( - temperature=args.temperature, - n=args.number_samples, - max_tokens=32768, + temperature=args.temperature, top_p=args.top_p, n=args.number_samples, max_tokens=32768 ) outputs = llm.generate(prompts, sampling_params) diff --git a/scripts/train/debug/grpo_fast.sh b/scripts/train/debug/grpo_fast.sh index 5bf6f791e..13b59000a 100755 --- a/scripts/train/debug/grpo_fast.sh +++ b/scripts/train/debug/grpo_fast.sh @@ -1,11 +1,31 @@ +#!/bin/bash + +python mason.py \ + --task_name grpo_debug_small_active \ + --cluster ai2/jupiter \ + --workspace ai2/olmo-instruct \ + --priority urgent \ + --pure_docker_mode \ + --image michaeln/open_instruct_2.5-rl0 \ + --preemptible \ + --num_nodes 1 \ + --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + --env VLLM_ATTENTION_BACKEND="FLASH_ATTN" \ + --gpus 2 \ + --budget ai2/oe-adapt \ + -- \ uv run python open_instruct/grpo_fast.py \ --dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 64 \ --dataset_mixer_list_splits train \ --dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \ --dataset_mixer_eval_list_splits train \ --max_prompt_token_length 512 \ - --response_length 512 \ - --pack_length 1024 \ + --response_length 2048 \ + --pack_length 4096 \ + --async_steps 4 \ + --inflight_updates \ + --active_fill_completions \ + --truncated_importance_sampling_ratio_cap 2.0 \ --per_device_train_batch_size 1 \ --num_unique_prompts_rollout 8 \ --num_samples_per_prompt_rollout 4 \ @@ -17,20 +37,16 @@ uv run python open_instruct/grpo_fast.py \ --ground_truths_key ground_truth \ --chat_template_name r1_simple_chat_postpend_think \ --learning_rate 3e-7 \ - --total_episodes 200 \ + --total_episodes 1600 \ --deepspeed_stage 2 \ --num_epochs 1 \ --num_learners_per_node 1 \ + --vllm_num_engines 1 \ --vllm_tensor_parallel_size 1 \ - --beta 0.01 \ + --beta 0. \ --seed 3 \ - --local_eval_every 1 \ - --vllm_sync_backend gloo \ - --vllm_gpu_memory_utilization 0.3 \ - --save_traces \ - --vllm_enforce_eager \ + --local_eval_every 100 \ --gradient_checkpointing \ - --single_gpu_mode \ --push_to_hub false \ --system_prompt_override_file scripts/train/debug/cute_debug_system_prompt.txt \ # --with_tracking diff --git a/scripts/train/rlvr/grpo_rlzero.sh b/scripts/train/rlvr/grpo_rlzero.sh new file mode 100755 index 000000000..680bcb658 --- /dev/null +++ b/scripts/train/rlvr/grpo_rlzero.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +# OLMo 3 model +MODEL_NAME_OR_PATH="/weka/oe-adapt-default/michaeln/checkpoints/olmo3-7b-base" +GS_MODEL_NAME="olmo3_7b_base" + +# english only DAPO +# DATASETS="mnoukhov/DAPO-Math-14k-Processed-RLVR 1.0 TTTXXX01/MATH_3000_Filtered 1.0" +DATASETS="saurabh5/DAPO-Math-17k-Processed_filtered_olmo_completions_new_template_filtered 1.0 saurabh5/MATH_3000_Filtered_olmo_completions_new_template_filtered 1.0" +# DATASETS="mnoukhov/deepscaler_20k_medhard_nolatex_rlvr 1.0" +# DATASETS="" + +# math evals +# EVALS="minerva_math_500::hamish_zs_reasoning_deepseek" +EVALS="aime:zs_cot_r1::pass_at_32_2024_dapo,aime:zs_cot_r1::pass_at_32_2025_dapo" + +# AIME 2024, 2025 local evals +LOCAL_EVALS="mnoukhov/aime2024-25-rlvr 1.0 mnoukhov/aime2024-25-rlvr 1.0" +LOCAL_EVAL_SPLITS="test_2024 test_2024 test_2025 test_2025" +# tengmath3k +# EXP_NAME="grpo_deepscaler20k_k8_${GS_MODEL_NAME}" +EXP_NAME="grpo_17kfilter_${GS_MODEL_NAME}" + +cluster=ai2/jupiter + +python mason.py \ + --task_name ${EXP_NAME} \ + --cluster ${cluster} \ + --workspace ai2/olmo-instruct \ + --priority urgent \ + --pure_docker_mode \ + --image michaeln/open_instruct_rlzero \ + --preemptible \ + --num_nodes 9 \ + --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + --env VLLM_ATTENTION_BACKEND="FLASH_ATTN" \ + --gs_model_name $GS_MODEL_NAME \ + --gpus 8 \ + --budget ai2/oe-adapt \ + -- \ +source configs/beaker_configs/ray_node_setup.sh \&\& \ +source configs/beaker_configs/code_api_setup.sh \&\& \ +python open_instruct/grpo_fast.py \ + --exp_name ${EXP_NAME} \ + --beta 0.0 \ + --async_steps 4 \ + --inflight_updates \ + --truncated_importance_sampling_ratio_cap 2.0 \ + --advantage_normalization_type centered \ + --active_fill_completions \ + --no_resample_solve_rate 0.9 \ + --num_samples_per_prompt_rollout 16 \ + --num_unique_prompts_rollout 16 \ + --num_mini_batches 1 \ + --learning_rate 1e-6 \ + --per_device_train_batch_size 1 \ + --kl_estimator kl3 \ + --dataset_mixer_list $DATASETS \ + --dataset_mixer_list_splits train \ + --dataset_mixer_eval_list $LOCAL_EVALS \ + --dataset_mixer_eval_list_splits $LOCAL_EVAL_SPLITS \ + --max_prompt_token_length 2048 \ + --response_length 12000 \ + --pack_length 32768 \ + --model_name_or_path ${MODEL_NAME_OR_PATH} \ + --chat_template_name olmo_thinker_dapo \ + --non_stop_penalty False \ + --temperature 1.0 \ + --total_episodes 512000 \ + --deepspeed_stage 3 \ + --num_learners_per_node 8 \ + --vllm_num_engines 64 \ + --vllm_tensor_parallel_size 1 \ + --lr_scheduler_type constant \ + --apply_verifiable_reward true \ + --seed 1 \ + --local_eval_every 100 \ + --save_freq 100 \ + --checkpoint_state_freq 100 \ + --gradient_checkpointing \ + --with_tracking \ + --vllm_enable_prefix_caching \ + --clip_higher 0.272 \ + --mask_truncated_completions True \ + --oe_eval_max_length 32768 \ + --try_launch_beaker_eval_jobs_on_weka True \ + --eval_priority high \ + --oe_eval_tasks $EVALS \ + --oe_eval_gpu_multiplier 4 \ + --oe_eval_beaker_image michaeln/oe_eval_rlzero diff --git a/uv.lock b/uv.lock index 5b839989f..a6c81c00b 100644 --- a/uv.lock +++ b/uv.lock @@ -1804,7 +1804,7 @@ requires-dist = [ { name = "bitsandbytes", marker = "sys_platform != 'darwin'", specifier = ">=0.44.1" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy", specifier = ">=1.8.13" }, - { name = "deepspeed", specifier = "<0.17.6" }, + { name = "deepspeed", specifier = "<0.17.3" }, { name = "fastapi", marker = "extra == 'code'", specifier = ">=0.100.0" }, { name = "flash-attn", marker = "sys_platform != 'darwin'", specifier = ">=2.8.3" }, { name = "hf-transfer", specifier = ">=0.1.8" },