Skip to content

Commit

Permalink
finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyu-he committed Nov 7, 2023
1 parent a591af9 commit 931174a
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 41 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,5 @@ fabric.properties

# idea folder, uncomment if you don't need it

.idea
.idea
src
27 changes: 15 additions & 12 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ class Config:
word_emb_dim = 512
hidden_dim = 1024
num_lstm_layers = 1
num_transformer_layers = 6
num_gpt1_layers = 6
n_head = 8

batch = 32
epoch = 5
lr_lstm = 2e-3
lr_transformer = 2e-4
lr_lstm = 5e-4
lr_gpt1 = 2e-4

train_size = 0.8

Expand Down Expand Up @@ -53,27 +53,30 @@ class Config:
'_e' + str(epoch) +
'_lstm.pt'
)
encoder_transformer_file = (
encoder_gpt1_file = (
'src/encoder' +
'_b' + str(batch) +
'_h' + str(hidden_dim) +
'_l' + str(num_lstm_layers) +
'_l' + str(num_gpt1_layers) +
'_nh' + str(n_head) +
'_e' + str(epoch) +
'_transformer.pt'
'_gpt1.pt'
)
decoder_transformer_file = (
decoder_gpt1_file = (
'src/decoder' +
'_b' + str(batch) +
'_h' + str(hidden_dim) +
'_l' + str(num_lstm_layers) +
'_l' + str(num_gpt1_layers) +
'_nh' + str(n_head) +
'_e' + str(epoch) +
'_transformer.pt'
'_gpt1.pt'
)
embedding_transformer_file = (
embedding_gpt1_file = (
'src/embedding' +
'_b' + str(batch) +
'_h' + str(hidden_dim) +
'_l' + str(num_lstm_layers) +
'_l' + str(num_gpt1_layers) +
'_nh' + str(n_head) +
'_e' + str(epoch) +
'_transformer.pt'
'_gpt1.pt'
)
2 changes: 1 addition & 1 deletion load_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

Expand Down
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(x)


class DecoderTransformer(nn.Module):
class DecoderGPT1(nn.Module):

def __init__(self,
word_emb_dim: int,
Expand Down
27 changes: 14 additions & 13 deletions show_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from config import Config
from vocab import Vocab
from model import Encoder, DecoderLSTM, DecoderTransformer
from model import Encoder, DecoderLSTM, DecoderGPT1


parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'transformer'])
parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'gpt1'])
parser.add_argument('--image_file', type=str, default='6261030.jpg')
args = parser.parse_args()

Expand Down Expand Up @@ -46,20 +46,20 @@
num_layers=config.num_lstm_layers,
vocab_size=config.vocab_size).to(config.device)
else:
decoder = DecoderTransformer(word_emb_dim=config.word_emb_dim,
nhead=config.n_head,
hidden_dim=config.hidden_dim,
num_layers=config.num_transformer_layers,
vocab_size=config.vocab_size).to(config.device)
decoder = DecoderGPT1(word_emb_dim=config.word_emb_dim,
nhead=config.n_head,
hidden_dim=config.hidden_dim,
num_layers=config.num_gpt1_layers,
vocab_size=config.vocab_size).to(config.device)

if args.model == 'lstm':
encoder.load_state_dict(torch.load(config.encoder_lstm_file, map_location=config.device))
emb_layer.load_state_dict(torch.load(config.embedding_lstm_file, map_location=config.device))
decoder.load_state_dict(torch.load(config.decoder_lstm_file, map_location=config.device))
else:
encoder.load_state_dict(torch.load(config.encoder_transformer_file, map_location=config.device))
emb_layer.load_state_dict(torch.load(config.embedding_transformer_file, map_location=config.device))
decoder.load_state_dict(torch.load(config.decoder_transformer_file, map_location=config.device))
encoder.load_state_dict(torch.load(config.encoder_gpt1_file, map_location=config.device))
emb_layer.load_state_dict(torch.load(config.embedding_gpt1_file, map_location=config.device))
decoder.load_state_dict(torch.load(config.decoder_gpt1_file, map_location=config.device))

encoder.eval()
emb_layer.eval()
Expand Down Expand Up @@ -104,9 +104,10 @@
sentence.append(next_word)

sentence = ' '.join(sentence).strip().capitalize() + '.'
plt.figure().set_figwidth(50)
plt.imshow(image_ori.permute(1, 2, 0).cpu())
plt.title(sentence)
plt.title('[{}]'.format(args.model) + ' ' + sentence)
plt.axis('off')
image_save = args.image_file.split('.')[0] + '_' + args.model + '.pdf'
plt.savefig(image_save, bbox_inches='tight')
image_save = args.image_file.split('.')[0] + '_' + args.model + '.png'
plt.savefig(image_save, dpi=300, bbox_inches='tight')
print(sentence)
22 changes: 11 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from config import Config
from vocab import Vocab
from load_dataset import Flicker30k, preprocess_image, Padding
from model import Encoder, DecoderLSTM, DecoderTransformer
from model import Encoder, DecoderLSTM, DecoderGPT1

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'transformer'])
parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'gpt1'])
args = parser.parse_args()

config = Config
Expand Down Expand Up @@ -47,15 +47,15 @@
num_layers=config.num_lstm_layers,
vocab_size=config.vocab_size).to(config.device)
else:
decoder = DecoderTransformer(word_emb_dim=config.word_emb_dim,
nhead=config.n_head,
hidden_dim=config.hidden_dim,
num_layers=config.num_transformer_layers,
vocab_size=config.vocab_size).to(config.device)
decoder = DecoderGPT1(word_emb_dim=config.word_emb_dim,
nhead=config.n_head,
hidden_dim=config.hidden_dim,
num_layers=config.num_gpt1_layers,
vocab_size=config.vocab_size).to(config.device)

criterion = torch.nn.CrossEntropyLoss().to(config.device)
parameters = list(encoder.parameters()) + list(emb_layer.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=config.lr_lstm if args.model == 'lstm' else config.lr_transformer)
optimizer = torch.optim.Adam(params=parameters, lr=config.lr_lstm if args.model == 'lstm' else config.lr_gpt1)

# training
print('---Training---')
Expand Down Expand Up @@ -164,6 +164,6 @@
torch.save(emb_layer.state_dict(), config.embedding_lstm_file)
torch.save(decoder.state_dict(), config.decoder_lstm_file)
else:
torch.save(encoder.state_dict(), config.encoder_transformer_file)
torch.save(emb_layer.state_dict(), config.embedding_transformer_file)
torch.save(decoder.state_dict(), config.decoder_transformer_file)
torch.save(encoder.state_dict(), config.encoder_gpt1_file)
torch.save(emb_layer.state_dict(), config.embedding_gpt1_file)
torch.save(decoder.state_dict(), config.decoder_gpt1_file)
2 changes: 0 additions & 2 deletions vocab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import string
from nltk.tokenize import RegexpTokenizer
from collections import Counter

Expand Down

0 comments on commit 931174a

Please sign in to comment.