Skip to content

Commit 8fe6d3d

Browse files
committed
Major Refactoring
1 parent 1ff1f92 commit 8fe6d3d

File tree

6 files changed

+164
-188
lines changed

6 files changed

+164
-188
lines changed

evaluations.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
'''
44

55
import re
6-
from rouge import Rouge
6+
77
from nltk.translate.bleu_score import sentence_bleu
8+
from rouge import Rouge
89
from sentence_transformers import util
910

11+
1012
########################
1113
## BLEU
1214
########################
@@ -43,9 +45,7 @@ def caculate_bleu(results, data, gram):
4345
bleu = bleu_score(target, prediction, gram)
4446
bleus.append(bleu)
4547

46-
avg_bleu = sum(bleus) / len(bleus)
47-
48-
return avg_bleu
48+
return sum(bleus) / len(bleus)
4949

5050

5151
########################
@@ -54,8 +54,7 @@ def caculate_bleu(results, data, gram):
5454
def score_rouge(str1, str2):
5555
rouge = Rouge(metrics=["rouge-l"])
5656
scores = rouge.get_scores(str1, str2, avg=True)
57-
rouge_l = scores['rouge-l']['f']
58-
return rouge_l
57+
return scores['rouge-l']['f']
5958

6059

6160
def caculate_rouge(results, data):
@@ -71,8 +70,7 @@ def caculate_rouge(results, data):
7170
rouge = score_rouge(target, prediction)
7271
rouges.append(rouge)
7372

74-
avg_rouge = sum(rouges) / len(rouges)
75-
return avg_rouge
73+
return sum(rouges) / len(rouges)
7674

7775

7876
########################
@@ -82,8 +80,7 @@ def similariry_score(str1, str2, model):
8280
# compute embedding for both lists
8381
embedding_1 = model.encode(str1, convert_to_tensor=True)
8482
embedding_2 = model.encode(str2, convert_to_tensor=True)
85-
score = util.pytorch_cos_sim(embedding_1, embedding_2).item()
86-
return score
83+
return util.pytorch_cos_sim(embedding_1, embedding_2).item()
8784

8885

8986
def caculate_similariry(results, data, model):
@@ -96,5 +93,4 @@ def caculate_similariry(results, data, model):
9693
score = similariry_score(target, prediction, model)
9794
scores.append(score)
9895

99-
avg_score = sum(scores) / len(scores)
100-
return avg_score
96+
return sum(scores) / len(scores)

main.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
1+
import argparse
2+
import json
13
import os
4+
import random
5+
import re
6+
27
import numpy as np
38
import torch
4-
import os
5-
import re
6-
import json
7-
import argparse
8-
import random
9-
from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration
10-
from model import T5ForConditionalGeneration, T5ForMultimodalGeneration
11-
from utils_data import img_shape, load_data_std, load_data_img, ScienceQADatasetStd, ScienceQADatasetImg
12-
from utils_prompt import *
13-
from utils_evaluate import get_scores
14-
from rich.table import Column, Table
159
from rich import box
1610
from rich.console import Console
11+
from rich.table import Column, Table
12+
from transformers import (DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, T5ForConditionalGeneration, T5Tokenizer)
13+
14+
from model import T5ForConditionalGeneration, T5ForMultimodalGeneration
15+
from utils_data import (ScienceQADatasetImg, ScienceQADatasetStd, img_shape, load_data_img, load_data_std)
16+
from utils_evaluate import get_scores
17+
from utils_prompt import *
18+
1719
console = Console(record=True)
18-
from torch import cuda
19-
import nltk
2020
import evaluate
21+
import nltk
22+
from torch import cuda
2123

2224

2325
def parse_args():
@@ -36,7 +38,7 @@ def parse_args():
3638
parser.add_argument('--train_split', type=str, default='train', choices=['train', 'trainval', 'minitrain'])
3739
parser.add_argument('--val_split', type=str, default='val', choices=['test', 'val', 'minival'])
3840
parser.add_argument('--test_split', type=str, default='test', choices=['test', 'minitest'])
39-
41+
4042
parser.add_argument('--use_generate', action='store_true', help='only for baseline to improve inference speed')
4143
parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch')
4244
parser.add_argument('--user_msg', type=str, default="baseline", help='experiment type in the save_dir')
@@ -50,16 +52,15 @@ def parse_args():
5052
choices=['QCM-A', 'QCM-LE', 'QCMG-A', 'QCM-LEA', 'QCM-ALE'])
5153
parser.add_argument('--seed', type=int, default=42, help='random seed')
5254

53-
args = parser.parse_args()
54-
return args
55+
return parser.parse_args()
5556

5657
def T5Trainer(
5758
dataframe, args,
5859
):
5960
torch.manual_seed(args.seed) # pytorch random seed
6061
np.random.seed(args.seed) # numpy random seed
6162
torch.backends.cudnn.deterministic = True
62-
63+
6364
if args.evaluate_dir is not None:
6465
args.model = args.evaluate_dir
6566

@@ -72,7 +73,7 @@ def T5Trainer(
7273
train_qids = qids['train']
7374
test_qids = qids['test']
7475
val_qids = qids['val']
75-
76+
7677
if args.evaluate_dir is not None:
7778
save_dir = args.evaluate_dir
7879
else:
@@ -139,7 +140,7 @@ def T5Trainer(
139140
args,
140141
args.eval_le,
141142
)
142-
143+
143144
test_set = ScienceQADatasetStd(
144145
problems,
145146
test_qids,
@@ -155,11 +156,8 @@ def T5Trainer(
155156
def extract_ans(ans):
156157
pattern = re.compile(r'The answer is \(([A-Z])\)')
157158
res = pattern.findall(ans)
158-
159-
if len(res) == 1:
160-
answer = res[0] # 'A', 'B', ...
161-
else:
162-
answer = "FAILED"
159+
160+
answer = res[0] if len(res) == 1 else "FAILED"
163161
return answer
164162

165163
# accuracy for answer inference
@@ -184,7 +182,7 @@ def compute_metrics_acc(eval_preds):
184182
if reference == best_option:
185183
correct +=1
186184
return {'accuracy': 1.0*correct/len(targets)}
187-
185+
188186
# rougel for rationale generation
189187
metric = evaluate.load("rouge")
190188
def postprocess_text(preds, labels):
@@ -218,13 +216,13 @@ def compute_metrics_rougel(eval_preds):
218216
if args.final_eval:
219217
training_args = Seq2SeqTrainingArguments(
220218
save_dir,
221-
do_train=True if args.evaluate_dir is None else False,
219+
do_train=args.evaluate_dir is None,
222220
do_eval=False,
223221
evaluation_strategy="no",
224222
logging_strategy="steps",
225223
save_strategy="epoch",
226-
save_total_limit = 2,
227-
learning_rate= args.lr,
224+
save_total_limit=2,
225+
learning_rate=args.lr,
228226
eval_accumulation_steps=args.eval_acc,
229227
per_device_train_batch_size=args.bs,
230228
per_device_eval_batch_size=args.eval_bs,
@@ -233,23 +231,24 @@ def compute_metrics_rougel(eval_preds):
233231
predict_with_generate=args.use_generate,
234232
report_to="none",
235233
)
236-
# evaluate at each epoch
237234
else:
238235
training_args = Seq2SeqTrainingArguments(
239236
save_dir,
240-
do_train=True if args.evaluate_dir is None else False,
237+
do_train=args.evaluate_dir is None,
241238
do_eval=True,
242239
evaluation_strategy="epoch",
243240
logging_strategy="steps",
244241
save_strategy="epoch",
245-
save_total_limit = 2,
246-
learning_rate= args.lr,
242+
save_total_limit=2,
243+
learning_rate=args.lr,
247244
eval_accumulation_steps=args.eval_acc,
248245
per_device_train_batch_size=args.bs,
249246
per_device_eval_batch_size=args.eval_bs,
250247
weight_decay=0.01,
251248
num_train_epochs=args.epoch,
252-
metric_for_best_model="accuracy" if args.prompt_format != "QCM-LE" else "rougeL",
249+
metric_for_best_model="accuracy"
250+
if args.prompt_format != "QCM-LE"
251+
else "rougeL",
253252
predict_with_generate=args.use_generate,
254253
load_best_model_at_end=True,
255254
report_to="none",
@@ -268,12 +267,12 @@ def compute_metrics_rougel(eval_preds):
268267
if args.evaluate_dir is None:
269268
trainer.train()
270269
trainer.save_model(save_dir)
271-
270+
272271
metrics = trainer.evaluate(eval_dataset = test_set)
273272
trainer.log_metrics("test", metrics)
274273
trainer.save_metrics("test", metrics)
275274

276-
predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len)
275+
predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len)
277276
if trainer.is_world_process_zero():
278277
if args.use_generate:
279278
preds, targets = predict_results.predictions, predict_results.label_ids
@@ -292,7 +291,7 @@ def compute_metrics_rougel(eval_preds):
292291
results_ans = {}
293292
results_rationale = {}
294293
results_reference = {}
295-
294+
296295
num_fail = 0
297296
for idx, qid in enumerate(test_qids):
298297
pred = preds[int(idx)]
@@ -302,7 +301,7 @@ def compute_metrics_rougel(eval_preds):
302301
if extract_pred in args.options:
303302
extract_pred = args.options.index(extract_pred)
304303
else:
305-
extract_pred = random.choice(range(0,len(args.options)))
304+
extract_pred = random.choice(range(len(args.options)))
306305
else:
307306
num_fail += 1
308307
extract_pred = random.choice(range(len(args.options))) # random choose one option
@@ -320,7 +319,7 @@ def compute_metrics_rougel(eval_preds):
320319
output_prediction_file = os.path.join(save_dir,"predictions_ans_test.json")
321320
with open(output_prediction_file, "w") as writer:
322321
writer.write(json.dumps(output_data, indent=4))
323-
322+
324323
# generate the rationale for the eval set
325324
if args.prompt_format == "QCM-LE":
326325
torch.cuda.empty_cache()

model.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,19 @@
22
Adapted from https://github.com/huggingface/transformers
33
'''
44

5-
from transformers import T5Config, T5ForConditionalGeneration
6-
from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG, T5EncoderModel
75
import copy
86
import math
97
import os
108
import warnings
119
from typing import Optional, Tuple, Union
10+
1211
import torch
1312
from torch import nn
1413
from torch.nn import CrossEntropyLoss
15-
from transformers.modeling_outputs import (
16-
BaseModelOutput,
17-
Seq2SeqLMOutput,
18-
)
14+
from transformers import T5Config, T5ForConditionalGeneration
15+
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
16+
from transformers.models.t5.modeling_t5 import (__HEAD_MASK_WARNING_MSG, T5EncoderModel, T5Stack)
17+
1918

2019
class T5ForMultimodalGeneration(T5ForConditionalGeneration):
2120
_keys_to_ignore_on_load_missing = [
@@ -87,10 +86,13 @@ def forward(
8786
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
8887

8988
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
90-
if head_mask is not None and decoder_head_mask is None:
91-
if self.config.num_layers == self.config.num_decoder_layers:
92-
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
93-
decoder_head_mask = head_mask
89+
if (
90+
head_mask is not None
91+
and decoder_head_mask is None
92+
and self.config.num_layers == self.config.num_decoder_layers
93+
):
94+
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
95+
decoder_head_mask = head_mask
9496

9597
# Encode if needed (training, first prediction pass)
9698
if encoder_outputs is None:
@@ -114,7 +116,7 @@ def forward(
114116

115117

116118
hidden_states = encoder_outputs[0]
117-
119+
118120
image_embedding = self.image_dense(image_ids)
119121
image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding)
120122

utils_data.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import os
2-
from torch.utils.data import Dataset
3-
import os
41
import json
2+
import os
3+
54
import numpy as np
65
import torch
6+
from torch.utils.data import Dataset
7+
78
from utils_prompt import *
89

910
img_shape = {
@@ -20,9 +21,9 @@ def load_data_std(args):
2021
for qid in problems:
2122
problems[qid]['caption'] = captions[qid] if qid in captions else ""
2223

23-
train_qids = pid_splits['%s' % (args.train_split)]
24-
val_qids = pid_splits['%s' % (args.val_split)]
25-
test_qids = pid_splits['%s' % (args.test_split)]
24+
train_qids = pid_splits[f'{args.train_split}']
25+
val_qids = pid_splits[f'{args.val_split}']
26+
test_qids = pid_splits[f'{args.test_split}']
2627
print(f"number of train problems: {len(train_qids)}\n")
2728
print(f"number of val problems: {len(val_qids)}\n")
2829
print(f"number of test problems: {len(test_qids)}\n")
@@ -43,18 +44,16 @@ def load_data_img(args):
4344
image_features = image_features.repeat(512, axis=1)
4445
elif args.img_type == "clip":
4546
image_features = np.load('vision_features/clip.npy')
46-
elif args.img_type == "detr":
47-
image_features = np.load('vision_features/detr.npy')
4847
else:
4948
image_features = np.load('vision_features/detr.npy')
5049
print("img_features size: ", image_features.shape)
5150

5251
for qid in problems:
5352
problems[qid]['caption'] = captions[qid] if qid in captions else ""
5453

55-
train_qids = pid_splits['%s' % (args.train_split)]
56-
val_qids = pid_splits['%s' % (args.val_split)]
57-
test_qids = pid_splits['%s' % (args.test_split)]
54+
train_qids = pid_splits[f'{args.train_split}']
55+
val_qids = pid_splits[f'{args.val_split}']
56+
test_qids = pid_splits[f'{args.test_split}']
5857
print(f"number of train problems: {len(train_qids)}\n")
5958
print(f"number of val problems: {len(val_qids)}\n")
6059
print(f"number of test problems: {len(test_qids)}\n")
@@ -79,10 +78,7 @@ def __init__(
7978
self.summ_len = target_len
8079
self.target_text = []
8180
self.source_text = []
82-
if test_le is not None:
83-
test_le_data =json.load(open(test_le))["preds"]
84-
else:
85-
test_le_data = None
81+
test_le_data = None if test_le is None else json.load(open(test_le))["preds"]
8682
idx = 0
8783
for qid in self.data:
8884
if test_le_data is not None:
@@ -161,10 +157,7 @@ def __init__(
161157
self.target_text = []
162158
self.source_text = []
163159
self.image_ids = []
164-
if test_le is not None:
165-
test_le_data =json.load(open(test_le))["preds"]
166-
else:
167-
test_le_data = None
160+
test_le_data = None if test_le is None else json.load(open(test_le))["preds"]
168161
idx = 0
169162
for qid in self.data:
170163
if test_le_data is not None:

0 commit comments

Comments
 (0)