diff --git a/models/llama_v2_7b_16h.py b/models/llama_v2_7b_16h.py index ee2ff92..9757289 100644 --- a/models/llama_v2_7b_16h.py +++ b/models/llama_v2_7b_16h.py @@ -22,7 +22,7 @@ def main(): model, example = benchmark.get_module() # Run the model once in torch - expected = model(*example) + expected = model(**example) env = torch_xla2.default_env() @@ -36,16 +36,15 @@ def main(): example = env.to_xla(example) with env: start = time.perf_counter() - xla2_ans = model(*example) + xla2_ans = model(**example) print(example) end = time.perf_counter() print('Eager mode time', end - start) - - print('Eager max abs vs expected', (torch_xla2.tensor.j2t(xla2_ans._elem) - expected).abs().max()) + print('Eager max abs vs expected', (torch_xla2.tensor.j2t(xla2_ans._elem) - expected).abs().max()) def func_call(state, example): with env: - return torch.func.functional_call(model, state, example, tie_weights=False) + return torch.func.functional_call(model, state, None, example, tie_weights=False) # doing it jitted jitted = torch_xla2.interop.jax_jit(func_call) diff --git a/models/llava.py b/models/llava.py index d596096..e3e57e5 100644 --- a/models/llava.py +++ b/models/llava.py @@ -22,7 +22,7 @@ def main(): model, example = benchmark.get_module() # Run the model once in torch - expected = model(*example) + expected = model(**example) env = torch_xla2.default_env() @@ -36,7 +36,7 @@ def main(): example = env.to_xla(example) with env: start = time.perf_counter() - xla2_ans = model(*example) + xla2_ans = model(**example) print(example) end = time.perf_counter() print('Eager mode time', end - start)