diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index e9ddc625a..960c96e6d 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -32,14 +32,24 @@ training: max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[:95%]" + +eval: + eval_every_n_steps: 5 # (null = disabled) + max_eval_steps: 0 # Max batches per eval dataset (null = run until epoch completes) + batch_size: ${training.local_batch_size} # Batch size for evaluation + datasets: + - path: "yahma/alpaca-cleaned" + split: "train[95%:]" parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 - tensor_parallel_degree: 1 + tensor_parallel_degree: 2 pipeline_parallel_degree: 1 - context_parallel_degree: 1 + context_parallel_degree: 2 expert_parallel_degree: 1 disable_loss_parallel: false @@ -62,6 +72,7 @@ metric_logging: group: sft_exp_${oc.env:USER} logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + # profiling: # enable_profiling: false diff --git a/apps/sft/main.py b/apps/sft/main.py index edda0b49d..b7a07237a 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -27,8 +27,10 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.data.utils import StopAfterOneEpoch from forge.observability import get_or_create_metric_logger, record_metric, Reduce from forge.util.config import parse +from forge.util.logging import log_rank_zero from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf @@ -81,6 +83,7 @@ def __init__(self, config: DictConfig): self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) + self._init_dist() super().__init__(job_config) @@ -122,27 +125,67 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): - self.train_dataloader = self.setup_data() - self.mlogger = await self.setup_metric_logger() - - # self.train_dataloader = self.setup_data( - # self.train_config.train_dataset_config, - # self.train_config.train_dataloader_config, - # self.train_config.packing_config, - # ) - # self.val_dataloader = self.setup_data( - # self.train_config.val_dataset_config, - # self.train_config.val_dataloader_config, - # self.train_config.packing_config, - # ) - - # TODO: confirm that this is working properly - # Should also use load, not dcp_load + """Setup datasets from config. + + Loads training and evaluation datasets based on config structure. + """ + # Load training datasets + logger.info("Setting training datasets") + train_datasets_config = self.job_config.training.datasets + self.train_dataloader = self.setup_data(train_datasets_config) + + # Load eval config (might be None) + eval_config = self.job_config.get("eval", {}) + self.val_dataloaders = {} + self.validation_enabled = False + self.eval_every_n_steps = eval_config.get("eval_every_n_steps", None) + max_eval_steps = eval_config.get("max_eval_steps", None) + self.max_eval_steps = ( + max_eval_steps if max_eval_steps and max_eval_steps > 0 else None + ) + self.validation_enabled = ( + self.eval_every_n_steps is not None and self.eval_every_n_steps > 0 + ) + if self.validation_enabled: + logger.info("Setting eval datasets") + self.eval_datasets_config = eval_config.datasets + + for i, dataset_config in enumerate(self.eval_datasets_config): + ds_name = dataset_config.get("dataset_name", i) + + dataloader = self.setup_data([dataset_config]) + self.val_dataloaders[ds_name] = dataloader + + # Load checkpoint if resuming self.checkpointer.load(step=self.current_step) - # self.profiler = self.setup_profiler(self.train_config.profiler_config) - # self.logger = self.setup_logger(self.train_config.logger_config) - def setup_data(self): + def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: + """Setup data from dataset configs. + + Currently only supports single dataset (first in list). + Multi-dataset training with InterleavedDataset is future work. + + Args: + dataset_configs: List of dataset config dicts with keys like 'path', 'split', etc. + + Returns: + StatefulDataLoader for the dataset + + Raises: + ValueError: If multiple datasets provided (not yet supported) + """ + # Currently only support single dataset + if len(dataset_configs) > 1: + raise ValueError( + f"Multiple training datasets not supported yet. " + f"Got {len(dataset_configs)} datasets. " + f"For dataset mixing, use InterleavedDataset (coming soon)." + ) + + dataset_config = dataset_configs[0] + + # TODO: Evaluate if tokenizers should be created once and shared for every dataset + # Load tokenizer print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( @@ -156,35 +199,55 @@ def setup_data(self): ), ) + # Store tokenizer for later use (e.g., decoding in debug logs) + self.tokenizer = tokenizer + + # Get DP mesh for data sharding + dp_mesh = None + if self.parallel_dims is not None and self.parallel_dims.dp_enabled: + dp_mesh = self.parallel_dims.world_mesh.get_group("dp") + + # Pass config directly to dataset constructor dataset = sft_iterable_dataset( model_transform=tokenizer, message_transform=AlpacaToMessages(), - path="yahma/alpaca-cleaned", - split="train", + dp_mesh=dp_mesh, + **dataset_config, # Unpack config (path, split, etc.) ) + packer = TextPacker(padding_idx=0) dataset = PackedDataset( dataset=dataset, packer=packer, - target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model + target_tokens_per_pack=self.job_config.training.seq_len, ) - dataloader = StatefulDataLoader( + + return StatefulDataLoader( dataset=dataset, batch_size=self.job_config.training.local_batch_size, collate_fn=partial( collate_packed, mask_fn=packer.create_block_mask, device=self.device ), + drop_last=True, ) - # Ultimately we probably want something like this - # packer = build_packing_strategy(packing_config) - # dataset = build_dataset(dataset_config) - # dataloader = build_dataloader(dataloader_config, dataset, packer) - return dataloader - - def forward_backward( - self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + def forward( + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + compute_gradients: bool = True, ) -> torch.Tensor: + """Forward pass with optional gradient computation. + + Args: + input_dict: Input dictionary containing tokens + labels: Target labels + compute_gradients: If True, compute gradients (training mode). + If False, skip backward pass (evaluation mode). + + Returns: + Loss tensor + """ model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -204,7 +267,7 @@ def forward_backward( ) if parallel_dims.pp_enabled: - # Pipeline Parallel forward / backward inside step() call + # Pipeline Parallel forward (with optional backward) with self.train_context(optional_context_parallel_ctx): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) @@ -226,7 +289,7 @@ def forward_backward( else torch.tensor([-1.0], device=self.device) ) else: - # Non-PP forward / backward + # Non-PP forward (with optional backward) with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: @@ -234,7 +297,8 @@ def forward_backward( loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred - loss.backward() + if compute_gradients: + loss.backward() return loss @@ -246,7 +310,7 @@ def train_step(self, batch) -> None: # self.data_parallel_size, # ) as grad_acc: labels = batch.pop("labels") - loss = self.forward_backward(batch, labels) + loss = self.forward(batch, labels, compute_gradients=True) loss = loss.item() record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) @@ -256,6 +320,107 @@ def train_step(self, batch) -> None: self.optimizers.step() self.lr_schedulers.step() + async def evaluate(self) -> None: + """Run evaluation on multiple datasets, one at a time. + + 1. Set models to eval mode + 2. For each eval dataset: + - Create fresh iterator (starts from epoch 0) + - Use StopAfterOneEpoch to iterate until epoch boundary. This utility + is necessary for infinite iterable dataset, since epoch boundaries are not known. + - Respect max_eval_steps cap if configured + - Record loss and step metrics (on dp rank only) + 3. Restore models to train mode + """ + logger.debug("==STARTING EVALUATION==") + + # Set models to eval mode + for model_part in self.model_parts: + model_part.eval() + + # Get DP process group for epoch synchronization + dp_process_group = None + if self.parallel_dims is not None and self.parallel_dims.dp_enabled: + dp_process_group = self.parallel_dims.world_mesh.get_group("dp") + + # Evaluate each dataset sequentially + for dataset_name, val_dataloader in self.val_dataloaders.items(): + logger.debug(f"=====Evaluating dataset: {dataset_name}=====") + + # Evaluation loop for this dataset + total_loss = torch.tensor(0.0, device=self.device) + num_steps = 0 + + # NOTE: Assumes batch contains batch["metrics"]["num_epochs"]: int + batch_iter = StopAfterOneEpoch( + iter(val_dataloader), # Fresh iterator from epoch 0 + self.device, + dataset_name, + dp_process_group, + ) + + with torch.no_grad(): + for batch in batch_iter: + # Check max_eval_steps limit + if ( + self.max_eval_steps is not None + and num_steps >= self.max_eval_steps + ): + log_rank_zero( + logger, + f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}", + ) + break + + # Move tensors to device + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(self.device) + + # Process batch + labels = batch.pop("labels") + loss = self.forward(batch, labels, compute_gradients=False) + total_loss += loss + num_steps += 1 + + # Log progress (rank 0 only) + if num_steps % 100 == 0: + loss_val = loss.item() + log_rank_zero( + logger, + f" [{dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}", + ) + + # Compute average loss + avg_loss = (total_loss / max(num_steps, 1)).item() + log_rank_zero(logger, f" [{dataset_name}] avg_loss: {avg_loss:.4f}") + + # Record metrics only on DP rank 0 to avoid double counting + # record_metric aggregates across all processes via monarch + should_record = True + if dp_process_group is not None: + dp_rank = torch.distributed.get_rank(group=dp_process_group) + should_record = dp_rank == 0 + + if should_record: + record_metric( + f"ForgeSFTRecipe/evaluate/{dataset_name}_loss", + avg_loss, + Reduce.MEAN, + ) + record_metric( + f"ForgeSFTRecipe/evaluate/{dataset_name}_steps", + num_steps, + Reduce.MEAN, + ) + + # Restore train mode + for model_part in self.model_parts: + model_part.train() + + # Summary + logger.debug("==EVALUATION COMPLETE==") + @endpoint async def train(self) -> None: dataloader = iter(self.train_dataloader) @@ -280,10 +445,12 @@ async def train(self) -> None: # self.profiler.step() self.current_step += 1 - # Flush metrics - if self._rank == 0: - logger.debug(f"Flushing metrics at step {self.current_step}") - await self.mlogger.flush.call_one(global_step=self.current_step) + # Run evaluation periodically if enabled + if ( + self.validation_enabled + and self.current_step % self.eval_every_n_steps == 0 + ): + await self.evaluate() self.checkpointer.save( curr_step=self.current_step, @@ -292,6 +459,11 @@ async def train(self) -> None: # self.pbar.close() + # Run final evaluation at end of training + if self.validation_enabled: + logger.info("Running final evaluation at end of training...") + await self.evaluate() + @endpoint async def cleanup(self) -> None: if self.checkpointer: diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index d7b36fe68..8399d68e5 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -70,6 +70,7 @@ def __init__( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, + dp_mesh: Any = None, **load_dataset_kwargs, ): # Store configuration @@ -79,6 +80,8 @@ def __init__( self._model_transform = model_transform self._output_transform = output_transform self._weight = weight if weight is not None else 1.0 + self._dp_mesh = dp_mesh + self._is_resumed = False # Create default transform if not provided self._metric_transform = metric_transform or DefaultDatasetMetricTransform() @@ -138,11 +141,22 @@ def _setup_hf_dataset( shuffle configuration, and filtering. Called once during __init__. """ - # Distributed setup + # Extract rank/world_size from DP mesh world_size, rank = 1, 0 - if dist.is_initialized(): + if self._dp_mesh is not None: + world_size = dist.get_world_size(group=self._dp_mesh) + rank = dist.get_rank(group=self._dp_mesh) + logger.info( + f"Using DP mesh for sharding: rank={rank}, world_size={world_size}" + ) + elif dist.is_initialized(): + # Fallback to global rank (may not respect TP/PP) world_size = dist.get_world_size() rank = dist.get_rank() + logger.warning( + f"Using global rank for sharding: rank={rank}, world_size={world_size}. " + f"If using TP/PP, pass dp_mesh for correct sharding." + ) # Load and shard dataset ds = load_dataset(**load_dataset_kwargs) @@ -218,6 +232,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]: - Adds 'num_epochs' metric to track dataset progress - Yields samples indefinitely for continuous training """ + # Reset iter + if not self._is_resumed: + self._num_epochs = 0 while True: # Infinite iteration self._ds.set_epoch(self._num_epochs) @@ -282,3 +299,4 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: # HF is responsible for resuming the dataset state # where it last left off self._ds.load_state_dict(hf_state) + self._is_resumed = True diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index 93a21b85e..1a22352ef 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -343,9 +343,6 @@ def _reset_packer_state(self) -> None: # exhausted: whether the dataset is exhausted self._exhausted: bool = False - # resuming: whether the packer is resuming from a checkpoint - self._resuming: bool = False - def _fill_buffer(self, iterator: Iterator[SampleType]) -> None: """ Fills the buffer with samples from the dataset. @@ -449,18 +446,15 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> SampleDict | None: return None def __iter__(self) -> Iterator[SampleDict]: + """Create a new iterator for the dataset. + + Always resets the packer state to ensure consistent iteration from the start. + """ if not isinstance(self.dataset, Iterable): raise TypeError("Dataset is not an iterable") - if not self._resuming: - self._reset_packer_state() - self._iterator = iter(self.dataset) - - # If resuming, the iterator must be recreated from the loaded state - if self._iterator is None: - self._iterator = iter(self.dataset) - - self._resuming = False # Consume the resume flag + self._reset_packer_state() + self._iterator = iter(self.dataset) # Main packing loop while True: @@ -502,7 +496,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: raise ValueError("Dataset is not stateful.") self._reset_packer_state() - self._resuming = True class TextPacker(Packer[SampleDict]): diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index 00278c1e5..6264b13ea 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -162,6 +162,7 @@ def sft_iterable_dataset( dataset_name: str | None = None, filter_fn: Callable | None = None, filter_kwargs: dict[str, Any] | None = None, + dp_mesh: Any = None, **load_dataset_kwargs: dict[str, Any], ) -> HfIterableDataset: """ @@ -177,6 +178,7 @@ def sft_iterable_dataset( dataset_name (str | None): Name for metrics namespacing filter_fn (Callable | None): Filter function filter_kwargs (dict[str, Any] | None): Filter function kwargs + dp_mesh (Any): Data parallel mesh for sharding (None for single process) **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset Returns: @@ -206,5 +208,6 @@ def sft_iterable_dataset( dataset_name=dataset_name, filter_fn=filter_fn, filter_kwargs=filter_kwargs, + dp_mesh=dp_mesh, **load_dataset_kwargs, ) diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index be8c13857..498d3e419 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -4,13 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from enum import Enum -from typing import Any, Literal, Union +from typing import Any, Iterator, Literal, Union import torch from torch.nn.attention.flex_attention import BlockMask +logger = logging.getLogger(__name__) + CROSS_ENTROPY_IGNORE_IDX = -100 Role = Literal[ @@ -213,3 +216,127 @@ def batch_to_device(batch: dict, device: torch.device) -> None: f"Tensor, or BlockMask with flexattention enabled. " f'Got key "{k}" with value of type {type(v)}' ) + + +class StopAfterOneEpoch: + """Iterator that wraps a dataloader and stops after one epoch completes. + + Handles epoch detection and synchronization across DP ranks using async + all_reduce. Assumes dataset inherits from InfiniteTuneIterableDataset + which provides 'metrics' with 'num_epochs' metric. + + When any rank detects an epoch change, all ranks stop (synchronized). + + Args: + dataloader_iter: Iterator over dataloader batches + device: Device for computation + dataset_name: Name for logging + dp_process_group: Data parallel process group (None for single process) + """ + + def __init__( + self, + dataloader_iter: Iterator, + device: torch.device, + dataset_name: str, + dp_process_group: Any = None, + ): + self.dataloader_iter = dataloader_iter + self.device = device + self.dataset_name = dataset_name + self.dp_process_group = dp_process_group + + # Prefetch first batch for pipeline-style execution + self._next_batch = next(dataloader_iter) + + # Track pending async epoch sync + self._epoch_tensor: torch.Tensor | None = None + self._pending_work: Any = None + self._should_stop = False + + def __iter__(self): + return self + + def __next__(self) -> dict: + """Get next batch from current epoch. + + Returns: + Batch dict guaranteed to be from current epoch + + Raises: + StopIteration: When epoch completes across all ranks + """ + # Check if previous epoch sync completed + if self._pending_work is not None: + self._pending_work.wait() + if self._epoch_tensor.item() > 0: + self._should_stop = True + self._pending_work = None + self._epoch_tensor = None + + if self._should_stop: + logger.debug( + f"[{self.dataset_name}] Eval epoch completed. Stopping data iterator." + ) + raise StopIteration + + # Get current batch + current_batch = self._next_batch + current_epoch = extract_epoch_from_batch(current_batch) + + # Prefetch next batch and check for epoch change + self._next_batch = next(self.dataloader_iter) + next_epoch = extract_epoch_from_batch(self._next_batch) + epoch_changed = next_epoch > current_epoch + + # Start async epoch sync + if torch.distributed.is_initialized(): + self._epoch_tensor = torch.tensor([int(epoch_changed)], device=self.device) + self._pending_work = torch.distributed.all_reduce( + self._epoch_tensor, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group, + async_op=True, + ) + elif epoch_changed: + # if not distributed, just update the flag directly + self._should_stop = True + + return current_batch + + +def extract_epoch_from_batch(batch: dict | list) -> int: + """Extract epoch number from batch metrics. + + Assumes datasets inherit from InfiniteTuneIterableDataset which always + adds num_epochs metric. Raises clear error if assumption is violated. + + Args: + batch: Batch dictionary with 'metrics' field OR list of sample dicts + + Returns: + Epoch number from metrics + + Raises: + ValueError: If metrics missing or no num_epochs found + """ + # Handle list of samples (uncollated batches) + if isinstance(batch, list): + if not batch: + raise ValueError("Empty batch provided") + batch = batch[0] # Extract first sample + + if "metrics" not in batch: + raise ValueError( + "Batch missing 'metrics' field. Ensure dataset inherits from " + "InfiniteTuneIterableDataset which adds this automatically." + ) + + for metric in batch["metrics"]: + if "num_epochs" in metric.key: + return int(metric.value) + + raise ValueError( + f"No 'num_epochs' metric found in batch. Got metrics: " + f"{[m.key for m in batch['metrics']]}" + ) diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index 8298bf1a8..9fd2ce464 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -272,6 +272,8 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): class TestDistributedHfIterableDataset(FSDPTest): + """Test HfIterableDataset with 2-GPU distributed setup.""" + @property def world_size(self) -> int: return 2 @@ -364,3 +366,124 @@ def create_loader(): finally: shutil.rmtree(temp_dir) + + +class TestDPShardingWithTP(FSDPTest): + """Test DP sharding with TP replication (4-GPU setup).""" + + @property + def world_size(self) -> int: + return 4 + + @gpu_test(gpu_count=4) + def test_dp_sharding_with_tp_replication(self): + """Verify DP sharding works correctly with TP/CP replication. + + This is a CRITICAL test that validates the core bug fix: + - Previously: Each rank got different batches (incorrect) + - Now: TP/CP ranks within same DP group get identical batches (correct) + + Setup: DP=2, TP=2 (4 GPUs total) + - DP group 0: ranks [0, 1] - should see SAME batches (TP replication) + - DP group 1: ranks [2, 3] - should see SAME batches (TP replication) + - DP group 0 vs 1: should see DIFFERENT batches (DP sharding) + + Mesh structure: + - TP rank 0 DP replicas: [0, 2] - shard across these + - TP rank 1 DP replicas: [1, 3] - shard across these + """ + import hashlib + + rank = dist.get_rank() + world_size = dist.get_world_size() + temp_dir = tempfile.mkdtemp(prefix=f"dp_tp_test_rank{rank}_") + + try: + data_file = Path(temp_dir) / "data.json" + # Create dataset with enough samples for clear sharding + # 40 samples / 2 DP groups = 20 samples per DP group + create_test_json_file(data_file, MEDIUM_DATASET_SIZE, offset=0) + + # Create DP mesh for sharding + # Key insight: Create groups across DP replicas for each TP rank + # TP rank = rank % 2, so: + # - TP rank 0: ranks [0, 2] (one from each DP group) + # - TP rank 1: ranks [1, 3] (one from each DP group) + tp_rank = rank % 2 + tp_world_size = 2 + dp_world_size = world_size // tp_world_size + + # Create DP groups for each TP rank + dp_groups = [] + for tp_r in range(tp_world_size): + # Ranks for this TP rank across DP groups + ranks = [tp_r + i * tp_world_size for i in range(dp_world_size)] + group = dist.new_group(ranks=ranks) + dp_groups.append(group) + + dp_mesh = dp_groups[tp_rank] + + # - Rank 0 (tp_rank=0) uses group [0, 2], gets rank=0 → shard 0 + # - Rank 1 (tp_rank=1) uses group [1, 3], gets rank=0 → shard 0 + # - Rank 2 (tp_rank=0) uses group [0, 2], gets rank=1 → shard 1 + # - Rank 3 (tp_rank=1) uses group [1, 3], gets rank=1 → shard 1 + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + dataset_name="dp_tp_test", + shuffle_buffer_size=0, + metric_transform=DefaultDatasetMetricTransform(), + num_shards_per_rank=2, + dp_mesh=dp_mesh, # CRITICAL: Pass dp_mesh for correct sharding + ) + + dataloader = StatefulDataLoader( + dataset, + batch_size=BATCH_SIZE, + collate_fn=collate_with_metrics, + num_workers=0, + ) + + # Collect batches and compute hashes + batches = list(islice(iter(dataloader), 5)) + batch_hashes = [] + for batch in batches: + # Hash the batch IDs to verify identity/difference + batch_ids = batch["id"].cpu().tolist() + batch_hash = hashlib.md5(str(batch_ids).encode()).hexdigest() + batch_hashes.append(batch_hash) + + # Gather hashes from all ranks for comparison + gathered_hashes = [None] * world_size + dist.all_gather_object(gathered_hashes, batch_hashes) + + if rank == 0: + # Verify TP replication within DP groups + # Ranks 0 and 1 should have identical hashes (same DP group) + assert gathered_hashes[0] == gathered_hashes[1], ( + f"Ranks 0 and 1 (same DP group) should see identical batches!\n" + f"Rank 0 hashes: {gathered_hashes[0][:3]}...\n" + f"Rank 1 hashes: {gathered_hashes[1][:3]}..." + ) + + # Ranks 2 and 3 should have identical hashes (same DP group) + assert gathered_hashes[2] == gathered_hashes[3], ( + f"Ranks 2 and 3 (same DP group) should see identical batches!\n" + f"Rank 2 hashes: {gathered_hashes[2][:3]}...\n" + f"Rank 3 hashes: {gathered_hashes[3][:3]}..." + ) + + # Verify DP sharding across groups + # Ranks 0/1 should see DIFFERENT batches from ranks 2/3 + assert gathered_hashes[0] != gathered_hashes[2], ( + f"Ranks 0 and 2 (different DP groups) should see different batches!\n" + f"DP group 0 hashes: {gathered_hashes[0][:3]}...\n" + f"DP group 1 hashes: {gathered_hashes[2][:3]}..." + ) + + dist.barrier() + + finally: + shutil.rmtree(temp_dir) diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py index 56cd5ff02..1c6c4906f 100644 --- a/tests/unit_tests/datasets/test_packed.py +++ b/tests/unit_tests/datasets/test_packed.py @@ -949,3 +949,49 @@ def create_loader(): # Verify that checkpointing and resumption work assert len(result["post_checkpoint_batches"]) == steps_after_checkpoint assert len(result["resumed_batches"]) == steps_after_checkpoint + + def test_iter_restart_determinism(self, dataset_factory): + """Test that calling iter() multiple times produces deterministic results. + + This is critical for evaluation: each eval run should start from the + same state (epoch 0, step 0) regardless of previous iterations. + """ + samples = [ + {"tokens": [0] * 3}, + {"tokens": [1] * 2}, + {"tokens": [2] * 4}, + ] + target_tokens_per_pack = 6 + + # Create packed dataset + dataset = dataset_factory(samples) + packer = TextPacker(padding_idx=999, ignore_idx=-100) + packed_dataset = PackedDataset( + dataset=dataset, + packer=packer, + target_tokens_per_pack=target_tokens_per_pack, + buffer_size=1, + ) + + # First iteration - get first 2 packs + iter1 = iter(packed_dataset) + packs_iter1 = list(islice(iter1, 2)) + + # Second iteration - should get same first 2 packs + iter2 = iter(packed_dataset) + packs_iter2 = list(islice(iter2, 2)) + + # Verify both iterations produce identical packs + assert len(packs_iter1) == len(packs_iter2) == 2 + + for i, (pack1, pack2) in enumerate(zip(packs_iter1, packs_iter2)): + torch.testing.assert_close( + pack1["tokens"], + pack2["tokens"], + msg=f"Pack {i}: tokens mismatch between iterations", + ) + torch.testing.assert_close( + pack1["document_ids"], + pack2["document_ids"], + msg=f"Pack {i}: document_ids mismatch between iterations", + ) diff --git a/tests/unit_tests/datasets/test_stop_after_one_epoch.py b/tests/unit_tests/datasets/test_stop_after_one_epoch.py new file mode 100644 index 000000000..0b0399cda --- /dev/null +++ b/tests/unit_tests/datasets/test_stop_after_one_epoch.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for StopAfterOneEpoch iterator and extract_epoch_from_batch helper.""" +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from forge.data.datasets import HfIterableDataset + +from forge.data.utils import extract_epoch_from_batch, StopAfterOneEpoch +from forge.observability.metrics import Metric, Reduce +from torch.testing._internal.common_fsdp import FSDPTest +from torchdata.stateful_dataloader import StatefulDataLoader + +from tests.test_utils import gpu_test + + +def create_test_json_file(path: Path, num_samples: int) -> None: + """Create test data file with simple samples.""" + with open(path, "w") as f: + for i in range(num_samples): + f.write(f'{{"id": {i}, "tokens": [{i}, {i+1}]}}\n') + + +class TestExtractEpochFromBatch: + """Test extract_epoch_from_batch helper function.""" + + def test_extract_epoch_from_batch_success(self): + """Test extracting epoch from valid batch with metrics.""" + batch = { + "tokens": torch.tensor([1, 2, 3]), + "metrics": [ + Metric(key="dataset/test/num_epochs", value=2, reduction=Reduce.MAX), + Metric( + key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN + ), + ], + } + epoch = extract_epoch_from_batch(batch) + assert epoch == 2 + + def test_extract_epoch_missing_metrics_field(self): + """Test error when batch has no 'metrics' field.""" + batch = {"tokens": torch.tensor([1, 2, 3])} + with pytest.raises(ValueError, match="Batch missing 'metrics' field"): + extract_epoch_from_batch(batch) + + def test_extract_epoch_no_num_epochs_metric(self): + """Test error when no num_epochs metric found.""" + batch = { + "metrics": [ + Metric( + key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN + ), + ] + } + with pytest.raises(ValueError, match="No 'num_epochs' metric found"): + extract_epoch_from_batch(batch) + + +class TestStopAfterOneEpochSingleProcess: + """Test StopAfterOneEpoch in single-process mode (no distributed).""" + + def test_stop_after_one_epoch(self, tmp_path): + """Verify iterator stops after exactly one epoch completes.""" + # Create small dataset (10 samples) + data_file = tmp_path / "data.json" + create_test_json_file(data_file, num_samples=10) + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + shuffle_buffer_size=0, + num_shards_per_rank=1, + ) + + dataloader = StatefulDataLoader(dataset, batch_size=2, collate_fn=lambda x: x) + + # Wrap with StopAfterOneEpoch + device = torch.device("cuda") + batch_iter = StopAfterOneEpoch( + iter(dataloader), device, "test_dataset", dp_process_group=None + ) + + # Collect all batches until StopIteration + batches = [] + for batch in batch_iter: + batches.append(batch) + # Verify all batches are from epoch 0 + epoch = extract_epoch_from_batch(batch) + assert epoch == 0, f"Expected epoch 0, got {epoch}" + + # Should have consumed exactly one epoch (5 batches of size 2) + assert len(batches) == 5 + + +class TestStopAfterOneEpochDistributed(FSDPTest): + """Test StopAfterOneEpoch with distributed synchronization.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_epoch_sync_across_ranks(self): + """Verify all ranks stop when any rank detects epoch change.""" + import shutil + import tempfile + + rank = dist.get_rank() + temp_dir = tempfile.mkdtemp(prefix=f"stop_epoch_test_rank{rank}_") + + try: + data_file = Path(temp_dir) / "data.json" + # Create dataset with 20 samples, split across 2 ranks (10 each) + create_test_json_file(data_file, num_samples=20) + + dataset = HfIterableDataset( + path="json", + data_files=str(data_file), + split="train", + shuffle_buffer_size=0, + num_shards_per_rank=1, + ) + + dataloader = StatefulDataLoader( + dataset, batch_size=2, collate_fn=lambda x: x + ) + + # Get DP process group (use global group for this test) + dp_process_group = dist.group.WORLD + + batch_iter = StopAfterOneEpoch( + iter(dataloader), + torch.device("cuda"), + f"test_rank{rank}", + dp_process_group, + ) + + # Collect batches + batches = [] + for batch in batch_iter: + batches.append(batch) + # All should be epoch 0 + assert extract_epoch_from_batch(batch) == 0 + + # All ranks should have processed exactly one epoch + # Since dataset is split across ranks, each rank gets 10 samples = 5 batches + assert ( + len(batches) == 5 + ), f"Rank {rank} expected 5 batches, got {len(batches)}" + + # Synchronize to ensure both ranks completed + dist.barrier() + + finally: + shutil.rmtree(temp_dir)