diff --git a/models/pyhpc_equation_of_state.py b/models/pyhpc_equation_of_state.py index cab4ca3..a836c49 100644 --- a/models/pyhpc_equation_of_state.py +++ b/models/pyhpc_equation_of_state.py @@ -45,7 +45,7 @@ def main(): def func_call(state, example): with env: - return torch.func.functional_call(model, state, example, tie_weights=False) + return torch.func.functional_call(model, state, tuple(example), tie_weights=False) # doing it jitted jitted = torch_xla2.interop.jax_jit(func_call)