1+ import argparse
2+ import json
13import os
4+ import random
5+ import re
6+
27import numpy as np
38import torch
4- import os
5- import re
6- import json
7- import argparse
8- import random
9- from transformers import T5Tokenizer , DataCollatorForSeq2Seq , Seq2SeqTrainingArguments , Seq2SeqTrainer , T5ForConditionalGeneration
10- from model import T5ForConditionalGeneration , T5ForMultimodalGeneration
11- from utils_data import img_shape , load_data_std , load_data_img , ScienceQADatasetStd , ScienceQADatasetImg
12- from utils_prompt import *
13- from utils_evaluate import get_scores
14- from rich .table import Column , Table
159from rich import box
1610from rich .console import Console
11+ from rich .table import Column , Table
12+ from transformers import (DataCollatorForSeq2Seq , Seq2SeqTrainer , Seq2SeqTrainingArguments , T5ForConditionalGeneration , T5Tokenizer )
13+
14+ from model import T5ForConditionalGeneration , T5ForMultimodalGeneration
15+ from utils_data import (ScienceQADatasetImg , ScienceQADatasetStd , img_shape , load_data_img , load_data_std )
16+ from utils_evaluate import get_scores
17+ from utils_prompt import *
18+
1719console = Console (record = True )
18- from torch import cuda
19- import nltk
2020import evaluate
21+ import nltk
22+ from torch import cuda
2123
2224
2325def parse_args ():
@@ -36,7 +38,7 @@ def parse_args():
3638 parser .add_argument ('--train_split' , type = str , default = 'train' , choices = ['train' , 'trainval' , 'minitrain' ])
3739 parser .add_argument ('--val_split' , type = str , default = 'val' , choices = ['test' , 'val' , 'minival' ])
3840 parser .add_argument ('--test_split' , type = str , default = 'test' , choices = ['test' , 'minitest' ])
39-
41+
4042 parser .add_argument ('--use_generate' , action = 'store_true' , help = 'only for baseline to improve inference speed' )
4143 parser .add_argument ('--final_eval' , action = 'store_true' , help = 'only evaluate the model at the final epoch' )
4244 parser .add_argument ('--user_msg' , type = str , default = "baseline" , help = 'experiment type in the save_dir' )
@@ -50,16 +52,15 @@ def parse_args():
5052 choices = ['QCM-A' , 'QCM-LE' , 'QCMG-A' , 'QCM-LEA' , 'QCM-ALE' ])
5153 parser .add_argument ('--seed' , type = int , default = 42 , help = 'random seed' )
5254
53- args = parser .parse_args ()
54- return args
55+ return parser .parse_args ()
5556
5657def T5Trainer (
5758 dataframe , args ,
5859):
5960 torch .manual_seed (args .seed ) # pytorch random seed
6061 np .random .seed (args .seed ) # numpy random seed
6162 torch .backends .cudnn .deterministic = True
62-
63+
6364 if args .evaluate_dir is not None :
6465 args .model = args .evaluate_dir
6566
@@ -72,7 +73,7 @@ def T5Trainer(
7273 train_qids = qids ['train' ]
7374 test_qids = qids ['test' ]
7475 val_qids = qids ['val' ]
75-
76+
7677 if args .evaluate_dir is not None :
7778 save_dir = args .evaluate_dir
7879 else :
@@ -139,7 +140,7 @@ def T5Trainer(
139140 args ,
140141 args .eval_le ,
141142 )
142-
143+
143144 test_set = ScienceQADatasetStd (
144145 problems ,
145146 test_qids ,
@@ -155,11 +156,8 @@ def T5Trainer(
155156 def extract_ans (ans ):
156157 pattern = re .compile (r'The answer is \(([A-Z])\)' )
157158 res = pattern .findall (ans )
158-
159- if len (res ) == 1 :
160- answer = res [0 ] # 'A', 'B', ...
161- else :
162- answer = "FAILED"
159+
160+ answer = res [0 ] if len (res ) == 1 else "FAILED"
163161 return answer
164162
165163 # accuracy for answer inference
@@ -184,7 +182,7 @@ def compute_metrics_acc(eval_preds):
184182 if reference == best_option :
185183 correct += 1
186184 return {'accuracy' : 1.0 * correct / len (targets )}
187-
185+
188186 # rougel for rationale generation
189187 metric = evaluate .load ("rouge" )
190188 def postprocess_text (preds , labels ):
@@ -218,13 +216,13 @@ def compute_metrics_rougel(eval_preds):
218216 if args .final_eval :
219217 training_args = Seq2SeqTrainingArguments (
220218 save_dir ,
221- do_train = True if args .evaluate_dir is None else False ,
219+ do_train = args .evaluate_dir is None ,
222220 do_eval = False ,
223221 evaluation_strategy = "no" ,
224222 logging_strategy = "steps" ,
225223 save_strategy = "epoch" ,
226- save_total_limit = 2 ,
227- learning_rate = args .lr ,
224+ save_total_limit = 2 ,
225+ learning_rate = args .lr ,
228226 eval_accumulation_steps = args .eval_acc ,
229227 per_device_train_batch_size = args .bs ,
230228 per_device_eval_batch_size = args .eval_bs ,
@@ -233,23 +231,24 @@ def compute_metrics_rougel(eval_preds):
233231 predict_with_generate = args .use_generate ,
234232 report_to = "none" ,
235233 )
236- # evaluate at each epoch
237234 else :
238235 training_args = Seq2SeqTrainingArguments (
239236 save_dir ,
240- do_train = True if args .evaluate_dir is None else False ,
237+ do_train = args .evaluate_dir is None ,
241238 do_eval = True ,
242239 evaluation_strategy = "epoch" ,
243240 logging_strategy = "steps" ,
244241 save_strategy = "epoch" ,
245- save_total_limit = 2 ,
246- learning_rate = args .lr ,
242+ save_total_limit = 2 ,
243+ learning_rate = args .lr ,
247244 eval_accumulation_steps = args .eval_acc ,
248245 per_device_train_batch_size = args .bs ,
249246 per_device_eval_batch_size = args .eval_bs ,
250247 weight_decay = 0.01 ,
251248 num_train_epochs = args .epoch ,
252- metric_for_best_model = "accuracy" if args .prompt_format != "QCM-LE" else "rougeL" ,
249+ metric_for_best_model = "accuracy"
250+ if args .prompt_format != "QCM-LE"
251+ else "rougeL" ,
253252 predict_with_generate = args .use_generate ,
254253 load_best_model_at_end = True ,
255254 report_to = "none" ,
@@ -268,12 +267,12 @@ def compute_metrics_rougel(eval_preds):
268267 if args .evaluate_dir is None :
269268 trainer .train ()
270269 trainer .save_model (save_dir )
271-
270+
272271 metrics = trainer .evaluate (eval_dataset = test_set )
273272 trainer .log_metrics ("test" , metrics )
274273 trainer .save_metrics ("test" , metrics )
275274
276- predict_results = trainer .predict (test_dataset = test_set , max_length = args .output_len )
275+ predict_results = trainer .predict (test_dataset = test_set , max_length = args .output_len )
277276 if trainer .is_world_process_zero ():
278277 if args .use_generate :
279278 preds , targets = predict_results .predictions , predict_results .label_ids
@@ -292,7 +291,7 @@ def compute_metrics_rougel(eval_preds):
292291 results_ans = {}
293292 results_rationale = {}
294293 results_reference = {}
295-
294+
296295 num_fail = 0
297296 for idx , qid in enumerate (test_qids ):
298297 pred = preds [int (idx )]
@@ -302,7 +301,7 @@ def compute_metrics_rougel(eval_preds):
302301 if extract_pred in args .options :
303302 extract_pred = args .options .index (extract_pred )
304303 else :
305- extract_pred = random .choice (range (0 , len (args .options )))
304+ extract_pred = random .choice (range (len (args .options )))
306305 else :
307306 num_fail += 1
308307 extract_pred = random .choice (range (len (args .options ))) # random choose one option
@@ -320,7 +319,7 @@ def compute_metrics_rougel(eval_preds):
320319 output_prediction_file = os .path .join (save_dir ,"predictions_ans_test.json" )
321320 with open (output_prediction_file , "w" ) as writer :
322321 writer .write (json .dumps (output_data , indent = 4 ))
323-
322+
324323 # generate the rationale for the eval set
325324 if args .prompt_format == "QCM-LE" :
326325 torch .cuda .empty_cache ()
0 commit comments