From 5b598ea285a2024690c1753bbe3ffdbee6c0c8cc Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Mon, 30 Dec 2024 13:58:25 +0800 Subject: [PATCH 1/2] warn instead of raising exceptions in inf-check --- icefall/hooks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/icefall/hooks.py b/icefall/hooks.py index 1c5bd2ae68..9ae0184aae 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -40,8 +40,8 @@ def register_inf_check_hooks(model: nn.Module) -> None: def forward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): if not torch.isfinite(_output.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output is not finite: {_output}" + logging.warning( + f"The sum of {_name}.output is not finite" ) elif isinstance(_output, tuple): for i, o in enumerate(_output): @@ -50,8 +50,8 @@ def forward_hook(_module, _input, _output, _name=name): if not isinstance(o, Tensor): continue if not torch.isfinite(o.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output[{i}] is not finite: {_output}" + logging.warning( + f"The sum of {_name}.output[{i}] is not finite" ) # default param _name is a way to capture the current value of the variable "name". From 35c8de75673d835693e57cddebbe9a99eb728a8a Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Mon, 30 Dec 2024 16:04:58 +0800 Subject: [PATCH 2/2] Fix the style issue --- icefall/hooks.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/icefall/hooks.py b/icefall/hooks.py index 9ae0184aae..83f2750faf 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -40,9 +40,7 @@ def register_inf_check_hooks(model: nn.Module) -> None: def forward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): if not torch.isfinite(_output.to(torch.float32).sum()): - logging.warning( - f"The sum of {_name}.output is not finite" - ) + logging.warning(f"The sum of {_name}.output is not finite") elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): @@ -50,9 +48,7 @@ def forward_hook(_module, _input, _output, _name=name): if not isinstance(o, Tensor): continue if not torch.isfinite(o.to(torch.float32).sum()): - logging.warning( - f"The sum of {_name}.output[{i}] is not finite" - ) + logging.warning(f"The sum of {_name}.output[{i}] is not finite") # default param _name is a way to capture the current value of the variable "name". def backward_hook(_module, _input, _output, _name=name):