Skip to content

Commit

Permalink
add 'raise NotImplementedError' in alg_type
Browse files Browse the repository at this point in the history
Raise NotImplementedError when alg_type is not set as one of ['bi', 'kgat', 'gcn', 'graphsage'], in order to solve some errors like 'no ua_embeddings'.
  • Loading branch information
xiangwang1223 authored Jun 19, 2020
1 parent b10ed2b commit 530327a
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Model/KGAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,17 @@ def _build_weights(self):
return all_weights

def _build_model_phase_I(self):
if self.alg_type in ['bi']:
if self.alg_type in ['bi', 'kgat']:
self.ua_embeddings, self.ea_embeddings = self._create_bi_interaction_embed()

elif self.alg_type in ['gcn']:
self.ua_embeddings, self.ea_embeddings = self._create_gcn_embed()

elif self.alg_type in ['graphsage']:
self.ua_embeddings, self.ea_embeddings = self._create_graphsage_embed()
else:
print('please check the the alg_type argument, which should be bi, kgat, gcn, or graphsage.')
raise NotImplementedError

self.u_e = tf.nn.embedding_lookup(self.ua_embeddings, self.users)
self.pos_i_e = tf.nn.embedding_lookup(self.ea_embeddings, self.pos_items)
Expand Down

0 comments on commit 530327a

Please sign in to comment.