Skip to content

Commit

Permalink
play around a bit with different kinds of cells
Browse files Browse the repository at this point in the history
  • Loading branch information
DNGros committed Mar 29, 2019
1 parent 0344233 commit e1c0c27
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 38 deletions.
143 changes: 138 additions & 5 deletions ainix_kernel/models/EncoderDecoder/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,40 @@ def end_train_session(self):
pass


class TreeRNNCell(nn.Module):
class TreeRNNCell(nn.Module, ABC):
"""An rnn cell in a tree RNN"""
def __init__(self, ast_node_embed_size: int, hidden_size):
super().__init__()
self.input_size = ast_node_embed_size

def forward(
self,
last_hidden: torch.Tensor,
type_to_predict_features: torch.Tensor,
parent_node_features: torch.Tensor,
parent_node_hidden: torch.Tensor,
memory_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplemented()


class TreeRNNCellLSTM(TreeRNNCell):
"""An rnn cell in a tree RNN"""
def __init__(self, ast_node_embed_size: int, hidden_size, attn: bool = True):
super().__init__(ast_node_embed_size, hidden_size)
self.rnn = nn.LSTMCell(ast_node_embed_size, hidden_size)
# self.rnn = nn.GRUCell(ast_node_embed_size, hidden_size)
self.root_node_features = nn.Parameter(torch.rand(hidden_size))
self.dropout = nn.Dropout(p=0.1)
self.attn_query_lin = nn.Linear(hidden_size*3, hidden_size)

def forward(
self,
last_hidden: torch.Tensor,
type_to_predict_features: torch.Tensor,
parent_node_features: torch.Tensor,
parent_node_hidden: torch.Tensor,
memory_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -154,15 +172,127 @@ def forward(
internal state to pass forward.
"""
# TODO (DNGros): Use parent hidden data

# Normal LSTM
if parent_node_features is None:
num_of_batches = len(type_to_predict_features)
parent_node_features = self.root_node_features.expand(num_of_batches, -1)

attn_query = self.attn_query_lin(
torch.cat((parent_node_features, type_to_predict_features, last_hidden), dim=1))
# The attend function expects dim (batch_size_q, num_queries, hidden)
# right now we are only (batch_size, hidden), so need to unsqueeze
attn_query = attn_query.unsqueeze(0)
attn_result = attend.attend(attn_query, context=memory_tokens)
attn_result = attn_result.squeeze(1)
last_hidden = attn_result

out, next_hidden = self.rnn(type_to_predict_features,
(parent_node_features, last_hidden))
next_hidden = self.dropout(next_hidden)
return out, next_hidden


class TreeCellOnlyAttn(TreeRNNCell):
"""An rnn cell in a tree RNN"""
def __init__(self, ast_node_embed_size: int, hidden_size, attn: bool = True):
super().__init__(ast_node_embed_size, hidden_size)
self.root_node_features = nn.Parameter(torch.rand(hidden_size))
self.dropout = nn.Dropout(p=0.1)
self.after_attn_lin = nn.Linear(hidden_size, hidden_size)

def forward(
self,
last_hidden: torch.Tensor,
type_to_predict_features: torch.Tensor,
parent_node_features: torch.Tensor,
parent_node_hidden: torch.Tensor,
memory_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
last_hidden:
type_to_predict_features:
parent_node_features:
parent_node_hidden:
Returns:
Tuple of tensor. First the the thing to predict on. Second is
internal state to pass forward.
"""
# TODO (DNGros): Use parent hidden data

# Normal LSTM
if parent_node_features is None:
num_of_batches = len(type_to_predict_features)
parent_node_features = self.root_node_features.expand(num_of_batches, -1)

#attn_query = self.attn_query_lin(
# torch.cat((type_to_predict_features,), dim=1))
attn_query = self.dropout(type_to_predict_features)
# The attend function expects dim (batch_size_q, num_queries, hidden)
# right now we are only (batch_size, hidden), so need to unsqueeze
attn_query = attn_query.unsqueeze(0)
attn_result = attend.attend(attn_query, context=memory_tokens)
attn_result = attn_result.squeeze(1)
attn_result = self.after_attn_lin(attn_result)

return attn_result, attn_result


class TreeRNNCellGRU(TreeRNNCell):
"""An rnn cell in a tree RNN"""
def __init__(self, ast_node_embed_size: int, hidden_size, attn: bool = True):
super().__init__(ast_node_embed_size, hidden_size)
self.rnn = nn.GRUCell(ast_node_embed_size, hidden_size)
self.root_node_features = nn.Parameter(torch.rand(hidden_size))
self.dropout = nn.Dropout(p=0.1)
self.attn_query_lin = nn.Linear(hidden_size*3, hidden_size)

def forward(
self,
last_hidden: torch.Tensor,
type_to_predict_features: torch.Tensor,
parent_node_features: torch.Tensor,
parent_node_hidden: torch.Tensor,
memory_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
last_hidden:
type_to_predict_features:
parent_node_features:
parent_node_hidden:
Returns:
Tuple of tensor. First the the thing to predict on. Second is
internal state to pass forward.
"""
# TODO (DNGros): Use parent hidden data

# Normal LSTM
if parent_node_features is None:
num_of_batches = len(type_to_predict_features)
parent_node_features = self.root_node_features.expand(num_of_batches, -1)

attn_query = self.attn_query_lin(
self.dropout(
torch.cat((parent_node_features, type_to_predict_features, last_hidden), dim=1)
)
)
# The attend function expects dim (batch_size_q, num_queries, hidden)
# right now we are only (batch_size, hidden), so need to unsqueeze
attn_query = attn_query.unsqueeze(0)
attn_result = attend.attend(attn_query, context=memory_tokens)
attn_result = attn_result.squeeze(1)
last_hidden = attn_result

out, next_hidden = self.rnn(type_to_predict_features,
(parent_node_features, last_hidden))
next_hidden = self.dropout(next_hidden)
#out = self.rnn(type_to_predict_features, last_hidden)
return out, next_hidden
#return out, out


#@attr.s
Expand Down Expand Up @@ -219,7 +349,8 @@ def _inference_objectchoice_step(
last_hidden=last_hidden,
type_to_predict_features=self._get_obj_choice_features(current_leaf),
parent_node_features=parent_node_features,
parent_node_hidden=None
parent_node_hidden=None,
memory_tokens=memory_tokens
)
if len(outs) != 1:
raise NotImplemented("Batches not implemented")
Expand Down Expand Up @@ -257,6 +388,7 @@ def _train_objectchoice_step(
type_to_predict_features=self._get_obj_choice_features(teacher_force_path),
parent_node_features=parent_node_features,
parent_node_hidden=None,
memory_tokens=memory_tokens
)

loss = self.action_selector.forward_train(
Expand Down Expand Up @@ -478,7 +610,8 @@ def get_default_nonretrieval_decoder(
) -> TreeDecoder:
object_vectorizer = vectorizers.TorchDeepEmbed(type_context.get_object_count(), rnn_hidden_size)
type_vectorizer = vectorizers.TorchDeepEmbed(type_context.get_type_count(), rnn_hidden_size)
rnn_cell = TreeRNNCell(rnn_hidden_size, rnn_hidden_size)
rnn_cell = TreeRNNCellLSTM(rnn_hidden_size, rnn_hidden_size)
#rnn_cell = TreeCellOnlyAttn(rnn_hidden_size, rnn_hidden_size)
action_selector = SimpleActionSelector(rnn_cell.input_size,
objectselector.get_default_object_selector(
type_context, object_vectorizer), type_context)
Expand Down
2 changes: 1 addition & 1 deletion ainix_kernel/models/EncoderDecoder/encdecmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_default_tokenizers() -> Tuple[
tokenizers.Tokenizer
]:
"""Returns tuple (default x tokenizer, default y tokenizer)"""
return (NonLetterTokenizer(), None), AstValTokenizer()
#return (NonLetterTokenizer(), None), AstValTokenizer()
word_piece_tok, word_list = get_default_pieced_tokenizer_word_list()
x_vocab = BasicVocab(word_list + parse_constants.ALL_SPECIALS)
return (word_piece_tok, x_vocab), AstValTokenizer()
Expand Down
64 changes: 32 additions & 32 deletions ainix_kernel/training/opennmt/expir3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

BATCH_SIZE=64
REPLACE_SAMPLES=10
TRAIN_EPOCHS=40
#TRAIN_EPOCHS=40
WORD_VEC_SIZE=300
TRAIN_STEPS=7500

#echo "Exporting latest data"
#cd ../../..
Expand All @@ -28,39 +29,39 @@ WORD_VEC_SIZE=300
# --tgt_words_min_frequency 3 \
# || exit 1

echo "prepare glove"
cd ./OpenNMT-py/
python3 -m tools.embeddings_to_torch \
-emb_file_both "../glove_dir/glove.840B.${WORD_VEC_SIZE}d.txt" \
-dict_file "../expirs/exp1.vocab.pt" \
-output_file "../data/embeddings" \
|| exit 1
cd ..

echo "Train"
data_size=$(wc -l < data_train_x.txt)
steps_to_do=$[(TRAIN_EPOCHS*BATCH_SIZE)/REPLACE_SAMPLES/BATCH_SIZE]
echo ${steps_to_do}
CUDA_VISIBLE_DEVICES=0 python3 ./OpenNMT-py/train.py \
-data expirs/exp1 \
-save_model data/demo-model \
--src_word_vec_size 64 \
--tgt_word_vec_size 64 \
--rnn_size 128 \
--batch_size ${BATCH_SIZE} \
--train_steps 7500 \
--report_every 50 \
--start_decay_steps 4000 \
--decay_steps 2000 \
--gpu_rank 0 \
--word_vec_size ${WORD_VEC_SIZE} \
--pre_word_vecs_enc "data/embeddings.enc.pt" \
--pre_word_vecs_dec "data/embeddings.dec.pt" \
|| exit 1
#echo "prepare glove"
#cd ./OpenNMT-py/
#python3 -m tools.embeddings_to_torch \
# -emb_file_both "../glove_dir/glove.840B.${WORD_VEC_SIZE}d.txt" \
# -dict_file "../expirs/exp1.vocab.pt" \
# -output_file "../data/embeddings" \
# || exit 1
#cd ..
#
#echo "Train"
#data_size=$(wc -l < data_train_x.txt)
##steps_to_do=$[(TRAIN_EPOCHS*BATCH_SIZE)/REPLACE_SAMPLES/BATCH_SIZE]
#echo ${steps_to_do}
#CUDA_VISIBLE_DEVICES=0 python3 ./OpenNMT-py/train.py \
# -data expirs/exp1 \
# -save_model data/demo-model \
# --src_word_vec_size 64 \
# --tgt_word_vec_size 64 \
# --rnn_size 128 \
# --batch_size ${BATCH_SIZE} \
# --train_steps ${TRAIN_STEPS} \
# --report_every 50 \
# --start_decay_steps 4000 \
# --decay_steps 2000 \
# --gpu_rank 0 \
# --word_vec_size ${WORD_VEC_SIZE} \
# --pre_word_vecs_enc "data/embeddings.enc.pt" \
# --pre_word_vecs_dec "data/embeddings.dec.pt" \
# || exit 1

echo "Predict"
python3 ./OpenNMT-py/translate.py \
-model data/demo-model_step_5000.pt \
-model data/demo-model_step_${TRAIN_STEPS}.pt \
-src data_val_x.txt \
-tgt data_val_y.txt \
-output pred.txt \
Expand All @@ -84,5 +85,4 @@ python3 -m ainix_kernel.training.eval_external \
#--optim adagrad \
#--learning_rate 1 \


echo "Done."

0 comments on commit e1c0c27

Please sign in to comment.