From f13b86173a674ca6fd6035902455c2f3c5675784 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 24 Feb 2025 16:55:44 -0800 Subject: [PATCH] fix sequences_pre Signed-off-by: Charlene Yang --- transformer_engine/pytorch/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 43ac09dbf6..67dec8c376 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -320,7 +320,8 @@ def pre_step( self.batch_size = len(step_dict) self.sequences = self.cache_manager.pre_step(step_dict) - for k, v in enumerate(self.sequences): + self.sequences_pre = OrderedDict() + for k, v in self.sequences.items(): self.sequences_pre[k] = v - self.step_dict[k] actual_batch_size = len(step_dict)