1111"""
1212
1313import asyncio
14-
14+ import contextlib
1515import logging
1616import math
1717import os
2727from forge .data .datasets .packed import PackedDataset , TextPacker
2828from forge .data .datasets .sft_dataset import AlpacaToMessages , sft_iterable_dataset
2929from forge .data .tokenizer import HuggingFaceModelTokenizer
30+ from forge .data .utils import StopAfterOneEpoch
3031from forge .observability import get_or_create_metric_logger , record_metric , Reduce
3132from forge .util .config import parse
3233
@@ -96,28 +97,72 @@ def record_batch_metrics(self, data_metrics: list):
9697
9798 @endpoint
9899 async def setup (self ):
99- self .train_dataloader = self .setup_data ()
100+
101+ # all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
102+ self .rank_should_record_loss = True
103+ if hasattr (self , "pp_has_last_stage" ) and not self .pp_has_last_stage :
104+ self .rank_should_record_loss = False
105+
106+ # metric logger
100107 self .mlogger = await self .setup_metric_logger ()
101108
102- # self.train_dataloader = self.setup_data(
103- # self.train_config.train_dataset_config,
104- # self.train_config.train_dataloader_config,
105- # self.train_config.packing_config,
106- # )
107- # self.val_dataloader = self.setup_data(
108- # self.train_config.val_dataset_config,
109- # self.train_config.val_dataloader_config,
110- # self.train_config.packing_config,
111- # )
109+ # Load training datasets
110+ logger .info ("Setting training datasets" )
111+ train_datasets_config = self .job_config .training .datasets
112+ self .train_dataloader = self .setup_data (train_datasets_config )
113+
114+ # Load eval datasets
115+ eval_config = self .job_config ["eval" ]
116+ self .val_dataloaders = {}
117+ self .eval_every_n_steps = eval_config ["eval_every_n_steps" ]
118+ max_eval_steps = eval_config ["max_eval_steps" ]
119+ self .max_eval_steps = (
120+ max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
121+ )
122+ self .validation_enabled = (
123+ self .eval_every_n_steps is not None and self .eval_every_n_steps > 0
124+ )
125+ if self .validation_enabled :
126+ logger .info ("Setting eval datasets" )
127+ self .eval_datasets_config = eval_config .datasets
128+
129+ for i , dataset_config in enumerate (self .eval_datasets_config ):
130+ ds_name = dataset_config .get ("dataset_name" , i )
131+
132+ # TODO: Support separate eval batch size from config (eval.local_batch_size)
133+ dataloader = self .setup_data ([dataset_config ])
134+ self .val_dataloaders [ds_name ] = dataloader
112135
113136 # TODO: confirm that this is working properly
114137 # Should also use load, not dcp_load
115138 self .checkpointer .load (step = self .current_step )
139+
116140 # self.profiler = self.setup_profiler(self.train_config.profiler_config)
117141 # self.logger = self.setup_logger(self.train_config.logger_config)
118142
119- def setup_data (self ):
120- print (os .path .join (self .job_config .model .hf_assets_path , "tokenizer.json" ))
143+ def setup_data (self , dataset_configs : list [dict ]) -> StatefulDataLoader :
144+ """Instantiates datasets and returns a StatefulDataLoader.
145+
146+ Args:
147+ dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`.
148+
149+ Returns:
150+ StatefulDataLoader
151+
152+ Raises:
153+ ValueError: If multiple datasets provided (not yet supported)
154+ """
155+ # TODO felipemello: Currently only support single dataset
156+ if len (dataset_configs ) > 1 :
157+ raise ValueError (
158+ f"Multiple training datasets not supported yet. "
159+ f"Got { len (dataset_configs )} datasets. "
160+ )
161+
162+ dataset_config = dataset_configs [0 ]
163+
164+ # TODO: Evaluate if tokenizers should be created once and shared for every dataset
165+ # Load tokenizer
121166 tokenizer = HuggingFaceModelTokenizer (
122167 tokenizer_json_path = os .path .join (
123168 self .job_config .model .hf_assets_path , "tokenizer.json"
@@ -139,18 +184,26 @@ def setup_data(self):
139184 ),
140185 )
141186
187+ # Get DP mesh for data sharding
188+ dp_mesh = None
189+ if self .parallel_dims is not None and self .parallel_dims .dp_enabled :
190+ dp_mesh = self .parallel_dims .world_mesh .get_group ("dp" )
191+
192+ # Pass config directly to dataset constructor
142193 dataset = sft_iterable_dataset (
143194 model_transform = tokenizer ,
144195 message_transform = AlpacaToMessages (),
145- path = "yahma/alpaca-cleaned" ,
146- split = "train" ,
196+ dp_mesh = dp_mesh ,
197+ ** dataset_config ,
147198 )
199+
148200 packer = TextPacker (padding_idx = 0 )
149201 dataset = PackedDataset (
150202 dataset = dataset ,
151203 packer = packer ,
152204 target_tokens_per_pack = self .job_config .training .seq_len , # TODO: get this from model
153205 )
206+
154207 dataloader = StatefulDataLoader (
155208 dataset = dataset ,
156209 batch_size = self .job_config .training .local_batch_size ,
@@ -166,7 +219,10 @@ def setup_data(self):
166219 return dataloader
167220
168221 def forward_backward (
169- self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
222+ self ,
223+ input_dict : dict [str , torch .Tensor ],
224+ labels : torch .Tensor ,
225+ skip_backward : bool = False ,
170226 ) -> torch .Tensor :
171227 model_parts = self .model_parts
172228 parallel_dims = self .parallel_dims
@@ -193,21 +249,22 @@ def forward_backward(
193249 (labels , []) if self .pp_has_last_stage else (None , None )
194250 )
195251 if self .pp_has_first_stage :
196- self .pp_schedule .step (
197- inputs , target = targets , losses = losses , input_batch = inputs
198- )
252+ self .pp_schedule .step (inputs , target = targets , losses = losses )
199253 else :
200- self .pp_schedule .step (
201- target = targets , losses = losses , input_batch = inputs
202- )
254+ self .pp_schedule .step (target = targets , losses = losses )
203255
204256 # accumulate losses across pipeline microbatches
205257 # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
206258 loss = (
207- torch .mean (torch .stack (losses )).to (self .device )
259+ torch .sum (torch .stack (losses )).to (self .device )
208260 if self .pp_has_last_stage
209- else torch .tensor ([ - 1.0 ] , device = self .device )
261+ else torch .tensor (- 1.0 , device = self .device )
210262 )
263+
264+ # TODO: PP requires gradients enabled and cant deactive with no_grad
265+ if skip_backward :
266+ loss = loss .detach ()
267+
211268 else :
212269 # Non-PP forward / backward
213270 with self .train_context (optional_context_parallel_ctx ):
@@ -217,7 +274,10 @@ def forward_backward(
217274 loss = self .loss_fn (pred , labels )
218275 # need to free to before bwd to avoid peaking memory
219276 del pred
220- loss .backward ()
277+
278+ # Only run backward if requested. Useful for eval.
279+ if not skip_backward :
280+ loss .backward ()
221281
222282 return loss
223283
@@ -230,15 +290,138 @@ def train_step(self, batch) -> None:
230290 # ) as grad_acc:
231291 labels = batch .pop ("labels" )
232292 loss = self .forward_backward (batch , labels )
233- loss = loss .item ()
234293
235- record_metric ("ForgeSFTRecipe/train_step/loss" , loss , Reduce .MEAN )
236- logger .info (f"{ self .current_step } / { self .num_training_steps } |Loss: { loss } " )
294+ if self .rank_should_record_loss :
295+ loss_val = loss .item ()
296+ record_metric ("ForgeSFTRecipe/train_step/loss" , loss_val , Reduce .MEAN )
297+ logger .info (
298+ f"step { self .current_step } / { self .num_training_steps } | Loss: { loss_val } "
299+ )
300+
237301 # self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
238302 # self.pbar.update(1)
239303 self .optimizers .step ()
240304 self .lr_schedulers .step ()
241305
306+ async def evaluate (self ) -> None :
307+ """Run evaluation on multiple datasets, one at a time.
308+
309+ 1. Set models to eval mode
310+ 2. For each eval dataset:
311+ - Create fresh iterator (starts from epoch 0)
312+ - Use StopAfterOneEpoch to iterate until epoch boundary. This utility
313+ is necessary for infinite iterable dataset, since epoch boundaries are not known.
314+ - Respect max_eval_steps cap if configured
315+ - Record loss and step metrics (on dp rank only)
316+ 3. Restore models to train mode
317+ """
318+
319+ # Set models to eval mode
320+ for model_part in self .model_parts :
321+ model_part .eval ()
322+
323+ # Get DP process group for epoch synchronization
324+ dp_mesh = None
325+ if self .parallel_dims is not None and self .parallel_dims .dp_enabled :
326+ dp_mesh = self .parallel_dims .world_mesh .get_group ("dp" )
327+
328+ # For non-PP: disable gradients to save memory
329+ # TODO: For PP, if disabling gradients, throws error
330+ maybe_no_grad = (
331+ contextlib .nullcontext ()
332+ if self .parallel_dims .pp_enabled
333+ else torch .no_grad ()
334+ )
335+
336+ # Evaluate each dataset sequentially
337+ all_dataset_losses = []
338+ all_dataset_steps = []
339+ for dataset_name , val_dataloader in self .val_dataloaders .items ():
340+ logger .info (f"=====Evaluating dataset: { dataset_name } =====" )
341+
342+ # Evaluation loop for this dataset
343+ total_loss = torch .tensor (0.0 , device = self .device )
344+ num_steps = 0
345+
346+ # NOTE: Assumes batch contains field "metrics" containing "num_epochs"
347+ batch_iter = StopAfterOneEpoch (
348+ iter = iter (val_dataloader ), # Fresh iterator from epoch 0,
349+ device = self .device ,
350+ dp_mesh = dp_mesh ,
351+ )
352+
353+ with maybe_no_grad :
354+ for batch in batch_iter :
355+ # if max_eval_steps>len(dataset), it will be stopped earlier by StopAfterOneEpoch.
356+ if (
357+ self .max_eval_steps is not None
358+ and num_steps >= self .max_eval_steps
359+ ):
360+ logger .info (
361+ f"[{ dataset_name } ] Reached max_eval_steps cap of { self .max_eval_steps } "
362+ )
363+ break
364+
365+ # Move tensors to device
366+ for key , value in batch .items ():
367+ if isinstance (value , torch .Tensor ):
368+ batch [key ] = value .to (self .device )
369+
370+ # Process batch
371+ labels = batch .pop ("labels" )
372+ loss = self .forward_backward (batch , labels , skip_backward = True )
373+ total_loss += loss
374+ num_steps += 1
375+
376+ # Log progress
377+ if self .rank_should_record_loss :
378+ loss_val = loss .item ()
379+ logger .info (
380+ f"[dataset { dataset_name } ] Step { num_steps } | Loss: { loss_val :.4f} "
381+ )
382+
383+ # log loss
384+ avg_loss = (total_loss / max (num_steps , 1 )).item ()
385+ all_dataset_losses .append (avg_loss )
386+ all_dataset_steps .append (num_steps )
387+ logger .info (
388+ f"[dataset { dataset_name } ] Final Step { num_steps } | Avg Loss: { avg_loss :.4f} "
389+ )
390+ if self .rank_should_record_loss :
391+ record_metric (
392+ f"evaluate/dataset_{ dataset_name } _avg_loss" ,
393+ avg_loss ,
394+ Reduce .MEAN ,
395+ )
396+
397+ # Record macro and micro average losses across datasets (only if multiple datasets)
398+ if self .rank_should_record_loss and len (all_dataset_losses ) > 1 :
399+ # Macro: same weight for all datasets
400+ macro_avg_loss = sum (all_dataset_losses ) / len (all_dataset_losses )
401+ record_metric ("evaluate/macro_avg_loss" , macro_avg_loss , Reduce .MEAN )
402+
403+ # Micro: weighted mean by dataset size
404+ total_steps = sum (all_dataset_steps )
405+ micro_avg_loss = (
406+ sum (
407+ loss * steps
408+ for loss , steps in zip (all_dataset_losses , all_dataset_steps )
409+ )
410+ / total_steps
411+ )
412+ record_metric ("evaluate/micro_avg_loss" , micro_avg_loss , Reduce .MEAN )
413+
414+ logger .info (
415+ f"Macro avg loss (unweighted): { macro_avg_loss :.4f} , "
416+ f"Micro avg loss (weighted): { micro_avg_loss :.4f} "
417+ )
418+
419+ # Restore train mode
420+ for model_part in self .model_parts :
421+ model_part .train ()
422+
423+ logger .info ("==Evaluation complete==" )
424+
242425 @endpoint
243426 async def train (self ) -> None :
244427 dataloader = iter (self .train_dataloader )
@@ -263,18 +446,28 @@ async def train(self) -> None:
263446 # self.profiler.step()
264447 self .current_step += 1
265448
266- # Flush metrics
267- if self ._rank == 0 :
268- logger .debug (f"Flushing metrics at step { self .current_step } " )
269- await self .mlogger .flush .call_one (global_step = self .current_step )
449+ # Run evaluation periodically if enabled
450+ if (
451+ self .validation_enabled
452+ and self .current_step % self .eval_every_n_steps == 0
453+ ):
454+ await self .evaluate ()
270455
271456 self .checkpointer .save (
272457 curr_step = self .current_step ,
273458 last_step = self .current_step == self .num_training_steps ,
274459 )
275460
461+ # Flush metrics
462+ if self ._rank == 0 :
463+ await self .mlogger .flush .call_one (global_step = self .current_step )
464+
276465 # self.pbar.close()
277466
467+ if self .validation_enabled :
468+ logger .info ("Running final evaluation at end of training..." )
469+ await self .evaluate ()
470+
278471 @endpoint
279472 async def cleanup (self ) -> None :
280473 if self .checkpointer :
0 commit comments