diff --git a/scripts/triviaqa.py b/scripts/triviaqa.py index 5602f73..6277088 100644 --- a/scripts/triviaqa.py +++ b/scripts/triviaqa.py @@ -265,7 +265,7 @@ def __init__(self, args): self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None def load_model(self): - model = Longformer.from_pretrained(args.model_path) + model = Longformer.from_pretrained(self.args.model_path) for layer in model.encoder.layer: layer.attention.self.attention_mode = self.args.attention_mode self.args.attention_window = layer.attention.self.attention_window