From 87059ea9f239ece2e957daf8800ac118b16324b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Doreau?= Date: Wed, 22 Jan 2025 19:17:23 +0200 Subject: [PATCH] Fix: Correct input type for pyhpc_equation_of_state in Torch_XLA2 Converts the example input list to a tuple before passing it to torch.func.functional_call to resolve argument type mismatch. --- models/pyhpc_equation_of_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/pyhpc_equation_of_state.py b/models/pyhpc_equation_of_state.py index cab4ca3..a836c49 100644 --- a/models/pyhpc_equation_of_state.py +++ b/models/pyhpc_equation_of_state.py @@ -45,7 +45,7 @@ def main(): def func_call(state, example): with env: - return torch.func.functional_call(model, state, example, tie_weights=False) + return torch.func.functional_call(model, state, tuple(example), tie_weights=False) # doing it jitted jitted = torch_xla2.interop.jax_jit(func_call)