Skip to content

Commit 52bbde6

Browse files
felipemello1Hossein KavianihamedaniFelipe Mello
authored
[SFT Eval ] Add eval to SFT script (#536)
Co-authored-by: Hossein Kavianihamedani <[email protected]> Co-authored-by: Felipe Mello <[email protected]>
1 parent 5bfcfae commit 52bbde6

File tree

10 files changed

+816
-54
lines changed

10 files changed

+816
-54
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,16 @@ training:
3232
max_norm: 1.0
3333
steps: 1000
3434
compile: false
35-
dataset: "c4"
35+
datasets:
36+
- path: "yahma/alpaca-cleaned"
37+
split: "train[:95%]"
38+
39+
eval:
40+
eval_every_n_steps: 50 # null = disabled
41+
max_eval_steps: null # null = run until epoch completes
42+
datasets:
43+
- path: "yahma/alpaca-cleaned"
44+
split: "train[95%:]"
3645

3746
parallelism:
3847
data_parallel_replicate_degree: 1
@@ -62,6 +71,7 @@ metric_logging:
6271
group: sft_exp_${oc.env:USER}
6372
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
6473

74+
6575
# profiling:
6676
# enable_profiling: false
6777

apps/sft/main.py

Lines changed: 226 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212

1313
import asyncio
14-
14+
import contextlib
1515
import logging
1616
import math
1717
import os
@@ -27,6 +27,7 @@
2727
from forge.data.datasets.packed import PackedDataset, TextPacker
2828
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
2929
from forge.data.tokenizer import HuggingFaceModelTokenizer
30+
from forge.data.utils import StopAfterOneEpoch
3031
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
3132
from forge.util.config import parse
3233

@@ -96,28 +97,72 @@ def record_batch_metrics(self, data_metrics: list):
9697

9798
@endpoint
9899
async def setup(self):
99-
self.train_dataloader = self.setup_data()
100+
101+
# all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
102+
self.rank_should_record_loss = True
103+
if hasattr(self, "pp_has_last_stage") and not self.pp_has_last_stage:
104+
self.rank_should_record_loss = False
105+
106+
# metric logger
100107
self.mlogger = await self.setup_metric_logger()
101108

102-
# self.train_dataloader = self.setup_data(
103-
# self.train_config.train_dataset_config,
104-
# self.train_config.train_dataloader_config,
105-
# self.train_config.packing_config,
106-
# )
107-
# self.val_dataloader = self.setup_data(
108-
# self.train_config.val_dataset_config,
109-
# self.train_config.val_dataloader_config,
110-
# self.train_config.packing_config,
111-
# )
109+
# Load training datasets
110+
logger.info("Setting training datasets")
111+
train_datasets_config = self.job_config.training.datasets
112+
self.train_dataloader = self.setup_data(train_datasets_config)
113+
114+
# Load eval datasets
115+
eval_config = self.job_config["eval"]
116+
self.val_dataloaders = {}
117+
self.eval_every_n_steps = eval_config["eval_every_n_steps"]
118+
max_eval_steps = eval_config["max_eval_steps"]
119+
self.max_eval_steps = (
120+
max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
121+
)
122+
self.validation_enabled = (
123+
self.eval_every_n_steps is not None and self.eval_every_n_steps > 0
124+
)
125+
if self.validation_enabled:
126+
logger.info("Setting eval datasets")
127+
self.eval_datasets_config = eval_config.datasets
128+
129+
for i, dataset_config in enumerate(self.eval_datasets_config):
130+
ds_name = dataset_config.get("dataset_name", i)
131+
132+
# TODO: Support separate eval batch size from config (eval.local_batch_size)
133+
dataloader = self.setup_data([dataset_config])
134+
self.val_dataloaders[ds_name] = dataloader
112135

113136
# TODO: confirm that this is working properly
114137
# Should also use load, not dcp_load
115138
self.checkpointer.load(step=self.current_step)
139+
116140
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
117141
# self.logger = self.setup_logger(self.train_config.logger_config)
118142

119-
def setup_data(self):
120-
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
143+
def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
144+
"""Instantiates datasets and returns a StatefulDataLoader.
145+
146+
Args:
147+
dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`.
148+
149+
Returns:
150+
StatefulDataLoader
151+
152+
Raises:
153+
ValueError: If multiple datasets provided (not yet supported)
154+
"""
155+
# TODO felipemello: Currently only support single dataset
156+
if len(dataset_configs) > 1:
157+
raise ValueError(
158+
f"Multiple training datasets not supported yet. "
159+
f"Got {len(dataset_configs)} datasets. "
160+
)
161+
162+
dataset_config = dataset_configs[0]
163+
164+
# TODO: Evaluate if tokenizers should be created once and shared for every dataset
165+
# Load tokenizer
121166
tokenizer = HuggingFaceModelTokenizer(
122167
tokenizer_json_path=os.path.join(
123168
self.job_config.model.hf_assets_path, "tokenizer.json"
@@ -139,18 +184,26 @@ def setup_data(self):
139184
),
140185
)
141186

187+
# Get DP mesh for data sharding
188+
dp_mesh = None
189+
if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
190+
dp_mesh = self.parallel_dims.world_mesh.get_group("dp")
191+
192+
# Pass config directly to dataset constructor
142193
dataset = sft_iterable_dataset(
143194
model_transform=tokenizer,
144195
message_transform=AlpacaToMessages(),
145-
path="yahma/alpaca-cleaned",
146-
split="train",
196+
dp_mesh=dp_mesh,
197+
**dataset_config,
147198
)
199+
148200
packer = TextPacker(padding_idx=0)
149201
dataset = PackedDataset(
150202
dataset=dataset,
151203
packer=packer,
152204
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
153205
)
206+
154207
dataloader = StatefulDataLoader(
155208
dataset=dataset,
156209
batch_size=self.job_config.training.local_batch_size,
@@ -166,7 +219,10 @@ def setup_data(self):
166219
return dataloader
167220

168221
def forward_backward(
169-
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
222+
self,
223+
input_dict: dict[str, torch.Tensor],
224+
labels: torch.Tensor,
225+
skip_backward: bool = False,
170226
) -> torch.Tensor:
171227
model_parts = self.model_parts
172228
parallel_dims = self.parallel_dims
@@ -193,21 +249,22 @@ def forward_backward(
193249
(labels, []) if self.pp_has_last_stage else (None, None)
194250
)
195251
if self.pp_has_first_stage:
196-
self.pp_schedule.step(
197-
inputs, target=targets, losses=losses, input_batch=inputs
198-
)
252+
self.pp_schedule.step(inputs, target=targets, losses=losses)
199253
else:
200-
self.pp_schedule.step(
201-
target=targets, losses=losses, input_batch=inputs
202-
)
254+
self.pp_schedule.step(target=targets, losses=losses)
203255

204256
# accumulate losses across pipeline microbatches
205257
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
206258
loss = (
207-
torch.mean(torch.stack(losses)).to(self.device)
259+
torch.sum(torch.stack(losses)).to(self.device)
208260
if self.pp_has_last_stage
209-
else torch.tensor([-1.0], device=self.device)
261+
else torch.tensor(-1.0, device=self.device)
210262
)
263+
264+
# TODO: PP requires gradients enabled and cant deactive with no_grad
265+
if skip_backward:
266+
loss = loss.detach()
267+
211268
else:
212269
# Non-PP forward / backward
213270
with self.train_context(optional_context_parallel_ctx):
@@ -217,7 +274,10 @@ def forward_backward(
217274
loss = self.loss_fn(pred, labels)
218275
# need to free to before bwd to avoid peaking memory
219276
del pred
220-
loss.backward()
277+
278+
# Only run backward if requested. Useful for eval.
279+
if not skip_backward:
280+
loss.backward()
221281

222282
return loss
223283

@@ -230,15 +290,138 @@ def train_step(self, batch) -> None:
230290
# ) as grad_acc:
231291
labels = batch.pop("labels")
232292
loss = self.forward_backward(batch, labels)
233-
loss = loss.item()
234293

235-
record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
236-
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
294+
if self.rank_should_record_loss:
295+
loss_val = loss.item()
296+
record_metric("ForgeSFTRecipe/train_step/loss", loss_val, Reduce.MEAN)
297+
logger.info(
298+
f"step {self.current_step} / {self.num_training_steps} | Loss: {loss_val}"
299+
)
300+
237301
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
238302
# self.pbar.update(1)
239303
self.optimizers.step()
240304
self.lr_schedulers.step()
241305

306+
async def evaluate(self) -> None:
307+
"""Run evaluation on multiple datasets, one at a time.
308+
309+
1. Set models to eval mode
310+
2. For each eval dataset:
311+
- Create fresh iterator (starts from epoch 0)
312+
- Use StopAfterOneEpoch to iterate until epoch boundary. This utility
313+
is necessary for infinite iterable dataset, since epoch boundaries are not known.
314+
- Respect max_eval_steps cap if configured
315+
- Record loss and step metrics (on dp rank only)
316+
3. Restore models to train mode
317+
"""
318+
319+
# Set models to eval mode
320+
for model_part in self.model_parts:
321+
model_part.eval()
322+
323+
# Get DP process group for epoch synchronization
324+
dp_mesh = None
325+
if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
326+
dp_mesh = self.parallel_dims.world_mesh.get_group("dp")
327+
328+
# For non-PP: disable gradients to save memory
329+
# TODO: For PP, if disabling gradients, throws error
330+
maybe_no_grad = (
331+
contextlib.nullcontext()
332+
if self.parallel_dims.pp_enabled
333+
else torch.no_grad()
334+
)
335+
336+
# Evaluate each dataset sequentially
337+
all_dataset_losses = []
338+
all_dataset_steps = []
339+
for dataset_name, val_dataloader in self.val_dataloaders.items():
340+
logger.info(f"=====Evaluating dataset: {dataset_name}=====")
341+
342+
# Evaluation loop for this dataset
343+
total_loss = torch.tensor(0.0, device=self.device)
344+
num_steps = 0
345+
346+
# NOTE: Assumes batch contains field "metrics" containing "num_epochs"
347+
batch_iter = StopAfterOneEpoch(
348+
iter=iter(val_dataloader), # Fresh iterator from epoch 0,
349+
device=self.device,
350+
dp_mesh=dp_mesh,
351+
)
352+
353+
with maybe_no_grad:
354+
for batch in batch_iter:
355+
# if max_eval_steps>len(dataset), it will be stopped earlier by StopAfterOneEpoch.
356+
if (
357+
self.max_eval_steps is not None
358+
and num_steps >= self.max_eval_steps
359+
):
360+
logger.info(
361+
f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}"
362+
)
363+
break
364+
365+
# Move tensors to device
366+
for key, value in batch.items():
367+
if isinstance(value, torch.Tensor):
368+
batch[key] = value.to(self.device)
369+
370+
# Process batch
371+
labels = batch.pop("labels")
372+
loss = self.forward_backward(batch, labels, skip_backward=True)
373+
total_loss += loss
374+
num_steps += 1
375+
376+
# Log progress
377+
if self.rank_should_record_loss:
378+
loss_val = loss.item()
379+
logger.info(
380+
f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}"
381+
)
382+
383+
# log loss
384+
avg_loss = (total_loss / max(num_steps, 1)).item()
385+
all_dataset_losses.append(avg_loss)
386+
all_dataset_steps.append(num_steps)
387+
logger.info(
388+
f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}"
389+
)
390+
if self.rank_should_record_loss:
391+
record_metric(
392+
f"evaluate/dataset_{dataset_name}_avg_loss",
393+
avg_loss,
394+
Reduce.MEAN,
395+
)
396+
397+
# Record macro and micro average losses across datasets (only if multiple datasets)
398+
if self.rank_should_record_loss and len(all_dataset_losses) > 1:
399+
# Macro: same weight for all datasets
400+
macro_avg_loss = sum(all_dataset_losses) / len(all_dataset_losses)
401+
record_metric("evaluate/macro_avg_loss", macro_avg_loss, Reduce.MEAN)
402+
403+
# Micro: weighted mean by dataset size
404+
total_steps = sum(all_dataset_steps)
405+
micro_avg_loss = (
406+
sum(
407+
loss * steps
408+
for loss, steps in zip(all_dataset_losses, all_dataset_steps)
409+
)
410+
/ total_steps
411+
)
412+
record_metric("evaluate/micro_avg_loss", micro_avg_loss, Reduce.MEAN)
413+
414+
logger.info(
415+
f"Macro avg loss (unweighted): {macro_avg_loss:.4f}, "
416+
f"Micro avg loss (weighted): {micro_avg_loss:.4f}"
417+
)
418+
419+
# Restore train mode
420+
for model_part in self.model_parts:
421+
model_part.train()
422+
423+
logger.info("==Evaluation complete==")
424+
242425
@endpoint
243426
async def train(self) -> None:
244427
dataloader = iter(self.train_dataloader)
@@ -263,18 +446,28 @@ async def train(self) -> None:
263446
# self.profiler.step()
264447
self.current_step += 1
265448

266-
# Flush metrics
267-
if self._rank == 0:
268-
logger.debug(f"Flushing metrics at step {self.current_step}")
269-
await self.mlogger.flush.call_one(global_step=self.current_step)
449+
# Run evaluation periodically if enabled
450+
if (
451+
self.validation_enabled
452+
and self.current_step % self.eval_every_n_steps == 0
453+
):
454+
await self.evaluate()
270455

271456
self.checkpointer.save(
272457
curr_step=self.current_step,
273458
last_step=self.current_step == self.num_training_steps,
274459
)
275460

461+
# Flush metrics
462+
if self._rank == 0:
463+
await self.mlogger.flush.call_one(global_step=self.current_step)
464+
276465
# self.pbar.close()
277466

467+
if self.validation_enabled:
468+
logger.info("Running final evaluation at end of training...")
469+
await self.evaluate()
470+
278471
@endpoint
279472
async def cleanup(self) -> None:
280473
if self.checkpointer:

0 commit comments

Comments
 (0)