diff --git a/.gitignore b/.gitignore
index 2a9250276..f0d64d70f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -159,3 +159,4 @@ dmypy.json
 cache/
 local_dataset_cache/
 scratch/
+vllm_olmo2.5/
diff --git a/Dockerfile b/Dockerfile
index 2f6dae3c4..098fa0b95 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -78,7 +78,7 @@ COPY configs configs
 COPY scripts scripts
 COPY mason.py mason.py
 # Copy oe-eval-internal if it exists (wildcard pattern won't fail if missing)
-COPY oe-eval-interna[l] oe-eval-internal/
+COPY oe-eval-internal oe-eval-internal
 COPY open_instruct open_instruct
 
 # Add build arguments for git information
diff --git a/Makefile b/Makefile
index 55ce7d63e..17f2ffa37 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,4 @@
-.PHONY: style quality
+.PHONY: style quality docker
 
 # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
 export PYTHONPATH = open_instruct
@@ -16,3 +16,13 @@ style-check:   ## *fail* if anything needs rewriting
 
 quality-check: ## *fail* if any rewrite was needed
 	uv run ruff check --exit-non-zero-on-fix $(check_dirs)
+
+setup:
+	git clone -b shanea/olmo2-retrofit https://github.com/2015aroras/vllm.git vllm_olmo2.5
+
+docker:
+	DOCKER_BUILDKIT=1 docker build -f Dockerfile --build-arg UV_CACHE_DIR=$(UV_CACHE_DIR) -t open_instruct_rlzero .
+	# if you are internally at AI2, you can create an image like this:
+	$(eval beaker_user := $(shell beaker account whoami --format json | jq -r '.[0].name'))
+	beaker image delete $(beaker_user)/open_instruct_rlzero
+	beaker image create open_instruct_rlzero -n open_instruct_rlzero -w ai2/$(beaker_user)
diff --git a/generate_olmo25.sh b/generate_olmo25.sh
new file mode 100755
index 000000000..ea80ed685
--- /dev/null
+++ b/generate_olmo25.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+MODEL_NAME_OR_PATH="/weka/oe-training-default/ai2-llm/checkpoints/tylerr/long-context/olmo25_7b_lc_64k_6T_M100B_round5-sparkle_6634-pre_s2pdf_gzip2080_cweN-yake-all-olmo_packing_yarn-fullonly_50B-fb13a737/step11921-hf"
+# DATASET="mnoukhov/DAPO-Math-14k-Processed-RLVR"
+DATASET="TTTXXX01/MATH_3000_Filtered"
+EXP_NAME="generate_olmo25_teng3k"
+
+python mason.py \
+    --task_name ${EXP_NAME} \
+    --cluster ai2/jupiter \
+    --image ${1:-michaeln/open_instruct_olmo2_retrofit} \
+    --workspace ai2/tulu-thinker \
+    --priority high \
+    --pure_docker_mode \
+    --preemptible \
+    --gpus 2 \
+    --num_nodes 1 \
+    --max_retries 0 \
+    --budget ai2/oe-adapt \
+    --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
+    --env VLLM_ATTENTION_BACKEND="FLASH_ATTN" \
+    -- \
+python scripts/data/rlvr/filtering_vllm.py \
+    --model $MODEL_NAME_OR_PATH \
+    --dataset $DATASET \
+    --split train \
+    --temperature 0.7 \
+    --top_p 0.95 \
+    --offset 0 \
+    --size 100000 \
+    --chat_template olmo_thinker_r1_style_nochat \
+    --output-file filtered_datasets/olmo25_7b_lc_dapo.jsonl \
+    --number_samples 16
diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py
index 52a414042..56188e255 100644
--- a/open_instruct/dataset_transformation.py
+++ b/open_instruct/dataset_transformation.py
@@ -442,6 +442,33 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai
         "{% endif %}"
         "{% endfor %}"
     ),
+    "olmo_thinker_r1_style_nochat": (
+        "Solve the following math problem step by step. "
+        "Reason about the question in   tags "
+        "then provide the final answer in   tags "
+        "so the full response is  reasoning process here  "
+        " answer here ."
+        "\n\n"
+        "{% for message in messages %}"
+        "{{ '\n\n' if not loop.first else '' }}"
+        "{{ message['content'] + '\n' }}"
+        "{% if loop.last and add_generation_prompt %}"
+        "{{ 'Solving step by step\n' }}"
+        "{% endif %}"
+        "{% endfor %}"
+    ),
+    "olmo_thinker_dapo": (
+        "Solve the following math problem step by step. "
+        "The last line of your response should be the answer to the problem in form Answer: $Answer (without quotes) where $Answer is the answer to the problem."
+        "\n\n"
+        "{% for message in messages %}"
+        "{{ '\n\n' if not loop.first else '' }}"
+        "{{ message['content'] + '\n' }}"
+        "{% if loop.last and add_generation_prompt %}"
+        "{{ '\nRemember to put your answer on its own line after \"Answer:\"' }}"
+        "{% endif %}"
+        "{% endfor %}"
+    ),
     # template is taken from https://arxiv.org/abs/2501.12948.
     "r1_simple_chat": (
         "A conversation between User and Assistant. "
diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py
index 7a816ad68..e9c7ac758 100644
--- a/open_instruct/grpo_fast.py
+++ b/open_instruct/grpo_fast.py
@@ -282,6 +282,13 @@ class Args:
     use_vllm_logprobs: bool = False
     """whether to use vLLM's logprobs for training instead of calculating them via forward pass"""
 
+    active_fill_completions: bool = False
+    """Whether to refill the batch with *new prompts/completions* after filtering."""
+    active_fill_max_attempts: int = 3
+    """How many times to attempt to fill"""
+    no_resample_solve_rate: float = None
+    """If a prompt is solved at more than this rate across K completions, don't resample it for next epoch"""
+
     # Reward
     # -- r1 style format reward
     apply_r1_style_format_reward: bool = False
@@ -411,6 +418,8 @@ class Args:
     """the max generation length for evaluation for oe-eval"""
     oe_eval_beaker_image: Optional[str] = None
     """the docker image for evaluation for oe-eval"""
+    oe_eval_gpu_multiplier: Optional[int] = 1
+    """gpu mulitplier for eval jobs"""
     eval_priority: Literal["low", "normal", "high", "urgent"] = "normal"
     """the priority of auto-launched evaluation jobs"""
 
@@ -574,16 +583,21 @@ def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None
         self.epoch_number = 0
         self.rng = np.random.default_rng(seed)
         self.rng.shuffle(self.data)
+        self.exclude_list = []
 
         # Ensure the effective dataset size is divisible by batch_size
-        self.effective_size = len(self.data) - (len(self.data) % batch_size)
+        self._update_effective_size()
 
     def __iter__(self) -> Iterator[List[int]]:
         return self
 
     def __next__(self) -> List[int]:
+        if self.effective_size == 0:
+            raise StopIteration("No data available to sample from the iterator.")
+
         if self.index >= self.effective_size:
             self.index = 0
+            self._update_effective_size()
             self.epoch_number += 1
             self.rng.shuffle(self.data)
 
@@ -593,6 +607,10 @@ def __next__(self) -> List[int]:
 
         return batch
 
+    def exclude_indices(self, exclude_list: List) -> None:
+        """Exclude provided data points from future sampling."""
+        self.exclude_list.extend(exclude_list)
+
     def get_state(self) -> Dict[str, Any]:
         """Get the current state of the iterator for checkpointing."""
         return {
@@ -600,6 +618,7 @@ def get_state(self) -> Dict[str, Any]:
             "epoch_number": self.epoch_number,
             "data": self.data.copy(),
             "rng_state": self.rng.bit_generator.state,
+            "exclude_list": self.exclude_list.copy(),
         }
 
     def set_state(self, state: Dict[str, Any]) -> None:
@@ -608,6 +627,18 @@ def set_state(self, state: Dict[str, Any]) -> None:
         self.epoch_number = state["epoch_number"]
         self.data = state["data"].copy()
         self.rng.bit_generator.state = state["rng_state"]
+        self.exclude_list = state.get("exclude_list", [])
+        self._update_effective_size()
+
+    def _update_effective_size(self) -> None:
+        if self.exclude_list:
+            mask = ~np.isin(self.data, self.exclude_list)
+            if mask.all():
+                return
+
+            self.data = self.data[mask]
+            self.exclude_list = []
+        self.effective_size = len(self.data) - (len(self.data) % self.batch_size)
 
 
 @ray.remote(num_gpus=1)
@@ -1356,6 +1387,7 @@ def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url
                 args.gs_bucket_path,
                 args.eval_priority,
                 args.oe_eval_beaker_image,
+                args.oe_eval_gpu_multiplier,
             )
 
 
@@ -1579,6 +1611,7 @@ def accumulate_inference_batches(
     all_ground_truths = []
     all_datasets = []
     all_raw_queries = []
+    all_indices = []
     for i in tqdm(
         range(num_prompts),
         total=num_prompts,
@@ -1605,6 +1638,7 @@ def accumulate_inference_batches(
         all_ground_truths.append(ground_truth)
         all_datasets.append(dataset)
         all_raw_queries.append(raw_query)
+        all_indices.append(result.dataset_index)
 
     # Combine all results into a single GenerationResult
     combined_responses = []
@@ -1686,13 +1720,13 @@ def accumulate_inference_batches(
     if actor_manager is not None:
         ray.get(actor_manager.report_token_statistics.remote(accumulated_stats))
 
-    # Note: We don't have dataset_indices here, but they're not needed for the returned batch
+    # Preserve dataset indices so downstream filtering keeps track of original prompts
     batch = Batch(
         queries=all_queries,
         ground_truths=all_ground_truths,
         datasets=all_datasets,
         raw_queries=all_raw_queries,
-        indices=None,  # Not meaningful for combined results
+        indices=all_indices,
     )
     return combined_result, batch, prompt_lengths, response_lengths
 
@@ -1709,181 +1743,351 @@ def data_preparation_thread(
     resume_training_step: int,
     actor_manager=None,
     model_dims: utils.ModelDims = None,
+    train_iterator: ShufflingIterator = None,
 ):
-    for training_step in range(resume_training_step, num_training_steps + 1):
-        # Streaming accumulation: collect results as they arrive
-        with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer:
-            result, batch, prompt_lengths, response_lengths = accumulate_inference_batches(
-                inference_results_Q,
-                pending_queries_map,
-                args,
-                generation_config,
-                num_prompts=args.num_unique_prompts_rollout,
-                model_dims=model_dims,
-                actor_manager=actor_manager,
-            )
-            if isinstance(result, ShutdownSentinel):
-                logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting")
-                return
-
-        getting_response_time = timer.duration
-
-        # ------------------------------------------------------------------------------------------------
-        # Pack sequences
-        if args.num_samples_per_prompt_rollout > 1:
-            batch = Batch(
-                queries=repeat_each(batch.queries, args.num_samples_per_prompt_rollout),
-                ground_truths=repeat_each(batch.ground_truths, args.num_samples_per_prompt_rollout),
-                datasets=repeat_each(batch.datasets, args.num_samples_per_prompt_rollout),
-                raw_queries=repeat_each(batch.raw_queries, args.num_samples_per_prompt_rollout),
-                indices=repeat_each(batch.indices, args.num_samples_per_prompt_rollout) if batch.indices else None,
-            )
-            good_outputs = [
-                len(result.request_info.tool_outputs[i]) > 0
-                and result.request_info.tool_calleds[i]
-                and not result.request_info.timeouts[i]
-                and not result.request_info.tool_errors[i]
-                for i in range(len(result.request_info.tool_outputs))
-            ]
-        for i in range(len(result.finish_reasons)):
-            if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0:
-                result.responses[i].append(tokenizer.eos_token_id)
-                result.masks[i].append(1)
-                result.logprobs[i].append(float("nan"))
-        with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
-            decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)
-            decoded_queries = batch.raw_queries
-            stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len(
-                result.finish_reasons
-            )
+    def combine_reward_metrics(metric_records: list[tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
+        if not metric_records:
+            return {}
+
+        buckets: Dict[str, list[tuple[Any, int]]] = {}
+        for metrics, weight in metric_records:
+            if not metrics:
+                continue
+            for key, value in metrics.items():
+                buckets.setdefault(key, []).append((value, weight))
+
+        combined: Dict[str, Any] = {}
+        for key, records in buckets.items():
+            sample_value = records[0][0]
+            if isinstance(sample_value, np.ndarray):
+                combined[key] = np.concatenate([np.asarray(value) for value, _ in records])
+            elif isinstance(sample_value, (list, tuple)):
+                concatenated: list[Any] = []
+                for value, _ in records:
+                    concatenated.extend(list(value))
+                combined[key] = concatenated
+            elif isinstance(sample_value, (int, float, bool, np.integer, np.floating)):
+                total_weight = sum(weight for _, weight in records)
+                if total_weight == 0:
+                    combined[key] = float(sample_value)
+                else:
+                    combined[key] = sum(float(value) * weight for value, weight in records) / total_weight
+            else:
+                # Fallback: keep the latest value if aggregation strategy is unclear.
+                combined[key] = records[-1][0]
+        return combined
 
-        with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"):
-            scores, reward_metrics = asyncio.run(
-                reward_fn(
-                    result.responses,
-                    decoded_responses,
-                    batch,
-                    result.finish_reasons,
-                    result.request_info,
-                    decoded_queries,
+    for training_step in range(resume_training_step, num_training_steps + 1):
+        per_prompt_rollout = args.num_samples_per_prompt_rollout
+        target_prompt_count = args.num_unique_prompts_rollout
+
+        aggregated_scores: list[np.ndarray] = []
+        aggregated_advantages: list[np.ndarray] = []
+        aggregated_reward_metrics: list[tuple[Dict[str, Any], int]] = []
+        aggregated_responses: list[List[int]] = []
+        aggregated_masks: list[List[int]] = []
+        aggregated_finish_reasons: list[str] = []
+        aggregated_vllm_logprobs: list[List[float]] = []
+        aggregated_batches: list[Batch] = []
+        aggregated_prompt_lengths: list[int] = []
+        aggregated_response_lengths: list[int] = []
+        aggregated_num_calls: list[int] = []
+        aggregated_timeouts: list[int] = []
+        aggregated_tool_errors: list[str] = []
+        aggregated_tool_outputs: list[str] = []
+        aggregated_tool_runtimes: list[float] = []
+        aggregated_tool_calleds: list[bool] = []
+
+        total_prompt_tokens = 0
+        total_response_tokens = 0
+        total_generation_time = 0.0
+        earliest_start_time: Optional[float] = None
+
+        total_stop_completions = 0
+        total_completions = 0
+
+        total_prompts_kept = 0
+        total_prompts_filtered = 0
+        solved_prompts_filtered = 0
+        zero_prompts_filtered = 0
+        noresample_prompts_filtered = 0
+        prompts_to_request = target_prompt_count
+        getting_response_time = 0.0
+
+        fill_iteration = 0
+        while prompts_to_request > 0:
+            fill_iteration += 1
+            # Streaming accumulation: collect results as they arrive
+            with Timer(f"🚀 [Data Preparation Thread] Getting {prompts_to_request} response ids") as timer:
+                result_step, batch_step, prompt_lengths_step, response_lengths_step = accumulate_inference_batches(
+                    inference_results_Q,
+                    pending_queries_map,
+                    args,
+                    generation_config,
+                    num_prompts=prompts_to_request,
+                    model_dims=model_dims,
+                    actor_manager=actor_manager,
                 )
-            )
-            scores = np.array(scores)
-            scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout)
-            mean_grouped_rewards = scores_per_prompt.mean(axis=-1)
-            mean_grouped_rewards = np.repeat(mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0)
-            std_grouped_rewards = scores_per_prompt.std(axis=-1)
-            std_grouped_rewards = np.repeat(std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0)
-            if args.advantage_normalization_type == "standard":
-                advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)
-            elif args.advantage_normalization_type == "centered":
-                advantages = scores - mean_grouped_rewards
-            else:
-                raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}")
-
-        with Timer("📦 [Data Preparation Thread] Filtering sequences"):
-            # Here we get the max possible score for each prompt, and see how many prompts are unsolved
-            max_possible_score = 0
-            if args.apply_verifiable_reward:
-                max_possible_score += args.verification_reward
-            if args.apply_r1_style_format_reward and args.additive_format_reward:
-                max_possible_score += args.r1_style_format_reward
-            unsolved_batch_size_ratio = ((scores != max_possible_score) > 0).sum() / len(scores)
-            # In GRPO, if the std of grouped rewards is 0, then there is zero gradient for the batch
-            # of args.num_samples_per_prompt_rollout responses, so we need to filter out those batches
-            non_zero_std_mask = scores_per_prompt.std(axis=-1) != 0
-            real_batch_size_ratio = non_zero_std_mask.sum() * args.num_samples_per_prompt_rollout / len(scores)
-            expanded_mask = np.repeat(non_zero_std_mask, args.num_samples_per_prompt_rollout)
-            non_zero_gradient_index = np.where(expanded_mask)[0]
-
-            # Log zero-gradient filtering statistics
-            num_zero_std_prompts = (~non_zero_std_mask).sum()
-            num_filtered_responses = len(scores) - len(non_zero_gradient_index)
-            if num_filtered_responses > 0:
-                logger.info(
-                    f"[Zero-gradient filtering] Filtered {num_zero_std_prompts} prompts with zero std "
-                    f"({num_filtered_responses} responses). Retention rate: {len(non_zero_gradient_index) / len(scores):.2%}"
+                if isinstance(result_step, ShutdownSentinel):
+                    logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting")
+                    return
+
+            getting_response_time += timer.duration
+
+            if args.num_samples_per_prompt_rollout > 1:
+                batch_step = Batch(
+                    queries=repeat_each(batch_step.queries, per_prompt_rollout),
+                    ground_truths=repeat_each(batch_step.ground_truths, per_prompt_rollout),
+                    datasets=repeat_each(batch_step.datasets, per_prompt_rollout),
+                    raw_queries=repeat_each(batch_step.raw_queries, per_prompt_rollout),
+                    indices=repeat_each(batch_step.indices, per_prompt_rollout) if batch_step.indices else None,
                 )
 
-            advantages = advantages[non_zero_gradient_index]
-            original_batch_size = len(scores)
-            scores = scores[non_zero_gradient_index]
-            responses = [result.responses[i] for i in non_zero_gradient_index]
-            masks = [result.masks[i] for i in non_zero_gradient_index]
-            batch = batch[non_zero_gradient_index.tolist()]
-            finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index]
-            vllm_logprobs = [result.logprobs[i] for i in non_zero_gradient_index]
-            if args.mask_truncated_completions:
-                stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"])
-                num_truncated = len(finish_reasons) - len(stop_idxes)
-                if num_truncated > 0:
+            for i in range(len(result_step.finish_reasons)):
+                if result_step.finish_reasons[i] == "stop" and len(result_step.responses[i]) == 0:
+                    result_step.responses[i].append(tokenizer.eos_token_id)
+                    result_step.masks[i].append(1)
+                    result_step.logprobs[i].append(float("nan"))
+
+            with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
+                decoded_responses_step = tokenizer.batch_decode(result_step.responses, skip_special_tokens=True)
+                decoded_queries_step = batch_step.raw_queries
+
+            with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"):
+                scores_step, reward_metrics_step = asyncio.run(
+                    reward_fn(
+                        result_step.responses,
+                        decoded_responses_step,
+                        batch_step,
+                        result_step.finish_reasons,
+                        result_step.request_info,
+                        decoded_queries_step,
+                    )
+                )
+                scores_step = np.array(scores_step)
+                scores_per_prompt_step = scores_step.reshape(-1, per_prompt_rollout)
+                mean_grouped_rewards = scores_per_prompt_step.mean(axis=-1)
+                mean_grouped_rewards = np.repeat(mean_grouped_rewards, per_prompt_rollout, axis=0)
+                std_grouped_rewards = scores_per_prompt_step.std(axis=-1)
+                std_grouped_rewards = np.repeat(std_grouped_rewards, per_prompt_rollout, axis=0)
+                if args.advantage_normalization_type == "standard":
+                    advantages_step = (scores_step - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)
+                elif args.advantage_normalization_type == "centered":
+                    advantages_step = scores_step - mean_grouped_rewards
+                else:
+                    raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}")
+
+            # Filtering (zero-gradient, truncated completions)
+            with Timer("📦 [Data Preparation Thread] Filtering sequences"):
+                max_possible_score = 0
+                if args.apply_verifiable_reward:
+                    max_possible_score += args.verification_reward
+                if args.apply_r1_style_format_reward and args.additive_format_reward:
+                    max_possible_score += args.r1_style_format_reward
+
+                if args.no_resample_solve_rate is not None:
+                    percent_solved_prompt = scores_per_prompt_step.mean(axis=-1) / max_possible_score
+                    solved_mask = percent_solved_prompt >= args.no_resample_solve_rate
+                    solved_expanded_mask = np.repeat(solved_mask, per_prompt_rollout)
+                    solved_indices = np.where(solved_expanded_mask)[0]
+                    solved_prompt_indices = batch_step[solved_indices.tolist()].indices
+                    train_iterator.exclude_indices(list(set(solved_prompt_indices)))
+                    noresample_prompts_filtered += solved_mask.sum()
+
+                non_zero_std_mask = scores_per_prompt_step.std(axis=-1) != 0
+                expanded_mask = np.repeat(non_zero_std_mask, per_prompt_rollout)
+                non_zero_indices = np.where(expanded_mask)[0]
+
+                num_zero_std_prompts = (~non_zero_std_mask).sum()
+                num_filtered_responses = len(scores_step) - len(non_zero_indices)
+
+                # Count groups with all zero rewards
+                zero_prompts_filtered += (scores_per_prompt_step == 0).all(axis=-1).sum()
+                solved_prompts_filtered += (scores_per_prompt_step == max_possible_score).all(axis=-1).sum()
+                total_prompts_filtered += num_zero_std_prompts
+
+                if num_filtered_responses > 0:
                     logger.info(
-                        f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. "
-                        f"Retention rate: {len(stop_idxes) / len(finish_reasons):.2%}"
+                        f"[Zero-gradient filtering] Filtered {num_zero_std_prompts} prompts with zero std "
+                        f"({num_filtered_responses} responses). Retention rate: {len(non_zero_indices) / len(scores_step):.2%}"
                     )
-                scores = scores[stop_idxes]
-                advantages = advantages[stop_idxes]
-                responses = [responses[i] for i in stop_idxes]
-                masks = [masks[i] for i in stop_idxes]
-                batch = batch[stop_idxes.tolist()]
-                finish_reasons = [finish_reasons[i] for i in stop_idxes]
-                vllm_logprobs = [vllm_logprobs[i] for i in stop_idxes]
-
-            if args.fill_completions:
-                with Timer("⏱ [Data Preparation Thread] Refill completions"):
-                    current_batch_size = len(scores)
-                    original_prompt_cnt = original_batch_size // args.num_samples_per_prompt_rollout
-                    current_prompt_cnt = current_batch_size // args.num_samples_per_prompt_rollout
-                    need_to_fill_prompt = original_prompt_cnt - current_prompt_cnt
-                    k = args.num_samples_per_prompt_rollout
-
-                    if need_to_fill_prompt > 0 and current_prompt_cnt > 0:
-                        scores_matrix = scores.reshape(current_prompt_cnt, k)
-                        stds = scores_matrix.std(axis=1) + 1e-8
-                        probs = stds / stds.sum()
 
+                advantages_step = advantages_step[non_zero_indices]
+                scores_step = scores_step[non_zero_indices]
+                responses_step = [result_step.responses[i] for i in non_zero_indices]
+                masks_step = [result_step.masks[i] for i in non_zero_indices]
+                batch_step = batch_step[non_zero_indices.tolist()]
+                finish_reasons_step = [result_step.finish_reasons[i] for i in non_zero_indices]
+                vllm_logprobs_step = [result_step.logprobs[i] for i in non_zero_indices]
+
+                prompt_indices_kept = np.where(non_zero_std_mask)[0]
+                prompt_lengths_kept = [prompt_lengths_step[i] for i in prompt_indices_kept]
+                response_lengths_kept = [response_lengths_step[i] for i in non_zero_indices]
+
+                num_calls_kept = [result_step.request_info.num_calls[i] for i in non_zero_indices]
+                timeouts_kept = [result_step.request_info.timeouts[i] for i in non_zero_indices]
+                tool_errors_kept = [result_step.request_info.tool_errors[i] for i in non_zero_indices]
+                tool_outputs_kept = [result_step.request_info.tool_outputs[i] for i in non_zero_indices]
+                tool_runtimes_kept = [result_step.request_info.tool_runtimes[i] for i in non_zero_indices]
+                tool_calleds_kept = [result_step.request_info.tool_calleds[i] for i in non_zero_indices]
+
+                if args.mask_truncated_completions:
+                    truncated_mask = np.array([reason != "stop" for reason in finish_reasons_step], dtype=bool)
+                    num_truncated = int(truncated_mask.sum())
+                    if num_truncated > 0:
+                        scores_step = scores_step.copy()
+                        scores_step[truncated_mask] = 0.0
+                        advantages_step = advantages_step.copy()
+                        advantages_step[truncated_mask] = 0.0
+                        for idx in np.where(truncated_mask)[0]:
+                            masks_step[idx] = [0 for _ in masks_step[idx]]
                         logger.info(
-                            f"[Refill completions] Need to fill {need_to_fill_prompt} prompts to maintain batch size. "
-                            f"Original: {original_prompt_cnt}, Current: {current_prompt_cnt}"
+                            f"[Truncated completions masking] Zeroed scores/advantages and response masks for {num_truncated} responses that didn't finish with 'stop'."
                         )
-
-                        sampled_prompt_ids = np.random.choice(
-                            current_prompt_cnt, size=need_to_fill_prompt, replace=True, p=probs
+            aggregated_scores.append(scores_step)
+            aggregated_advantages.append(advantages_step)
+            aggregated_reward_metrics.append((reward_metrics_step, len(scores_step)))
+            aggregated_responses.extend(responses_step)
+            aggregated_masks.extend(masks_step)
+            aggregated_finish_reasons.extend(finish_reasons_step)
+            aggregated_vllm_logprobs.extend(vllm_logprobs_step)
+            aggregated_batches.append(batch_step)
+            aggregated_prompt_lengths.extend(prompt_lengths_kept)
+            aggregated_response_lengths.extend(response_lengths_kept)
+            aggregated_num_calls.extend(num_calls_kept)
+            aggregated_timeouts.extend(timeouts_kept)
+            aggregated_tool_errors.extend(tool_errors_kept)
+            aggregated_tool_outputs.extend(tool_outputs_kept)
+            aggregated_tool_runtimes.extend(tool_runtimes_kept)
+            aggregated_tool_calleds.extend(tool_calleds_kept)
+
+            total_stop_completions += sum(int(reason == "stop") for reason in finish_reasons_step)
+            total_completions += len(finish_reasons_step)
+
+            if result_step.token_statistics is not None:
+                total_prompt_tokens += result_step.token_statistics.num_prompt_tokens
+                total_response_tokens += result_step.token_statistics.num_response_tokens
+                total_generation_time += result_step.token_statistics.generation_time
+                if result_step.token_statistics.earliest_start_time is not None:
+                    if earliest_start_time is None:
+                        earliest_start_time = result_step.token_statistics.earliest_start_time
+                    else:
+                        earliest_start_time = min(
+                            earliest_start_time, result_step.token_statistics.earliest_start_time
                         )
 
-                        sampled_indices = []
-                        for pid in sampled_prompt_ids:
-                            start = pid * k
-                            sampled_indices.extend(range(start, start + k))
-
-                        advantages = np.concatenate([advantages, advantages[sampled_indices]])
-                        scores = np.concatenate([scores, scores[sampled_indices]])
-                        responses += [responses[i] for i in sampled_indices]
-                        masks += [masks[i] for i in sampled_indices]
+            prompts_kept_this_iter = len(scores_step) // per_prompt_rollout if per_prompt_rollout > 0 else 0
+            total_prompts_kept += prompts_kept_this_iter
 
-                        sampled_batch = batch[sampled_indices]
+            if not args.active_fill_completions:
+                break
 
-                        batch = Batch(
-                            queries=batch.queries + sampled_batch.queries,
-                            ground_truths=batch.ground_truths + sampled_batch.ground_truths,
-                            datasets=batch.datasets + sampled_batch.datasets,
-                            indices=batch.indices + sampled_batch.indices if batch.indices is not None else None,
-                        )
+            prompts_remaining = target_prompt_count - total_prompts_kept
+            if prompts_remaining <= 0:
+                break
 
-                        finish_reasons += [finish_reasons[i] for i in sampled_indices]
-                        vllm_logprobs += [vllm_logprobs[i] for i in sampled_indices]
+            if prompts_kept_this_iter == 0 and fill_iteration > args.active_fill_max_attempts:
+                logger.warning(
+                    "[Active fill completions] Unable to collect non-zero advantage prompts in iteration %d; "
+                    "set as max by args.active_fill_max_attempts, proceeding with existing batch of size %d.",
+                    fill_iteration,
+                    len(aggregated_responses),
+                )
+                break
 
-                        logger.info(
-                            f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses"
-                        )
+            prompts_to_request = prompts_remaining
+
+        # Build aggregated containers from collected data
+        if aggregated_batches:
+            queries: list[List[int]] = []
+            ground_truths: list[List[int]] = []
+            datasets: list[str] = []
+            raw_queries_list: Optional[list[str]] = None
+            indices_list: Optional[list[int]] = None
+
+            raw_queries_present = aggregated_batches[0].raw_queries is not None
+            indices_present = aggregated_batches[0].indices is not None
+            if raw_queries_present:
+                raw_queries_list = []
+            if indices_present:
+                indices_list = []
+
+            for batch_step in aggregated_batches:
+                queries.extend(batch_step.queries)
+                ground_truths.extend(batch_step.ground_truths)
+                datasets.extend(batch_step.datasets)
+                if raw_queries_present:
+                    assert batch_step.raw_queries is not None
+                    raw_queries_list.extend(batch_step.raw_queries)
+                if indices_present:
+                    assert batch_step.indices is not None
+                    indices_list.extend(batch_step.indices)
 
-            # Count groups with all zero rewards
-            all_zero_groups = (scores_per_prompt == 0).all(axis=-1).sum()
-            total_groups = len(scores_per_prompt)
-            logger.info(
-                f"[Reward Summary] Groups with all zero rewards: {all_zero_groups}/{total_groups} "
-                f"({all_zero_groups / total_groups:.1%})"
+            batch = Batch(
+                queries=queries,
+                ground_truths=ground_truths,
+                datasets=datasets,
+                raw_queries=raw_queries_list if raw_queries_present else None,
+                indices=indices_list if indices_present else None,
             )
+        else:
+            batch = Batch(queries=[], ground_truths=[], datasets=[], raw_queries=[], indices=[])
+
+        prompt_lengths = aggregated_prompt_lengths
+        response_lengths = aggregated_response_lengths
+
+        token_statistics = TokenStatistics(
+            num_prompt_tokens=total_prompt_tokens,
+            num_response_tokens=total_response_tokens,
+            generation_time=total_generation_time,
+            earliest_start_time=earliest_start_time,
+        )
+        result = GenerationResult(
+            responses=aggregated_responses,
+            finish_reasons=aggregated_finish_reasons,
+            masks=aggregated_masks,
+            request_info=RequestInfo(
+                num_calls=aggregated_num_calls,
+                timeouts=aggregated_timeouts,
+                tool_errors=aggregated_tool_errors,
+                tool_outputs=aggregated_tool_outputs,
+                tool_runtimes=aggregated_tool_runtimes,
+                tool_calleds=aggregated_tool_calleds,
+            ),
+            token_statistics=token_statistics,
+            logprobs=aggregated_vllm_logprobs,
+        )
+
+        good_outputs = [
+            len(result.request_info.tool_outputs[i]) > 0
+            and result.request_info.tool_calleds[i]
+            and not result.request_info.timeouts[i]
+            and not result.request_info.tool_errors[i]
+            for i in range(len(result.request_info.tool_outputs))
+        ]
+        stop_rate = (total_stop_completions / total_completions) if total_completions > 0 else 0
+
+        scores = np.concatenate(aggregated_scores) if aggregated_scores else np.array([])
+        advantages = np.concatenate(aggregated_advantages) if aggregated_advantages else np.array([])
+        reward_metrics = combine_reward_metrics(aggregated_reward_metrics)
+
+        if len(scores) == 0:
+            logger.warning(f"No responses with non-zero advantages in batch {training_step}.")
+
+        max_possible_score = 0
+        if args.apply_verifiable_reward:
+            max_possible_score += args.verification_reward
+        if args.apply_r1_style_format_reward and args.additive_format_reward:
+            max_possible_score += args.r1_style_format_reward
+
+        unsolved_batch_size_ratio = ((scores != max_possible_score) > 0).sum() / len(scores)
+        original_batch_size = target_prompt_count * per_prompt_rollout
+        real_batch_size_ratio = len(scores) / original_batch_size if original_batch_size > 0 else 0
+
+        responses = aggregated_responses
+        masks = aggregated_masks
+        finish_reasons = aggregated_finish_reasons
+        vllm_logprobs = aggregated_vllm_logprobs
 
         with Timer("📦 [Data Preparation Thread] Packing sequences"):
             packed_sequences = pack_sequences(
@@ -2004,17 +2208,15 @@ def data_preparation_thread(
                 np.array([]) if np.all(scores == max_possible_score) else np.array(sequence_lengths[scores == 0])
             )
 
-            # Use the already calculated reward summary metrics for wandb
-            all_zero_groups_ratio = all_zero_groups / total_groups if total_groups > 0 else 0
-
             metrics = {
                 "scores": np.array(scores).mean(),
                 "real_batch_size_ratio": real_batch_size_ratio,
                 "unsolved_batch_size_ratio": unsolved_batch_size_ratio,
                 "packed_ratio": len(packed_sequences.query_responses) / len(responses) if len(responses) > 0 else 0,
-                "val/all_zero_reward_groups": all_zero_groups,
-                "val/all_zero_reward_groups_ratio": all_zero_groups_ratio,
-                "val/total_reward_groups": total_groups,
+                "val/zero_filtered_groups": zero_prompts_filtered,
+                "val/solved_filtered_groups": solved_prompts_filtered,
+                "val/total_filtered_groups": total_prompts_filtered,
+                "val/noresample_filtered_groups": noresample_prompts_filtered,
                 "val/sequence_lengths": sequence_lengths.mean(),
                 "val/sequence_lengths_min": sequence_lengths.min(),
                 "val/sequence_lengths_max": sequence_lengths.max(),
@@ -2473,7 +2675,7 @@ def one_training_step(
             )
 
     save_time = 0
-    if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1):
+    if args.save_freq > 0 and (training_step % args.save_freq == 0 or (training_step == 1 and args.eval_on_step_0)):
         with Timer("[Main Thread] 🗡️ Saving model") as timer:
             checkpoint_dir = f"{args.output_dir}_checkpoints"
             step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
@@ -2883,6 +3085,7 @@ def run_training(
         resume_training_step,
         actor_manager,
         model_dims,
+        iter_dataloader,
     )
 
     def health_check_fn():
@@ -2912,6 +3115,7 @@ def health_check_fn():
     else:
         num_total_tokens = 0
 
+    filtered_prompts_count = 0
     training_start_time = time.perf_counter()  # Track overall training start time
     for training_step in range(resume_training_step, args.num_training_steps + 1):
         start_time = time.perf_counter()
@@ -2947,6 +3151,23 @@ def health_check_fn():
             generation_configs["train"],
             is_eval=False,
         )
+        i = 0.1
+        while filtered_prompts_count > args.num_unique_prompts_rollout:
+            another_batch = next_batch(next(iter_dataloader), train_dataset)
+            # little hack to make sure that we don't have the same training step, otherwise vllm requests can break
+            # putting both batches in the same training step might also break if we accidentally sample the same dataset index twice
+            split_and_insert_batch(
+                another_batch,
+                iter_dataloader.epoch_number,
+                training_step + i,
+                pending_queries_map,
+                param_prompt_Q,
+                generation_configs["train"],
+                is_eval=False,
+            )
+            filtered_prompts_count -= args.num_unique_prompts_rollout
+            i += 0.1
+
         if (
             training_step % args.local_eval_every == 0
             and eval_batch is not None
@@ -2962,12 +3183,14 @@ def health_check_fn():
                 is_eval=True,
             )
 
-        collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths = (
+        (collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths) = (
             load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn)
         )
         if collated_data is None:
             continue
 
+        filtered_prompts_count += data_thread_metrics["val/total_filtered_groups"]
+
         for metrics_Q in [generate_metrics_Q, weight_sync_metrics_Q]:
             try:
                 data_thread_metrics |= metrics_Q.get_nowait()
diff --git a/open_instruct/utils.py b/open_instruct/utils.py
index 38697c67e..d878170a1 100644
--- a/open_instruct/utils.py
+++ b/open_instruct/utils.py
@@ -1149,6 +1149,7 @@ def launch_ai2_evals_on_weka(
     gs_bucket_path: Optional[str] = None,
     eval_priority: Optional[str] = "normal",
     beaker_image: Optional[str] = None,
+    oe_eval_gpu_multiplier: Optional[int] = 1,
 ) -> None:
     weka_cluster = "ai2/saturn ai2/neptune"
     gcp_cluster = "ai2/augusta"
@@ -1178,6 +1179,7 @@ def launch_ai2_evals_on_weka(
 
     command = f"""\
 python scripts/submit_eval_jobs.py \
+--gpu_multiplier {oe_eval_gpu_multiplier} \
 --model_name {leaderboard_name} \
 --location {path} \
 --cluster {cluster} \
diff --git a/pyproject.toml b/pyproject.toml
index 8dca0b159..b8af0df55 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,7 +10,7 @@ dependencies = [
     "bitsandbytes>=0.44.1; platform_system != 'Darwin'",
     "datasets>=4.0.0",
     "debugpy>=1.8.13",
-    "deepspeed<0.17.6",  # version 17.6 and up break our training code right now
+    "deepspeed<0.17.3",  # version 17.6 and up break our training code right now
     "hf-transfer>=0.1.8",
     "litellm>=1.72.0,<1.75.2",  # avoid needing backoff https://github.com/BerriAI/litellm/issues/13827
     "matplotlib>=3.9.3",
@@ -95,6 +95,7 @@ target-version = ['py310']
 
 [tool.isort]
 known_first_party = ["open_instruct"]
+known-third-party = ["wandb"]
 profile = "black"
 src_paths = ["open_instruct"]
 
diff --git a/scripts/convert_deepspeed_checkpoint_to_hf.py b/scripts/convert_deepspeed_checkpoint_to_hf.py
new file mode 100644
index 000000000..4e83ced9d
--- /dev/null
+++ b/scripts/convert_deepspeed_checkpoint_to_hf.py
@@ -0,0 +1,263 @@
+#!/usr/bin/env python3
+"""
+Utility to turn a DeepSpeed ZeRO-2/3 checkpoint produced by grpo_fast.py into a Hugging Face model folder.
+
+Usage example:
+    python scripts/convert_deepspeed_checkpoint_to_hf.py \
+        --checkpoint-dir /path/to/checkpoint_state_dir \
+        --output-dir /path/to/exported_model \
+        --model-config /path/to/base_model_or_config \
+        --tokenizer /path/to/tokenizer \
+        --tag global_step16384
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+
+try:
+    from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
+except ImportError as exc:  # pragma: no cover - makes failure mode obvious to user
+    raise SystemExit(
+        "Could not import DeepSpeed. Install deepspeed in the current environment before running this script."
+    ) from exc
+
+try:
+    from open_instruct.dataset_transformation import CHAT_TEMPLATES
+except Exception:  # pragma: no cover - script is still useful without templates
+    CHAT_TEMPLATES = {}
+
+try:
+    from open_instruct.model_utils import get_olmo3_generation_config
+except Exception:  # pragma: no cover - optional helper
+    get_olmo3_generation_config = None
+
+
+def _resolve_checkpoint_leaf(checkpoint_dir: Path, explicit_tag: Optional[str]) -> tuple[Path, Optional[str], Path]:
+    """
+    Determine which folder actually holds the DeepSpeed shards.
+
+    Returns a tuple of (root_dir, tag, leaf_dir).
+    """
+    checkpoint_dir = checkpoint_dir.expanduser().resolve()
+    if not checkpoint_dir.exists():
+        raise FileNotFoundError(f"Checkpoint directory {checkpoint_dir} does not exist.")
+
+    tag = explicit_tag
+    if tag is None:
+        latest_file = checkpoint_dir / "latest"
+        if latest_file.is_file():
+            tag = latest_file.read_text().strip()
+
+    # If a tag is set, assume that `checkpoint_dir/tag` holds the shard files.
+    if tag:
+        leaf_dir = checkpoint_dir / tag
+    else:
+        leaf_dir = checkpoint_dir
+
+    if not leaf_dir.exists():
+        raise FileNotFoundError(f"Derived checkpoint leaf {leaf_dir} does not exist.")
+
+    shard = leaf_dir / "mp_rank_00_model_states.pt"
+    if not shard.exists():
+        # Some jobs nest the shards a level deeper, so search for a unique match.
+        matches = list(leaf_dir.glob("**/mp_rank_00_model_states.pt"))
+        if len(matches) == 1:
+            leaf_dir = matches[0].parent
+        elif len(matches) == 0:
+            raise FileNotFoundError(
+                f"Could not find mp_rank_00_model_states.pt under {leaf_dir}. "
+                "Make sure you are pointing at a DeepSpeed ZeRO checkpoint."
+            )
+        else:
+            raise RuntimeError(
+                f"Found multiple shard folders in {leaf_dir}. "
+                "Specify --tag to disambiguate which checkpoint to convert."
+            )
+
+    return checkpoint_dir, tag, leaf_dir
+
+
+def _parse_dtype(dtype_str: Optional[str]) -> Optional[torch.dtype]:
+    if dtype_str is None:
+        return None
+
+    normalized = dtype_str.lower()
+    mapping = {
+        "float16": torch.float16,
+        "fp16": torch.float16,
+        "half": torch.float16,
+        "float32": torch.float32,
+        "fp32": torch.float32,
+        "bfloat16": torch.bfloat16,
+        "bf16": torch.bfloat16,
+        "float64": torch.float64,
+        "fp64": torch.float64,
+    }
+    if normalized not in mapping:
+        valid = ", ".join(sorted(set(mapping.keys())))
+        raise ValueError(f"Unknown dtype '{dtype_str}'. Expected one of: {valid}")
+    return mapping[normalized]
+
+
+def _load_state_dict(leaf_dir: Path, tag: Optional[str]) -> dict[str, torch.Tensor]:
+    """
+    Load the aggregated state dict from the DeepSpeed checkpoint.
+    """
+    # The DeepSpeed utility accepts either the leaf directory that holds the shards
+    # (when tag=None) or the checkpoint root plus an explicit tag.
+    if tag is None:
+        state_dict = load_state_dict_from_zero_checkpoint(str(leaf_dir))
+    else:
+        state_dict = load_state_dict_from_zero_checkpoint(str(leaf_dir.parent), tag=tag)
+    return state_dict
+
+
+def _maybe_set_chat_template(tokenizer, chat_template_name: Optional[str]) -> None:
+    if not chat_template_name:
+        return
+
+    template = None
+    if chat_template_name in CHAT_TEMPLATES:
+        template = CHAT_TEMPLATES[chat_template_name]
+    else:
+        path = Path(chat_template_name)
+        if path.is_file():
+            template = path.read_text()
+
+    if template is None:
+        raise ValueError(
+            f"Could not resolve chat template '{chat_template_name}'. "
+            "Pass a key from CHAT_TEMPLATES or a path to a .jinja template file."
+        )
+
+    tokenizer.chat_template = template
+
+
+def _maybe_add_generation_config(model, tokenizer, chat_template_name: Optional[str]) -> None:
+    if not chat_template_name or get_olmo3_generation_config is None:
+        return
+    normalized = chat_template_name.lower()
+    if "olmo" in normalized:
+        gen_config = get_olmo3_generation_config(tokenizer)
+        model.generation_config = gen_config
+
+
+def convert_checkpoint(
+    checkpoint_dir: Path,
+    output_dir: Path,
+    model_config_source: str,
+    tokenizer_source: Optional[str],
+    tag: Optional[str],
+    dtype_str: Optional[str],
+    trust_remote_code: bool,
+    chat_template_name: Optional[str],
+) -> None:
+    checkpoint_dir, tag, leaf_dir = _resolve_checkpoint_leaf(checkpoint_dir, tag)
+    output_dir = output_dir.expanduser().resolve()
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    dtype = _parse_dtype(dtype_str)
+
+    state_dict = _load_state_dict(leaf_dir, tag)
+
+    config = AutoConfig.from_pretrained(model_config_source, trust_remote_code=trust_remote_code)
+    if dtype is not None:
+        config.torch_dtype = dtype
+    model = AutoModelForCausalLM.from_config(config, trust_remote_code=trust_remote_code)
+    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+    if missing_keys:
+        print(f"[warning] Missing keys while loading state dict ({len(missing_keys)}): {missing_keys[:8]}")
+    if unexpected_keys:
+        print(f"[warning] Unexpected keys in state dict ({len(unexpected_keys)}): {unexpected_keys[:8]}")
+    model.tie_weights()
+
+    if tokenizer_source is None:
+        tokenizer_source = model_config_source
+    tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, trust_remote_code=trust_remote_code)
+
+    _maybe_set_chat_template(tokenizer, chat_template_name)
+    _maybe_add_generation_config(model, tokenizer, chat_template_name)
+
+    model.save_pretrained(output_dir, safe_serialization=True)
+    tokenizer.save_pretrained(output_dir)
+
+    metadata = {
+        "source_checkpoint": str(checkpoint_dir),
+        "tag": tag,
+        "model_config_source": model_config_source,
+        "tokenizer_source": tokenizer_source,
+        "dtype": dtype_str,
+    }
+    with (output_dir / "conversion_metadata.json").open("w") as f:
+        json.dump(metadata, f, indent=2)
+
+    print(f"Converted checkpoint from {leaf_dir} into {output_dir}")
+
+
+def build_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument("--checkpoint-dir", required=True, help="Directory that holds DeepSpeed checkpoint state.")
+    parser.add_argument(
+        "--output-dir",
+        required=True,
+        help="Where to write the Hugging Face model (created if it does not exist).",
+    )
+    parser.add_argument(
+        "--model-config",
+        required=True,
+        help="Path or identifier for AutoConfig.from_pretrained to define the model architecture.",
+    )
+    parser.add_argument(
+        "--tokenizer",
+        default=None,
+        help="Path or identifier for the tokenizer. Defaults to the same value as --model-config.",
+    )
+    parser.add_argument(
+        "--tag",
+        default=None,
+        help="Optional DeepSpeed tag (e.g., global_step1234). If omitted, uses the 'latest' file when present.",
+    )
+    parser.add_argument(
+        "--torch-dtype",
+        default=None,
+        help="Optional torch dtype to record in the config (e.g., bf16, fp16).",
+    )
+    parser.add_argument(
+        "--chat-template-name",
+        default=None,
+        help="Tokenizer chat template key (from CHAT_TEMPLATES) or path to a template file to attach.",
+    )
+    parser.add_argument(
+        "--trust-remote-code",
+        action="store_true",
+        help="Forward trust_remote_code=True when loading config/tokenizer/model.",
+    )
+    return parser
+
+
+def main() -> None:
+    parser = build_parser()
+    args = parser.parse_args()
+    convert_checkpoint(
+        checkpoint_dir=Path(args.checkpoint_dir),
+        output_dir=Path(args.output_dir),
+        model_config_source=args.model_config,
+        tokenizer_source=args.tokenizer,
+        tag=args.tag,
+        dtype_str=args.torch_dtype,
+        trust_remote_code=args.trust_remote_code,
+        chat_template_name=args.chat_template_name,
+    )
+
+
+if __name__ == "__main__":
+    main()
+
diff --git a/scripts/data/rlvr/filtering_vllm.py b/scripts/data/rlvr/filtering_vllm.py
index 160dca91f..bf8c0d700 100644
--- a/scripts/data/rlvr/filtering_vllm.py
+++ b/scripts/data/rlvr/filtering_vllm.py
@@ -1,4 +1,4 @@
-'''
+"""
 python mason.py \
   --cluster ai2/jupiter --image nathanl/open_instruct_auto \
   --workspace ai2/tulu-thinker \
@@ -15,7 +15,8 @@
   --size 100000 \
   --output-file filtered_datasets/qwen2_5_openthoughts2/orz.jsonl \
   --number_samples 8
-'''
+"""
+
 import argparse
 import json
 
@@ -27,65 +28,20 @@
 
 
 def main():
-    parser = argparse.ArgumentParser(
-        description="Bulk-generate N samples per HF dataset record using vLLM."
-    )
-    parser.add_argument(
-        "--model",
-        required=True,
-        help="vLLM model ID (e.g. facebook/opt-125m)"
-    )
-    parser.add_argument(
-        "--dataset",
-        required=True,
-        help="HF dataset name (e.g. squad)"
-    )
+    parser = argparse.ArgumentParser(description="Bulk-generate N samples per HF dataset record using vLLM.")
+    parser.add_argument("--model", required=True, help="vLLM model ID (e.g. facebook/opt-125m)")
+    parser.add_argument("--dataset", required=True, help="HF dataset name (e.g. squad)")
+    parser.add_argument("--split", default="train", help="Which split to load")
+    parser.add_argument("--offset", type=int, required=True, help="Start index into the split")
+    parser.add_argument("--size", type=int, required=True, help="Number of records to process")
+    parser.add_argument("--output-file", default=None, help="Path for output JSONL")
     parser.add_argument(
-        "--split",
-        default="train",
-        help="Which split to load"
-    )
-    parser.add_argument(
-        "--offset",
-        type=int,
-        required=True,
-        help="Start index into the split"
-    )
-    parser.add_argument(
-        "--size",
-        type=int,
-        required=True,
-        help="Number of records to process"
-    )
-    parser.add_argument(
-        "--output-file",
-        default=None,
-        help="Path for output JSONL"
-    )
-    parser.add_argument(
-        "--push_to_hub",
-        default=None,
-        type=str,
-        help="Give a dataset name to push this data to the hub."
-    )
-    parser.add_argument(
-        "--chat_template",
-        type=str,
-        default=None,
-        help="Chat template name"
-    )
-    parser.add_argument(
-        "--number_samples",
-        type=int,
-        default=8,
-        help="Number of samples to generate per record"
-    )
-    parser.add_argument(
-        "--temperature",
-        type=float,
-        default=1.0,
-        help="Sampling temperature"
+        "--push_to_hub", default=None, type=str, help="Give a dataset name to push this data to the hub."
     )
+    parser.add_argument("--chat_template", type=str, default=None, help="Chat template name")
+    parser.add_argument("--number_samples", type=int, default=8, help="Number of samples to generate per record")
+    parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
+    parser.add_argument("--top_p", type=float, default=1.0, help="Sampling temperature")
     args = parser.parse_args()
 
     # 1. Load and slice dataset
@@ -106,20 +62,14 @@ def main():
         tokenizer.apply_chat_template(
             sample["messages"][:-1] if len(sample["messages"]) > 1 else sample["messages"],
             add_generation_prompt=True,
-            tokenize=False
+            tokenize=False,
         )
         for sample in subset
     ]
     # 4. vLLM bulk generate
-    llm = LLM(
-        model=args.model,
-        dtype="bfloat16",
-        enable_prefix_caching=True
-    )
+    llm = LLM(model=args.model, dtype="bfloat16", enable_prefix_caching=True)
     sampling_params = SamplingParams(
-        temperature=args.temperature,
-        n=args.number_samples,
-        max_tokens=32768,
+        temperature=args.temperature, top_p=args.top_p, n=args.number_samples, max_tokens=32768
     )
     outputs = llm.generate(prompts, sampling_params)
 
diff --git a/scripts/train/debug/grpo_fast.sh b/scripts/train/debug/grpo_fast.sh
index 5bf6f791e..13b59000a 100755
--- a/scripts/train/debug/grpo_fast.sh
+++ b/scripts/train/debug/grpo_fast.sh
@@ -1,11 +1,31 @@
+#!/bin/bash
+
+python mason.py \
+    --task_name grpo_debug_small_active \
+    --cluster ai2/jupiter \
+    --workspace ai2/olmo-instruct \
+    --priority urgent \
+    --pure_docker_mode \
+    --image michaeln/open_instruct_2.5-rl0 \
+    --preemptible \
+    --num_nodes 1 \
+    --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
+    --env VLLM_ATTENTION_BACKEND="FLASH_ATTN" \
+    --gpus 2 \
+    --budget ai2/oe-adapt \
+    -- \
 uv run python open_instruct/grpo_fast.py \
     --dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 64 \
     --dataset_mixer_list_splits train \
     --dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \
     --dataset_mixer_eval_list_splits train \
     --max_prompt_token_length 512 \
-    --response_length 512 \
-    --pack_length 1024 \
+    --response_length 2048 \
+    --pack_length 4096 \
+    --async_steps 4 \
+    --inflight_updates \
+    --active_fill_completions \
+    --truncated_importance_sampling_ratio_cap 2.0 \
     --per_device_train_batch_size 1 \
     --num_unique_prompts_rollout 8 \
     --num_samples_per_prompt_rollout 4 \
@@ -17,20 +37,16 @@ uv run python open_instruct/grpo_fast.py \
     --ground_truths_key ground_truth \
     --chat_template_name r1_simple_chat_postpend_think \
     --learning_rate 3e-7 \
-    --total_episodes 200 \
+    --total_episodes 1600 \
     --deepspeed_stage 2 \
     --num_epochs 1 \
     --num_learners_per_node 1 \
+    --vllm_num_engines 1 \
     --vllm_tensor_parallel_size 1 \
-    --beta 0.01 \
+    --beta 0. \
     --seed 3 \
-    --local_eval_every 1 \
-    --vllm_sync_backend gloo \
-    --vllm_gpu_memory_utilization 0.3 \
-    --save_traces \
-    --vllm_enforce_eager \
+    --local_eval_every 100 \
     --gradient_checkpointing \
-    --single_gpu_mode \
     --push_to_hub false \
     --system_prompt_override_file scripts/train/debug/cute_debug_system_prompt.txt \
     # --with_tracking
diff --git a/scripts/train/rlvr/grpo_rlzero.sh b/scripts/train/rlvr/grpo_rlzero.sh
new file mode 100755
index 000000000..680bcb658
--- /dev/null
+++ b/scripts/train/rlvr/grpo_rlzero.sh
@@ -0,0 +1,90 @@
+#!/bin/bash
+
+# OLMo 3 model
+MODEL_NAME_OR_PATH="/weka/oe-adapt-default/michaeln/checkpoints/olmo3-7b-base"
+GS_MODEL_NAME="olmo3_7b_base"
+
+# english only DAPO
+# DATASETS="mnoukhov/DAPO-Math-14k-Processed-RLVR 1.0 TTTXXX01/MATH_3000_Filtered 1.0"
+DATASETS="saurabh5/DAPO-Math-17k-Processed_filtered_olmo_completions_new_template_filtered 1.0 saurabh5/MATH_3000_Filtered_olmo_completions_new_template_filtered 1.0"
+# DATASETS="mnoukhov/deepscaler_20k_medhard_nolatex_rlvr 1.0"
+# DATASETS=""
+
+# math evals
+# EVALS="minerva_math_500::hamish_zs_reasoning_deepseek"
+EVALS="aime:zs_cot_r1::pass_at_32_2024_dapo,aime:zs_cot_r1::pass_at_32_2025_dapo"
+
+# AIME 2024, 2025 local evals
+LOCAL_EVALS="mnoukhov/aime2024-25-rlvr 1.0 mnoukhov/aime2024-25-rlvr 1.0"
+LOCAL_EVAL_SPLITS="test_2024 test_2024 test_2025 test_2025"
+# tengmath3k
+# EXP_NAME="grpo_deepscaler20k_k8_${GS_MODEL_NAME}"
+EXP_NAME="grpo_17kfilter_${GS_MODEL_NAME}"
+
+cluster=ai2/jupiter
+
+python mason.py \
+    --task_name ${EXP_NAME} \
+    --cluster ${cluster} \
+    --workspace ai2/olmo-instruct \
+    --priority urgent \
+    --pure_docker_mode \
+    --image michaeln/open_instruct_rlzero \
+    --preemptible \
+    --num_nodes 9 \
+    --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
+    --env VLLM_ATTENTION_BACKEND="FLASH_ATTN" \
+    --gs_model_name $GS_MODEL_NAME \
+    --gpus 8 \
+    --budget ai2/oe-adapt \
+    -- \
+source configs/beaker_configs/ray_node_setup.sh \&\& \
+source configs/beaker_configs/code_api_setup.sh \&\& \
+python open_instruct/grpo_fast.py \
+    --exp_name ${EXP_NAME} \
+    --beta 0.0 \
+    --async_steps 4 \
+    --inflight_updates \
+    --truncated_importance_sampling_ratio_cap 2.0 \
+    --advantage_normalization_type centered \
+    --active_fill_completions \
+    --no_resample_solve_rate 0.9 \
+    --num_samples_per_prompt_rollout 16 \
+    --num_unique_prompts_rollout 16 \
+    --num_mini_batches 1 \
+    --learning_rate 1e-6 \
+    --per_device_train_batch_size 1 \
+    --kl_estimator kl3 \
+    --dataset_mixer_list $DATASETS \
+    --dataset_mixer_list_splits train \
+    --dataset_mixer_eval_list $LOCAL_EVALS \
+    --dataset_mixer_eval_list_splits $LOCAL_EVAL_SPLITS \
+    --max_prompt_token_length 2048 \
+    --response_length 12000 \
+    --pack_length 32768 \
+    --model_name_or_path ${MODEL_NAME_OR_PATH} \
+    --chat_template_name olmo_thinker_dapo \
+    --non_stop_penalty False \
+    --temperature 1.0 \
+    --total_episodes 512000 \
+    --deepspeed_stage 3 \
+    --num_learners_per_node 8 \
+    --vllm_num_engines 64 \
+    --vllm_tensor_parallel_size 1 \
+    --lr_scheduler_type constant \
+    --apply_verifiable_reward true \
+    --seed 1 \
+    --local_eval_every 100 \
+    --save_freq 100 \
+    --checkpoint_state_freq 100 \
+    --gradient_checkpointing \
+    --with_tracking \
+    --vllm_enable_prefix_caching \
+    --clip_higher 0.272 \
+    --mask_truncated_completions True \
+    --oe_eval_max_length 32768 \
+    --try_launch_beaker_eval_jobs_on_weka True \
+    --eval_priority high \
+    --oe_eval_tasks $EVALS \
+    --oe_eval_gpu_multiplier 4 \
+    --oe_eval_beaker_image michaeln/oe_eval_rlzero
diff --git a/uv.lock b/uv.lock
index 5b839989f..a6c81c00b 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1804,7 +1804,7 @@ requires-dist = [
     { name = "bitsandbytes", marker = "sys_platform != 'darwin'", specifier = ">=0.44.1" },
     { name = "datasets", specifier = ">=4.0.0" },
     { name = "debugpy", specifier = ">=1.8.13" },
-    { name = "deepspeed", specifier = "<0.17.6" },
+    { name = "deepspeed", specifier = "<0.17.3" },
     { name = "fastapi", marker = "extra == 'code'", specifier = ">=0.100.0" },
     { name = "flash-attn", marker = "sys_platform != 'darwin'", specifier = ">=2.8.3" },
     { name = "hf-transfer", specifier = ">=0.1.8" },