From 83dac8cf30d8abe2af421eb82ffd1c5a4fc859cb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:15:37 -0800 Subject: [PATCH] [PyTorch] Add weights_only=False for torch.load (#1374) add weights_only=False for torch.load Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_float8tensor.py | 2 +- tests/pytorch/test_sanity.py | 2 +- tests/pytorch/test_torch_save_load.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 51f4c695dc..a25ffa773c 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -339,7 +339,7 @@ def test_serialization( del x_fp8, byte_stream # Deserialize tensor - x_fp8 = torch.load(io.BytesIO(x_bytes)) + x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False) del x_bytes # Check results diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 4f057c12fe..32d517460a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1101,7 +1101,7 @@ def get_model(dtype, config): del block block = get_model(dtype, config) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) torch.set_rng_state(_cpu_rng_state_new) torch.cuda.set_rng_state(_cuda_rng_state_new) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 7bf8fb99d5..be77109cb7 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -124,7 +124,7 @@ def forward(self, inp, weight): torch.save(model_in.state_dict(), tmp_filename) model_out = Test_TE_Export(precision, True) - model_out.load_state_dict(torch.load(tmp_filename)) + model_out.load_state_dict(torch.load(tmp_filename, weights_only=False)) model_out.eval() # scaling fwd @@ -263,7 +263,7 @@ def test_fp8_model_checkpoint( # to load the fp8 metadata before loading tensors. # # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that loaded model matches saved model @@ -450,7 +450,7 @@ def train_step( torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that new model's FP8 metadata matches saved model