From c6d4084ded155074a19875ffe03b13e47981b6bd Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Tue, 7 Nov 2023 16:28:35 +0000 Subject: [PATCH 1/3] added the utility function to print model flops --- src/transformers/trainer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 18b569b8b8f5..c201ea5923aa 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1804,6 +1804,9 @@ def _inner_training_loop( profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1)) profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000)) profile_logdir = os.environ.get('PROFILE_LOGDIR', None) + + self.num_compilations = 0 + self.last_time_stamp = time.time() for step, inputs in enumerate(epoch_iterator): if step == 0 and epoch == 0: print('input sharding', {k: (v.shape, torch_xla._XLAC._get_xla_sharding_spec(v)) for k, v in inputs.items()}) @@ -2694,8 +2697,28 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, else: self.accelerator.backward(loss) + + xm.mark_step() + if self.num_compilations != met.metric_data('CompileTime')[:1] : + self.num_compilations = met.metric_data('CompileTime')[:1] + else: + step_time = time.time() - self.last_time_stamp + num_tokens = inputs["input_ids"].numel() + xm.master_print(f"Step time: {step_time}: Model TFLOPS: {self.model_flops(step_time, num_tokens)}") + xm.master_print(f"Memory Info: {xm.get_memory_info(xm.xla_device())}") + self.last_time_step = time.time() + + return loss.detach() / self.args.gradient_accumulation_steps + def model_flops(self, step_time, num_tokens): + num_trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + model_flops = 6 * num_trainable_params * num_tokens + model_tflops_per_second = model_flops / step_time / 1e12 + return model_tflops_per_second + + + def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -3169,6 +3192,9 @@ def evaluation_loop( if is_torch_tpu_available(): xm.mark_step() + + + # Update containers on host if loss is not None: losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) From d4e8e8efaddbdc605eb16a896d4a3a7073849420 Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Wed, 8 Nov 2023 05:25:31 +0000 Subject: [PATCH 2/3] refining step time calc --- src/transformers/trainer.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c201ea5923aa..36aec2b91530 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1806,8 +1806,8 @@ def _inner_training_loop( profile_logdir = os.environ.get('PROFILE_LOGDIR', None) self.num_compilations = 0 - self.last_time_stamp = time.time() for step, inputs in enumerate(epoch_iterator): + self.last_time_stamp = time.time() if step == 0 and epoch == 0: print('input sharding', {k: (v.shape, torch_xla._XLAC._get_xla_sharding_spec(v)) for k, v in inputs.items()}) total_batched_samples += 1 @@ -1899,6 +1899,17 @@ def _inner_training_loop( if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() + + xm.mark_step() + if self.num_compilations != met.metric_data('CompileTime')[:1] : + self.num_compilations = met.metric_data('CompileTime')[:1] + else: + xm.wait_device_ops() + step_time = time.time() - self.last_time_stamp + num_tokens = inputs["input_ids"].numel() / self.args.spmd_mesh.ici_mesh_shape[1] + xm.master_print(f"Num Tokens:{num_tokens},ICI mesh shape: {self.args.spmd_mesh.ici_mesh_shape}") + xm.master_print(f"Step time: {step_time}: Model TFLOPS: {self.model_flops(step_time, num_tokens)}") + model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch @@ -1908,6 +1919,7 @@ def _inner_training_loop( else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + if self.control.should_epoch_stop or self.control.should_training_stop: break @@ -2698,15 +2710,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, self.accelerator.backward(loss) - xm.mark_step() - if self.num_compilations != met.metric_data('CompileTime')[:1] : - self.num_compilations = met.metric_data('CompileTime')[:1] - else: - step_time = time.time() - self.last_time_stamp - num_tokens = inputs["input_ids"].numel() - xm.master_print(f"Step time: {step_time}: Model TFLOPS: {self.model_flops(step_time, num_tokens)}") - xm.master_print(f"Memory Info: {xm.get_memory_info(xm.xla_device())}") - self.last_time_step = time.time() + # TODO: implement memory info for PJRT + #xm.master_print(f"Memory Info: {xm.get_memory_info(xm.xla_device())}") return loss.detach() / self.args.gradient_accumulation_steps From ccfc005d48ec006648796c7ccc93b5afda0cebeb Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Thu, 9 Nov 2023 03:05:56 +0000 Subject: [PATCH 3/3] clean up --- src/transformers/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 36aec2b91530..01951f823850 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1904,10 +1904,11 @@ def _inner_training_loop( if self.num_compilations != met.metric_data('CompileTime')[:1] : self.num_compilations = met.metric_data('CompileTime')[:1] else: - xm.wait_device_ops() + xm.rendezvous('step') step_time = time.time() - self.last_time_stamp - num_tokens = inputs["input_ids"].numel() / self.args.spmd_mesh.ici_mesh_shape[1] - xm.master_print(f"Num Tokens:{num_tokens},ICI mesh shape: {self.args.spmd_mesh.ici_mesh_shape}") + data, fsdp, mdl = self.args.spmd_mesh.ici_mesh_shape + num_devices = data * fsdp * mdl + num_tokens = inputs["input_ids"].numel() / num_devices xm.master_print(f"Step time: {step_time}: Model TFLOPS: {self.model_flops(step_time, num_tokens)}") model.zero_grad()