Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PT: also use complete_frac for progress reporting #1698

Merged
merged 6 commits into from
Mar 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 29 additions & 62 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from __future__ import annotations
from typing import Optional, Any, Union, Callable, Dict, Set, Tuple
from typing import Optional, Any, Union, Callable, Dict, Set
from contextlib import nullcontext, ExitStack, contextmanager

import gc
Expand Down Expand Up @@ -365,8 +365,6 @@ def train_epoch(self):
zero_grad_next_step = True
cur_count_grad_accum = 0
extern_data = None
num_seqs = None
last_seq_idx = 0

total_data_size_packed = NumbersDict()
total_data_size_padded = NumbersDict()
Expand Down Expand Up @@ -400,20 +398,8 @@ def train_epoch(self):
)

complete_frac = float(extern_data_raw["complete_frac"])
num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
report_prefix=report_prefix,
extern_data_raw=extern_data_raw,
step_idx=step_idx,
prev_num_seqs=num_seqs,
prev_last_seq_idx=last_seq_idx,
)
epoch_continuous = (
self.epoch - 1 + complete_frac
if complete_frac >= 0.0
else (self.epoch - 1 + (last_seq_idx + 1) / num_seqs)
if num_seqs is not None
else None
)
epoch_continuous = self.epoch - 1 + complete_frac if complete_frac >= 0.0 else None
num_seqs = int(extern_data_raw["num_seqs"])

# clear the gradients when every gradient accumulation loop starts
if zero_grad_next_step:
Expand Down Expand Up @@ -490,7 +476,7 @@ def train_epoch(self):
eval_info=dict(eval_info),
step_duration=step_duration,
start_elapsed=step_end_time - epoch_start_time,
seq_idx=last_seq_idx,
complete_frac=complete_frac,
num_seqs=num_seqs,
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
log_memory_usage_device=self._device if self._log_memory_usage else None,
Expand Down Expand Up @@ -629,13 +615,18 @@ def eval_model(self, *, skip_already_evaluated: bool = False):
accumulated_losses_dict = NumbersDict()
accumulated_inv_norm_factors_dict = NumbersDict()
step_idx = 0
eval_start_time = time.monotonic()

report_prefix = f"ep {self.epoch} {dataset_name} eval"
with torch.no_grad():
for extern_data_raw in data_loader:
if self._torch_distributed_ctx and step_idx % 100 == 0:
_has_data = torch.tensor([True], device="cpu", dtype=torch.int8)
torch.distributed.broadcast(_has_data, src=0)

complete_frac = float(extern_data_raw["complete_frac"])
num_seqs = int(extern_data_raw["num_seqs"])

extern_data = extern_data_util.raw_dict_to_extern_data(
extern_data_raw,
extern_data_template=self.extern_data,
Expand All @@ -644,6 +635,8 @@ def eval_model(self, *, skip_already_evaluated: bool = False):
)

self._run_step(extern_data, train_func=True)
step_end_time = time.monotonic()

train_ctx = rf.get_run_ctx()

losses_dict = NumbersDict(
Expand All @@ -664,9 +657,12 @@ def eval_model(self, *, skip_already_evaluated: bool = False):
accumulated_inv_norm_factors_dict += inv_norm_factors_dict
eval_info = self._maybe_extend_losses_info(losses_dict / inv_norm_factors_dict)
_print_process(
f"ep {self.epoch} {dataset_name} eval",
report_prefix,
step=step_idx,
eval_info=dict(eval_info),
complete_frac=complete_frac,
num_seqs=num_seqs,
start_elapsed=step_end_time - eval_start_time,
log_memory_usage_device=self._device if self._log_memory_usage else None,
)
step_idx += 1
Expand Down Expand Up @@ -1290,8 +1286,6 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim:
new_dim.dyn_size_ext = _get_tensor_wo_batch_numpy(dim.dyn_size_ext)
return new_dim

num_seqs = None
last_seq_idx = 0
report_prefix = f"ep {self.epoch} {dataset.name} forward"
with torch.no_grad():
callback.init(model=self._orig_model)
Expand All @@ -1300,13 +1294,8 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim:
for extern_data_raw in data_loader:
step_begin_time = time.monotonic()

num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
report_prefix=report_prefix,
extern_data_raw=extern_data_raw,
step_idx=step_idx,
prev_num_seqs=num_seqs,
prev_last_seq_idx=last_seq_idx,
)
complete_frac = float(extern_data_raw["complete_frac"])
num_seqs = int(extern_data_raw["num_seqs"])

if self._forward_step_expected_outputs:
# Also resets any dyn dims, which might have been set in the prev step.
Expand Down Expand Up @@ -1354,7 +1343,7 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim:
eval_info=None,
step_duration=step_duration,
start_elapsed=step_end_time - epoch_start_time,
seq_idx=last_seq_idx,
complete_frac=complete_frac,
num_seqs=num_seqs,
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
log_memory_usage_device=self._device if self._log_memory_usage else None,
Expand Down Expand Up @@ -1442,7 +1431,7 @@ def _print_process(
batch_size_info: Optional[Dict[str, Any]] = None,
step_duration: Optional[float] = None,
start_elapsed: Optional[float] = None,
seq_idx: Optional[int] = None,
complete_frac: Optional[float] = None,
num_seqs: Optional[int] = None,
log_memory_usage_device: Optional[str] = None,
):
Expand All @@ -1455,11 +1444,14 @@ def _print_process(
:param batch_size_info:
:param step_duration: time elapsed for this step (secs)
:param start_elapsed: time elapsed since epoch start (secs)
:param num_seqs: total number of sequences for this epoch
:param complete_frac: how much of the current epoch is already consumed
:param num_seqs: total number of seqs this epoch
:param log_memory_usage_device: if given, will log memory usage (peak allocated memory)
:return: nothing, will be printed to log
"""
if log.verbose[5]: # report every minibatch
if step == 0 and num_seqs is not None and num_seqs >= 0:
print(f"{report_prefix} num_seqs: {num_seqs}", file=log.v5)
info = [report_prefix, "step %i" % step]
if eval_info: # Such as score.
info += ["%s %s" % (k, _format_score_value(v)) for k, v in eval_info.items()]
Expand All @@ -1475,17 +1467,16 @@ def _print_process(
info += ["%.3f sec/step" % step_duration]
if start_elapsed is not None:
info += ["elapsed %s" % hms(start_elapsed)]
if num_seqs is not None:
assert seq_idx is not None and start_elapsed is not None # unexpected combination...
complete = (seq_idx + 1) / num_seqs
assert 1 >= complete > 0, f"{step} step, {num_seqs} num_seqs"
total_time_estimated = start_elapsed / complete
if complete_frac is not None:
assert 1 >= complete_frac > 0, f"{step} step, {complete_frac} complete_frac"
assert start_elapsed is not None
total_time_estimated = start_elapsed / complete_frac
remaining_estimated = total_time_estimated - start_elapsed
info += [
"exp. remaining %s" % hms(remaining_estimated),
"complete %.02f%%" % (complete * 100),
"complete %.02f%%" % (complete_frac * 100),
]
if start_elapsed is not None and num_seqs is None:
if start_elapsed is not None and complete_frac is None:
info += ["(unk epoch len)"]
print(", ".join(filter(None, info)), file=log.v5)

Expand Down Expand Up @@ -1634,27 +1625,3 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
p=p,
).item()
)


def _get_num_seqs_last_seq_idx(
*,
report_prefix: str,
extern_data_raw: Dict[str, Any],
step_idx: int,
prev_num_seqs: Optional[int],
prev_last_seq_idx: int,
) -> Tuple[Optional[int], int]:
num_seqs = prev_num_seqs
num_seqs_ = int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
# Note: The batches might have been shuffled,
# thus we cannot really assert that the seq_idx is always increasing.
last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), prev_last_seq_idx)
if step_idx == 0:
if num_seqs_ >= 0:
print(f"{report_prefix} num_seqs: {num_seqs_}", file=log.v5)
num_seqs = num_seqs_
elif num_seqs_ >= 0:
assert num_seqs_ == num_seqs
if num_seqs is not None:
assert last_seq_idx < num_seqs
return num_seqs, last_seq_idx