Skip to content

Commit 2b7da3e

Browse files
committed
selecting data to train
1 parent 28971c0 commit 2b7da3e

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class Config:
88
def __init__(self):
99
"""Expirement configuration"""
1010

11-
self.mode = "train" # TODO: Why do we want this in config?
11+
self.mode = "dev"
1212

1313
# Device params
1414
# self.use_cuda = torch.cuda.is_available()
@@ -17,7 +17,7 @@ def __init__(self):
1717
self.use_cuda = False
1818

1919
# Global dimension params
20-
self.embedding_dim = 100
20+
self.embedding_dim = 50
2121
self.hidden_size = self.embedding_dim
2222
self.context_len = 200 # TODO: Why do we need this?
2323
self.question_len = 20 # TODO: Why do we need this?

data_utils.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def get_data(batch, is_train=True):
2323
qn_seq_var = torch.from_numpy(batch.qn_ids).long()
2424
context_seq_var = torch.from_numpy(batch.context_ids).long()
2525

26-
if config.mode == "train":
27-
span_var = torch.from_numpy(batch.ans_span).long()
26+
span_var = torch.from_numpy(batch.ans_span).long()
2827

2928
if is_train:
3029
qn_mask_var = qn_mask_var.to(config.device)
@@ -34,10 +33,7 @@ def get_data(batch, is_train=True):
3433
if is_train:
3534
span_var = span_var.to(config.device)
3635

37-
if config.mode == "train":
38-
return qn_seq_var, qn_mask_var, context_seq_var, context_mask_var, span_var
39-
else:
40-
return qn_seq_var, qn_mask_var, context_seq_var, context_mask_var
36+
return qn_seq_var, qn_mask_var, context_seq_var, context_mask_var, span_var
4137

4238

4339
import time

train.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def evaluate():
5656
pass
5757

5858

59-
def train(*args, **kwargs):
59+
def train(context_path, qn_path, ans_path):
6060
""" Train the network """
6161

6262
model = N.CoattentionNetwork(device=config.device,
@@ -109,8 +109,7 @@ def train(*args, **kwargs):
109109
# For each epoch
110110
for epoch in range(config.num_epochs):
111111
# For each batch
112-
for i, batch in enumerate(get_batch_generator(word2index, train_context_path,
113-
train_qn_path, train_ans_path,
112+
for i, batch in enumerate(get_batch_generator(word2index, context_path, qn_path, ans_path,
114113
config.batch_size, config.context_len,
115114
config.question_len, discard_long=True)):
116115
# Take step in training
@@ -136,4 +135,12 @@ def train(*args, **kwargs):
136135

137136

138137
if __name__ == '__main__':
139-
train()
138+
if config.mode == 'train':
139+
context_path = train_context_path
140+
qn_path = train_qn_path
141+
ans_path = train_ans_path
142+
else:
143+
context_path = dev_context_path
144+
qn_path = dev_qn_path
145+
ans_path = dev_ans_path
146+
train(context_path, qn_path, ans_path)

0 commit comments

Comments
 (0)