Skip to content

Commit

Permalink
[PyTorch] Fix ONNX export bug with operation-based API (#1320)
Browse files Browse the repository at this point in the history
Debug ONNX export with te.Sequential

ONNX export assumes that all state dict objects are tensor, even extra state.

Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Nov 13, 2024
1 parent 943f1e0 commit c0a539c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def forward(
basic_op_kwargs=[kwargs],
)

def get_extra_state(self) -> Optional[torch.Tensor]:
def get_extra_state(self) -> torch.Tensor:
"""Serialize extra state
Contains metadata for FP8 casting.
Expand Down Expand Up @@ -534,7 +534,7 @@ def get_extra_state(self) -> Optional[torch.Tensor]:
self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output")
)
if not has_fp8_state:
return None
return torch.Tensor()

def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Expand Down Expand Up @@ -588,7 +588,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor:

def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load extra state"""
if state is None:
if state is None or state.numel() == 0:
return

# Deserialize state from byte tensor
Expand Down

0 comments on commit c0a539c

Please sign in to comment.