|
| 1 | +from posixpath import join |
| 2 | +import numpy |
| 3 | +from numpy.lib.npyio import save |
| 4 | +from script.data_iterator import DataIterator |
| 5 | +import tensorflow as tf |
| 6 | +import time |
| 7 | +import random |
| 8 | +import sys |
| 9 | +from script.utils import * |
| 10 | +from tensorflow.python.framework import ops |
| 11 | +import os |
| 12 | +import json |
| 13 | + |
| 14 | +EMBEDDING_DIM = 18 |
| 15 | +HIDDEN_SIZE = 18 * 2 |
| 16 | +ATTENTION_SIZE = 18 * 2 |
| 17 | +best_auc = 0.0 |
| 18 | +best_case_acc = 0.0 |
| 19 | +batch_size=1 |
| 20 | +maxlen=100 |
| 21 | + |
| 22 | +def prepare_data(input, target, maxlen=None, return_neg=False): |
| 23 | + # x: a list of sentences |
| 24 | + lengths_x = [len(s[4]) for s in input] |
| 25 | + seqs_mid = [inp[3] for inp in input] |
| 26 | + seqs_cat = [inp[4] for inp in input] |
| 27 | + noclk_seqs_mid = [inp[5] for inp in input] |
| 28 | + noclk_seqs_cat = [inp[6] for inp in input] |
| 29 | + |
| 30 | + if maxlen is not None: |
| 31 | + new_seqs_mid = [] |
| 32 | + new_seqs_cat = [] |
| 33 | + new_noclk_seqs_mid = [] |
| 34 | + new_noclk_seqs_cat = [] |
| 35 | + new_lengths_x = [] |
| 36 | + for l_x, inp in zip(lengths_x, input): |
| 37 | + if l_x > maxlen: |
| 38 | + new_seqs_mid.append(inp[3][l_x - maxlen:]) |
| 39 | + new_seqs_cat.append(inp[4][l_x - maxlen:]) |
| 40 | + new_noclk_seqs_mid.append(inp[5][l_x - maxlen:]) |
| 41 | + new_noclk_seqs_cat.append(inp[6][l_x - maxlen:]) |
| 42 | + new_lengths_x.append(maxlen) |
| 43 | + else: |
| 44 | + new_seqs_mid.append(inp[3]) |
| 45 | + new_seqs_cat.append(inp[4]) |
| 46 | + new_noclk_seqs_mid.append(inp[5]) |
| 47 | + new_noclk_seqs_cat.append(inp[6]) |
| 48 | + new_lengths_x.append(l_x) |
| 49 | + lengths_x = new_lengths_x |
| 50 | + seqs_mid = new_seqs_mid |
| 51 | + seqs_cat = new_seqs_cat |
| 52 | + noclk_seqs_mid = new_noclk_seqs_mid |
| 53 | + noclk_seqs_cat = new_noclk_seqs_cat |
| 54 | + |
| 55 | + if len(lengths_x) < 1: |
| 56 | + return None, None, None, None |
| 57 | + |
| 58 | + n_samples = len(seqs_mid) |
| 59 | + maxlen_x = numpy.max(lengths_x) |
| 60 | + neg_samples = len(noclk_seqs_mid[0][0]) |
| 61 | + |
| 62 | + mid_his = numpy.zeros((n_samples, maxlen_x)).astype('int64') |
| 63 | + cat_his = numpy.zeros((n_samples, maxlen_x)).astype('int64') |
| 64 | + noclk_mid_his = numpy.zeros( |
| 65 | + (n_samples, maxlen_x, neg_samples)).astype('int64') |
| 66 | + noclk_cat_his = numpy.zeros( |
| 67 | + (n_samples, maxlen_x, neg_samples)).astype('int64') |
| 68 | + mid_mask = numpy.zeros((n_samples, maxlen_x)).astype('float32') |
| 69 | + for idx, [s_x, s_y, no_sx, no_sy] in enumerate( |
| 70 | + zip(seqs_mid, seqs_cat, noclk_seqs_mid, noclk_seqs_cat)): |
| 71 | + mid_mask[idx, :lengths_x[idx]] = 1. |
| 72 | + mid_his[idx, :lengths_x[idx]] = s_x |
| 73 | + cat_his[idx, :lengths_x[idx]] = s_y |
| 74 | + noclk_mid_his[idx, :lengths_x[idx], :] = no_sx |
| 75 | + noclk_cat_his[idx, :lengths_x[idx], :] = no_sy |
| 76 | + |
| 77 | + uids = numpy.array([inp[0] for inp in input]) |
| 78 | + mids = numpy.array([inp[1] for inp in input]) |
| 79 | + cats = numpy.array([inp[2] for inp in input]) |
| 80 | + |
| 81 | + if return_neg: |
| 82 | + return uids, mids, cats, mid_his, cat_his, mid_mask, numpy.array( |
| 83 | + target), numpy.array(lengths_x), noclk_mid_his, noclk_cat_his |
| 84 | + |
| 85 | + else: |
| 86 | + return uids, mids, cats, mid_his, cat_his, mid_mask, numpy.array( |
| 87 | + target), numpy.array(lengths_x) |
| 88 | + |
| 89 | + |
| 90 | +data_location='data' |
| 91 | +test_file = os.path.join(data_location, "local_test_splitByUser") |
| 92 | +uid_voc = os.path.join(data_location, "uid_voc.pkl") |
| 93 | +mid_voc = os.path.join(data_location, "mid_voc.pkl") |
| 94 | +cat_voc = os.path.join(data_location, "cat_voc.pkl") |
| 95 | + |
| 96 | +test_data = DataIterator(test_file, |
| 97 | + uid_voc, |
| 98 | + mid_voc, |
| 99 | + cat_voc, |
| 100 | + batch_size, |
| 101 | + maxlen, |
| 102 | + data_location=data_location) |
| 103 | + |
| 104 | +f = open("./test_data.csv","w") |
| 105 | +counter = 0 |
| 106 | + |
| 107 | +for src, tgt in test_data: |
| 108 | + uids, mids, cats, mid_his, cat_his, mid_mask, target, sl = prepare_data(src, tgt) |
| 109 | + |
| 110 | + all_data = [uids, mids, cats, mid_his, cat_his, mid_mask, target, sl] |
| 111 | + |
| 112 | + for cur_data in all_data: |
| 113 | + cur_data = numpy.squeeze(cur_data).reshape(-1) |
| 114 | + |
| 115 | + for col in range(cur_data.shape[0]): |
| 116 | + uid = cur_data[col] |
| 117 | + # print(uid) |
| 118 | + if col == cur_data.shape[0]-1: |
| 119 | + f.write(str(uid)+",k,") |
| 120 | + break |
| 121 | + f.write(str(uid)+",") |
| 122 | + |
| 123 | + f.write("\n"); |
| 124 | + if counter >= 1: |
| 125 | + break |
| 126 | + counter += 1 |
| 127 | + |
| 128 | + |
| 129 | +f.close() |
| 130 | + |
| 131 | + |
| 132 | + |
| 133 | + |
| 134 | + |
| 135 | + |
| 136 | + |
0 commit comments