Skip to content

Commit

Permalink
readme & rename function
Browse files Browse the repository at this point in the history
  • Loading branch information
yikangshen committed Nov 15, 2018
1 parent 4b4e515 commit 04531c4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 14 deletions.
16 changes: 8 additions & 8 deletions LSTMCell.py → ON_LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def cumsoftmax(x, dim=-1):
return torch.cumsum(F.softmax(x, dim=dim), dim=dim)


class LSTMCell(nn.Module):
class ONLSTMCell(nn.Module):

def __init__(self, input_size, hidden_size, chunk_size, dropconnect=0.):
super(LSTMCell, self).__init__()
super(ONLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.chunk_size = chunk_size
Expand Down Expand Up @@ -118,13 +118,13 @@ def sample_masks(self):
m.sample_mask()


class LSTMStack(nn.Module):
class ONLSTMStack(nn.Module):
def __init__(self, layer_sizes, chunk_size, dropout=0., dropconnect=0.):
super(LSTMStack, self).__init__()
self.cells = nn.ModuleList([LSTMCell(layer_sizes[i],
layer_sizes[i+1],
chunk_size,
dropconnect=dropconnect)
super(ONLSTMStack, self).__init__()
self.cells = nn.ModuleList([ONLSTMCell(layer_sizes[i],
layer_sizes[i+1],
chunk_size,
dropconnect=dropconnect)
for i in range(len(layer_sizes) - 1)])
self.lockdrop = LockedDropout()
self.dropout = dropout
Expand Down
40 changes: 39 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,39 @@
# Ordered Neurons
# ON-LSTM

This repository contains the code used for word-level language model and unsupervised parsing experiments in
[Ordered Neurons: Integrating Tree Structures into Recurrent Neural Networks](https://arxiv.org/abs/1810.09536) paper,
originally forked from the
[LSTM and QRNN Language Model Toolkit for PyTorch](https://github.com/salesforce/awd-lstm-lm).
If you use this code or our results in your research, we'd appreciate if you cite our paper as following:

```
@article{shen2018ordered,
title={Ordered Neurons: Integrating Tree Structures into Recurrent Neural Networks},
author={Shen, Yikang and Tan, Shawn and Sordoni, Alessandro and Courville, Aaron},
journal={arXiv preprint arXiv:1810.09536},
year={2018}
}
```

## Software Requirements
Python 3.6, NLTK and PyTorch 0.4 are required for the current codebase.

## Steps

1. Install PyTorch 0.4 and NLTK

2. Download PTB data. Note that the two tasks, i.e., language modeling and unsupervised parsing share the same model
strucutre but require different formats of the PTB data. For language modeling we need the standard 10,000 word
[Penn Treebank corpus](https://github.com/pytorch/examples/tree/75e435f98ab7aaa7f82632d4e633e8e03070e8ac/word_language_model/data/penn) data,
and for parsing we need [Penn Treebank Parsed](https://catalog.ldc.upenn.edu/ldc99t42) data.

3. Scripts and commands

+ Train Language Modeling
```python main.py --batch_size 20 --dropout 0.45 --dropouth 0.3 --dropouti 0.5 --wdrop 0.45 --chunk_size 10 --seed 141 --epoch 1000 --data /path/to/your/data```

+ Test Unsupervised Parsing
```python test_phrase_grammar.py --cuda```

The default setting in `main.py` achieves a perplexity of approximately `56.17` on PTB test set,
and unlabeled F1 of approximately `47.7` on WSJ test set.
5 changes: 2 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,20 @@
###############################################################################

def model_save(fn):
with open(os.path.join(os.environ['PT_OUTPUT_DIR'], fn), 'wb') as f:
with open(fn, 'wb') as f:
torch.save([model, criterion, optimizer], f)


def model_load(fn):
global model, criterion, optimizer
with open(os.path.join(os.environ['PT_OUTPUT_DIR'], fn), 'rb') as f:
with open(fn, 'rb') as f:
model, criterion, optimizer = torch.load(f)


import os
import hashlib

fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest())
fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
if os.path.exists(fn):
print('Loading cached dataset...')
corpus = torch.load(fn)
Expand Down
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from embed_regularize import embedded_dropout
from locked_dropout import LockedDropout
from weight_drop import WeightDrop
from LSTMCell import LSTMStack
from ON_LSTM import ONLSTMStack

class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
Expand All @@ -17,7 +17,7 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, chunk_size, nlayers, dropout=0.
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
assert rnn_type in ['LSTM'], 'RNN type is not supported'
self.rnn = LSTMStack(
self.rnn = ONLSTMStack(
[ninp] + [nhid] * (nlayers - 1) + [ninp],
chunk_size=chunk_size,
dropconnect=wdrop,
Expand Down

0 comments on commit 04531c4

Please sign in to comment.