@@ -35,21 +35,30 @@ def extract_epoch_from_batch(batch: dict) -> int | None:
3535 Epoch number from metrics, or None if not found
3636 """
3737 if "metrics" in batch :
38+ # Look for num_epochs in metric keys
39+ for metric in batch ["metrics" ]:
40+ # Metrics have a 'key' attribute with paths like:
41+ # 'dataset/yahma_alpaca-cleaned_train[:1%]/num_epochs'
42+ if hasattr (metric , "key" ) and "num_epochs" in metric .key :
43+ return int (metric .value )
44+
45+ # Fallback: check for old-style metric_name attribute
3846 for metric in batch ["metrics" ]:
3947 if hasattr (metric , "metric_name" ) and metric .metric_name == "num_epochs" :
40- return metric .value
48+ return int (metric .value )
49+
4150 return None
4251
4352
4453def start_epoch_sync (
45- epoch_increment : int ,
54+ epoch_changed : bool ,
4655 device : torch .device ,
4756 dp_process_group : Any = None ,
4857) -> tuple [torch .Tensor | None , Any ]:
4958 """Start async all_reduce for epoch synchronization across ranks.
5059
5160 Args:
52- epoch_increment: Difference between current and starting epoch
61+ epoch_changed: Whether the epoch changed on this rank
5362 device: Device for tensor
5463 dp_process_group: Data parallel process group (None = default group)
5564
@@ -59,7 +68,8 @@ def start_epoch_sync(
5968 if not torch .distributed .is_initialized ():
6069 return None , None
6170
62- epoch_tensor = torch .tensor ([epoch_increment ], dtype = torch .long , device = device )
71+ # Convert bool to tensor: 1 if epoch changed, 0 otherwise
72+ epoch_tensor = torch .tensor ([int (epoch_changed )], dtype = torch .long , device = device )
6373 pending_work = torch .distributed .all_reduce (
6474 epoch_tensor ,
6575 op = torch .distributed .ReduceOp .MAX ,
@@ -117,7 +127,7 @@ def eval_loop(
117127 Tuple of (avg_loss, num_batches)
118128 """
119129 total_loss = torch .tensor (0.0 , device = device )
120- num_batches , starting_epoch = 0 , None
130+ num_batches = 0
121131
122132 # Prefetch first batch
123133 next_batch = next (dataloader_iter )
@@ -142,26 +152,41 @@ def eval_loop(
142152
143153 batch = next_batch
144154
145- # Track starting epoch
155+ # Get current batch epoch
146156 current_epoch = extract_epoch_fn (batch )
147- if starting_epoch is None :
148- starting_epoch = current_epoch
149157
150- # Prefetch next batch and start async epoch check
158+ # Prefetch next batch and check for epoch change
151159 try :
152160 next_batch = next (dataloader_iter )
153161 next_epoch = extract_epoch_fn (next_batch )
154162
155- # Only check epochs if both are available
156- if next_epoch is not None and starting_epoch is not None :
157- epoch_increment = next_epoch - starting_epoch
163+ # Simple check: did epoch change between consecutive batches?
164+ if next_epoch is not None and current_epoch is not None :
165+ epoch_changed = next_epoch > current_epoch
166+
167+ if epoch_changed :
168+ logger .info (
169+ f"[{ dataset_name } ] Epoch change detected: "
170+ f"{ current_epoch } → { next_epoch } "
171+ )
172+
158173 if torch .distributed .is_initialized ():
174+ # All-reduce: if ANY rank's epoch changed, all ranks should stop
159175 epoch_tensor , pending_work = start_epoch_sync (
160- epoch_increment , device , dp_process_group
176+ epoch_changed , device , dp_process_group
161177 )
162178 else :
163- should_break = epoch_increment > 0
179+ # Single process: stop immediately if epoch changed
180+ should_break = epoch_changed
181+ else :
182+ # No epoch tracking available - rely on eval_steps
183+ if num_batches == 0 :
184+ logger .info (
185+ f"[{ dataset_name } ] No epoch tracking available "
186+ f"(current={ current_epoch } , next={ next_epoch } )"
187+ )
164188 except StopIteration :
189+ logger .info (f"[{ dataset_name } ] StopIteration - dataloader exhausted" )
165190 should_break = True
166191
167192 # Process current batch (overlaps with async all_reduce)
0 commit comments