Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
unknowed-ER committed Jan 9, 2025
0 parents commit 8a34f11
Show file tree
Hide file tree
Showing 48 changed files with 20,723 additions and 0 deletions.
173 changes: 173 additions & 0 deletions ActionTopic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import torch
import torch.nn as nn
from tools import Tools

class Action(nn.Module):
def __init__(self, args, p_encoder, a_decoder, hidden_size, main_encoder,
n_topic_vocab, bos_idx, max_len, glo2loc, loc2glo, vocab):
super(Action, self).__init__()
self.args = args
self.p_encoder = p_encoder
self.a_decoder = a_decoder
self.main_encoder = main_encoder
self.vocab = vocab
self.hidden_size = hidden_size
self.n_topic_vocab = n_topic_vocab
self.bos_idx = bos_idx
self.max_len = max_len
self.glo2loc = glo2loc
self.loc2glo = loc2glo
self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_topic_vocab))

def forward(self, m, l, context, context_len, related_topics, related_topics_len, ar_gth, ar_gth_len, tp_path, tp_path_len, user_id, user_embed, mode, encoder_embed=None, decoder_embed=None, profile_prob=None, tp_path_embed=None):
if l is not None:
l_mask = l.new_ones(l.size(0), 1, l.size(1))
l_hidden = self.p_encoder(l, l_mask, embed=encoder_embed)
if tp_path_embed==None:
tp_path = one_hot_scatter(tp_path, self.n_topic_vocab)
tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num)
tp_hidden = self.p_encoder(tp_path, tp_mask, embed=encoder_embed)
else:
tp_path = one_hot_scatter(tp_path, self.n_topic_vocab)
tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num)
tp_hidden = self.p_encoder(tp_path_embed, tp_mask, embed=encoder_embed, embed_input=True)
context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len)
context_hidden = self.main_encoder(context, context_mask)
if related_topics is not None:
related_topics = one_hot_scatter(related_topics, self.n_topic_vocab)
related_topics_mask = Tools.get_mask_via_len(related_topics_len, self.args.relation_num)
related_topic_hidden = self.p_encoder(related_topics, related_topics_mask, embed=encoder_embed)
if m is not None:
m_mask = m.new_ones(m.size(0), 1, m.size(1))
m_hidden = self.p_encoder(m, m_mask, embed=encoder_embed)

if m is None:
src_hidden = torch.cat([tp_hidden, context_hidden, related_topic_hidden], 1)
src_mask = torch.cat([tp_mask, context_mask, related_topics_mask], 2)
elif related_topics is None:
src_hidden = torch.cat([m_hidden, l_hidden, tp_hidden, context_hidden], 1)
src_mask = torch.cat([m_mask, l_mask, tp_mask, context_mask], 2)
elif l is None:
src_hidden = torch.cat([m_hidden, tp_hidden, context_hidden, related_topic_hidden], 1)
src_mask = torch.cat([m_mask, tp_mask, context_mask, related_topics_mask], 2)
else:
src_hidden = torch.cat([m_hidden, l_hidden, tp_hidden, context_hidden, related_topic_hidden], 1)
src_mask = torch.cat([m_mask, l_mask, tp_mask, context_mask, related_topics_mask], 2)

probs = None
action_mask = Tools.get_mask_via_len(ar_gth_len, self.args.action_num)
if mode == 'train':
for i in range(0, self.args.action_num, 2):
seq_gth = ar_gth[:, 0: i + 1]
ar_mask = action_mask[:, :, 0:i + 1]
dec_output = Tools._single_decode(seq_gth.detach(), src_hidden, src_mask, self.a_decoder, ar_mask)
prob = self.proj(dec_output, src_hidden, src_mask, m, l, context, tp_path, related_topics, embed=decoder_embed, profile_prob=profile_prob)
if i == 0:
probs = prob
else:
probs = torch.cat([probs, prob], 1)
return probs
else:
seq_gen = None
for i in range(0, self.args.action_num, 2):
if i == 0:
seq_gen = ar_gth[:, 0:i + 1]
else:
seq_gen = torch.cat([seq_gen, ar_gth[:, i:i + 1]], 1)
ar_mask = action_mask[:, :, 0:i + 1]
dec_output = Tools._single_decode(seq_gen.detach(), src_hidden, src_mask, self.a_decoder, ar_mask)
single_step_prob = self.proj(dec_output, src_hidden, src_mask, m, l, context, tp_path, related_topics, embed=decoder_embed, profile_prob=profile_prob)
if i == 0:
probs = single_step_prob
else:
probs = torch.cat([probs, single_step_prob], 1)
single_step_word = torch.argmax(single_step_prob, -1)
seq_gen = torch.cat([seq_gen, single_step_word], 1)
return seq_gen, probs

def proj(self, dec_out, src_hidden, src_mask, pv_m, l, context, tp, related_topics, embed=None, profile_prob=None):
B, L_a = dec_out.size(0), dec_out.size(1)
gen_logit = self.gen_prob(dec_out, embed)
copy_logit = torch.bmm(dec_out, src_hidden.permute(0, 2, 1))
copy_logit = copy_logit.masked_fill((src_mask == 0).expand(-1, L_a, -1), -1e9 if copy_logit.dtype==torch.float32 else -1e4)
logits = torch.cat([gen_logit, copy_logit], -1)
if self.args.scale_prj:
logits *= self.hidden_size ** -0.5
probs = torch.softmax(logits, -1)
gen_prob = probs[:, :, :self.n_topic_vocab]
if self.args.not_copynet:
probs = gen_prob
elif pv_m is None:
copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.state_num: self.n_topic_vocab + self.args.state_num + self.args.context_max_len]
transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context)
copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab)
copy_context_prob = copy_context_temp.scatter_add(dim=2,
index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1),
src=copy_context_prob)
copy_tp_prob = probs[:, :, self.n_topic_vocab : self.n_topic_vocab + self.args.state_num]
copy_tp_prob = torch.bmm(copy_tp_prob, tp)
copy_relation_prob = probs[:, :, self.n_topic_vocab + self.args.state_num + self.args.context_max_len:]
copy_relation_prob = torch.bmm(copy_relation_prob, related_topics)
probs = gen_prob + copy_tp_prob + copy_context_prob + copy_relation_prob
elif related_topics is None:
copy_m_prob = probs[:, :, self.n_topic_vocab:self.n_topic_vocab + self.args.preference_num]
copy_m_prob = torch.bmm(copy_m_prob, pv_m)
copy_l_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num:self.n_topic_vocab + self.args.preference_num + self.args.profile_num]
copy_l_prob = torch.bmm(copy_l_prob, l)
copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num + self.args.context_max_len]
transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context)
copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab)
copy_context_prob = copy_context_temp.scatter_add(dim=2,
index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1),
src=copy_context_prob)
copy_tp_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num]
copy_tp_prob = torch.bmm(copy_tp_prob, tp)
probs = gen_prob + copy_l_prob + copy_tp_prob + copy_context_prob + copy_m_prob
elif l is None:
copy_m_prob = probs[:, :, self.n_topic_vocab:self.n_topic_vocab + self.args.preference_num]
copy_m_prob = torch.bmm(copy_m_prob, pv_m)
copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.state_num: self.n_topic_vocab + self.args.preference_num + self.args.state_num + self.args.context_max_len]
transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context)
copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab)
copy_context_prob = copy_context_temp.scatter_add(dim=2,
index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1),
src=copy_context_prob)
copy_tp_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num: self.n_topic_vocab + self.args.preference_num + self.args.state_num]
copy_tp_prob = torch.bmm(copy_tp_prob, tp)
copy_relation_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.state_num + self.args.context_max_len:]
copy_relation_prob = torch.bmm(copy_relation_prob, related_topics)
probs = gen_prob + copy_tp_prob + copy_context_prob + copy_relation_prob + copy_m_prob
else:
copy_m_prob = probs[:, :, self.n_topic_vocab:self.n_topic_vocab + self.args.preference_num]
copy_m_prob = torch.bmm(copy_m_prob, pv_m)
copy_l_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num:self.n_topic_vocab + self.args.preference_num + self.args.profile_num]
copy_l_prob = torch.bmm(copy_l_prob, l)
copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num + self.args.context_max_len]
transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context)
copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab)
copy_context_prob = copy_context_temp.scatter_add(dim=2, index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1), src=copy_context_prob)
copy_tp_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num]
copy_tp_prob = torch.bmm(copy_tp_prob, tp)
copy_relation_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num + self.args.context_max_len:]
copy_relation_prob = torch.bmm(copy_relation_prob, related_topics)
probs = gen_prob + copy_l_prob + copy_tp_prob + copy_context_prob + copy_relation_prob + copy_m_prob
if self.args.topic_copynet:
probs = probs + profile_prob.unsqueeze(1)
return probs

def gen_prob(self, dec_output, embed=None):

if embed is None:
prob = self.gen_proj(dec_output)
else:
assert embed.size(0) == self.gen_proj[0].out_features
prob = dec_output.matmul(embed.T)
return prob


def one_hot_scatter(indice, num_classes, dtype=torch.float):
indice_shape = list(indice.shape)
placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype)
v = 1 if dtype == torch.long else 1.0
placeholder.scatter_(-1, indice.unsqueeze(-1), v)
return placeholder
102 changes: 102 additions & 0 deletions Action_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
import torch.nn.functional as F
import ipdb
import torch
import torch.nn as nn
from tools import Tools

class Action(nn.Module):
def __init__(self,p_encoder,a_decoder,hidden_size,main_encoder,
m_encoder,n_topic_vocab,bos_idx,max_len,glo2loc,loc2glo,vocab):
super(Action, self).__init__()
self.p_encoder = p_encoder
self.a_decoder = a_decoder
self.m_encoder = m_encoder
self.main_encoder = main_encoder
self.vocab = vocab
self.hidden_size = hidden_size
self.n_topic_vocab = n_topic_vocab
self.bos_idx = bos_idx
self.max_len = max_len
self.glo2loc = glo2loc
self.loc2glo = loc2glo
self.gen_proj = nn.Linear(self.hidden_size,self.n_topic_vocab)
self.topic2movie = nn.Linear(2583,self.n_topic_vocab-2583)
self.mask = torch.zeros(self.args.batch_size,1,self.n_topic_vocab).cuda()
self.mask[:,:,2583:] = 1
self.pad = torch.zeros(self.args.batch_size,1,2583).cuda()
def forward(self,m,l,context,context_len,ar_gth,ar_gth_len,
tp_path,tp_path_len,related_movies,related_movies_len,mode):
if mode == 'test':
m = one_hot_scatter(m, self.n_topic_vocab)
l = one_hot_scatter(l, self.n_topic_vocab)
m_mask = m.new_ones(m.size(0),1,m.size(1))
m_hidden = self.p_encoder(m,m_mask)
l_mask = l.new_ones(l.size(0),1,l.size(1))
l_hidden = self.p_encoder(l,l_mask)
tp_path = one_hot_scatter(tp_path, self.n_topic_vocab)
tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num)
tp_path_hidden = self.p_encoder(tp_path, tp_mask)
context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len)
context_hidden = self.main_encoder(context, context_mask)
related_movies = one_hot_scatter(related_movies,self.n_topic_vocab)
related_movies_mask = Tools.get_mask_via_len(related_movies_len,self.args.movie_num)
related_movies_hidden = self.p_encoder(related_movies,related_movies_mask)
src_hidden = torch.cat([m_hidden,l_hidden,context_hidden,tp_path_hidden,related_movies_hidden],1)
src_mask = torch.cat([m_mask,l_mask,context_mask,tp_mask,related_movies_mask],2)
action_mask = Tools.get_mask_via_len(ar_gth_len,self.args.action_num)
if mode == 'train':
seq_gth = ar_gth[:,[0]]
ar_mask = action_mask[:,:,[0]]
dec_output = Tools._single_decode(seq_gth.detach(), src_hidden, src_mask, self.a_decoder, ar_mask)
prob = self.proj(dec_out=dec_output, src_hidden=src_hidden, src_mask=src_mask,
tp=tp_path, m=m, l=l, context=context,related_movies=related_movies)
return prob
else:
seq_gen = ar_gth[:,[0]]
ar_mask = action_mask[:, :, [0]]
dec_output = Tools._single_decode(seq_gen.detach(), src_hidden, src_mask, self.a_decoder,ar_mask)
prob = self.proj(dec_out=dec_output, src_hidden=src_hidden, src_mask=src_mask,
tp=tp_path, m=m, l=l, context=context,related_movies=related_movies)
word = torch.argmax(prob, -1)
return word, prob

def proj(self,dec_out, src_hidden,src_mask, tp, m, l, context,related_movies ):
B,L_a =dec_out.size(0), dec_out.size(1)
gen_logit = self.gen_proj(dec_out)
copy_logit = torch.bmm(dec_out, src_hidden.permute(0, 2, 1))
copy_logit = copy_logit.masked_fill((src_mask == 0).expand(-1, L_a, -1), -1e9)
logits = torch.cat([gen_logit, copy_logit], -1)
if self.args.scale_prj:
logits *= self.hidden_size ** -0.5
probs = torch.softmax(logits, -1)
gen_prob = probs[:, :, :self.n_topic_vocab]
copy_m = probs[:, :, self.n_topic_vocab :
self.n_topic_vocab + self.args.preference_num]
copy_m_prob = torch.bmm(copy_m, m)
copy_l = probs[:, :, self.n_topic_vocab+ self.args.preference_num:self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num]
copy_l_prob = torch.bmm(copy_l, l)
copy_context_prob = probs[:, :, self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num:self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len]
transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context)
copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab)
copy_context_prob = copy_context_temp.scatter_add(dim=2,
index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1),
src=copy_context_prob)
copy_tp = probs[:,:,self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len:
self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len + self.args.state_num]
copy_tp_prob = torch.bmm(copy_tp, tp)
copy_relation = probs[:,:,self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len + self.args.state_num:]
copy_relation = torch.bmm(copy_relation,related_movies)
probs = gen_prob + copy_m_prob + copy_l_prob + copy_context_prob + copy_tp_prob + copy_relation
probs = probs.mul(self.mask)
norm = torch.sum(probs,-1)
norm = norm.unsqueeze(1)
probs/=norm
return probs

def one_hot_scatter(indice, num_classes, dtype=torch.float):
indice_shape = list(indice.shape)
placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype)
v = 1 if dtype == torch.long else 1.0
placeholder.scatter_(-1, indice.unsqueeze(-1), v)
return placeholder
57 changes: 57 additions & 0 deletions Bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

import argparse
import sys
import numpy as np
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu
import ipdb
import re
import jieba

def bleu_cal(sen1, tar1):
bleu1 = sentence_bleu([tar1], sen1, weights=(1, 0, 0, 0))
bleu2 = sentence_bleu([tar1], sen1, weights=(0.5, 0.5, 0, 0))
bleu3 = sentence_bleu([tar1], sen1, weights=(0.33, 0.33, 0.33, 0))
bleu4 = sentence_bleu([tar1], sen1, weights=(0.25, 0.25, 0.25, 0.25))
return bleu1, bleu2, bleu3, bleu4

def bleu(args, tokenized_gen, tokenized_tar):
print_num = 0
bleu1_sum, bleu2_sum, bleu3_sum, bleu4_sum, count = 0, 0, 0, 0, 0
for sen, tar in zip(tokenized_gen, tokenized_tar):
for j,word in enumerate(sen):
if word == args.EOS_RESPONSE:
sen = sen[:j]
break
tar = tar[1:]
for k,word in enumerate(tar):
if word == args.EOS_RESPONSE:
tar = tar[:k]
break
full_sen_gen = ''
full_sen_gth = ''
for word in sen:
full_sen_gen += word
for word in tar:
full_sen_gth +=word
sen_split_by_movie = list(full_sen_gen.split('<movie>'))
sen_1 = []
for i, sen_split in enumerate(sen_split_by_movie):
for segment in jieba.cut(sen_split):
sen_1.append(segment)
if i != len(sen_split_by_movie) - 1:
sen_1.append('<movie>')
tar_split_by_movie = list(full_sen_gth.split('<movie>'))
tar_1 = []
for i, tar_split in enumerate(tar_split_by_movie):
for segment in jieba.cut(tar_split):
tar_1.append(segment)
if i != len(tar_split_by_movie) - 1:
tar_1.append('<movie>')
bleu1, bleu2, bleu3, bleu4 = bleu_cal(sen_1, tar_1)
bleu1_sum += bleu1
bleu2_sum += bleu2
bleu3_sum += bleu3
bleu4_sum += bleu4
count += 1
return bleu1_sum / count, bleu2_sum / count, bleu3_sum / count, bleu4_sum / count
Loading

0 comments on commit 8a34f11

Please sign in to comment.