diff --git a/train_embeddings/main.py b/train_embeddings/main.py index b5f198e..716d736 100644 --- a/train_embeddings/main.py +++ b/train_embeddings/main.py @@ -62,17 +62,14 @@ def get_batch(self, er_vocab, er_vocab_pairs, idx): def evaluate(self, model, data): model.eval() - hits = [] + hits = [[] for _ in range(10)] ranks = [] - for i in range(10): - hits.append([]) - test_data_idxs = self.get_data_idxs(data) - er_vocab = self.get_er_vocab(self.get_data_idxs(d.data)) + er_vocab = self.get_er_vocab(test_data_idxs) print("Number of data points: %d" % len(test_data_idxs)) for i in tqdm(range(0, len(test_data_idxs), self.batch_size)): - data_batch, _ = self.get_batch(er_vocab, test_data_idxs, i) + data_batch = np.array(test_data_idxs[i: i+self.batch_size]) e1_idx = torch.tensor(data_batch[:,0]) r_idx = torch.tensor(data_batch[:,1]) e2_idx = torch.tensor(data_batch[:,2])