-
Context: Problem: Error: Exception has occurred: UnpicklingError Weights only load failed. This file can still be loaded, to do so you have two options, �[1mdo those steps only if you trust the source of the checkpoint�[0m. (1) In PyTorch 2.6, we changed the default value of the weights_only argument in torch.load from False to True. Re-running torch.load with weights_only set to False will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. (2) Alternatively, to load with weights_only=True please check the recommended steps in the following error message. WeightsUnpickler error: Unsupported global: GLOBAL omegaconf.base.ContainerMetadata was not an allowed global by default. Please use torch.serialization.add_safe_globals([ContainerMetadata]) or the torch.serialization.safe_globals([ContainerMetadata]) context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only [https://pytorch.org/docs/stable/generated/torch.load.html](vscode-file://vscode-app/Applications/Visual%20Studio%20Code.app/Contents/Resources/app/out/vs/code/electron-sandbox/workbench/workbench.html). _pickle.UnpicklingError: Unsupported global: GLOBAL omegaconf.base.ContainerMetadata was not an allowed global by default. Please use torch.serialization.add_safe_globals([ContainerMetadata]) or the torch.serialization.safe_globals([ContainerMetadata]) context manager to allowlist this global if you trust this class/function. Code:
# Convert target column to int8
df_train[target_col] = df_train[target_col].astype('int32')
df_test[target_col] = df_test[target_col].astype('int32')
# Split data into train and test
train, val = train_test_split(df_train, test_size=0.1, random_state=random_seed)
x_test= df_test[features]
y_test = df_test[target_col]
# Config
data_config = DataConfig(
target=target_col, # target should always be a list.
continuous_cols=num_cols,
categorical_cols=cat_cols,
)
trainer_config = TrainerConfig(
# auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
batch_size=32,
max_epochs=30,
early_stopping="valid_loss", # Monitor valid_loss for early stopping
early_stopping_mode="min", # Set the mode as min because for val_loss, lower is better
early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating
checkpoints="valid_loss", # Save best checkpoint monitoring val_loss
load_best=True, # After training, load the best checkpoint
)
optimizer_config = OptimizerConfig()
head_config = LinearHeadConfig(
layers="", # No additional layer in head, just a mapping layer to output_dim
dropout=0.1,
initialization="kaiming",
).__dict__
# FT Tabular Transformer
model_config = FTTransformerConfig(
task="classification",
learning_rate=1e-3,
head="LinearHead", # Linear Head
head_config=head_config, # Linear Head Config
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
tabular_model.fit(train=train, validation=val)
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Hey @john2408! I'm here to help you with your coding issues. I can assist in solving bugs, answering questions, and becoming a contributor. Let's tackle this problem together! The error you're encountering is related to the changes in PyTorch 2.6, where the default value of the To resolve this issue, you can try the following steps:
Here's how you can implement these changes: import torch
from omegaconf.base import ContainerMetadata
# Add ContainerMetadata to safe globals
torch.serialization.add_safe_globals([ContainerMetadata])
# When loading the model, set weights_only=False if you trust the source
model = torch.load('path_to_model_checkpoint', weights_only=False) Make sure to replace These steps should help you load the model without encountering the To continue talking to Dosu, mention @dosu. Help Dosu learn! Give it feedback: Great Response | Irrelevant Answer | Incorrect Sources | Too Verbose | Hallucination | Other |
Beta Was this translation helpful? Give feedback.
-
Hello @john2408 , |
Beta Was this translation helpful? Give feedback.
Hello @john2408 ,
This is an unpickling error that traces to the update with torch.load function in Pytorch 2.6 . You can either downgrade to a older Pytorch version or wait until PR #543 is merged.