diff --git a/.gitignore b/.gitignore index b714610..3b019f2 100644 --- a/.gitignore +++ b/.gitignore @@ -183,4 +183,5 @@ fabric.properties # idea folder, uncomment if you don't need it -.idea \ No newline at end of file +.idea +src \ No newline at end of file diff --git a/config.py b/config.py index d9f17a7..9d80686 100644 --- a/config.py +++ b/config.py @@ -13,13 +13,13 @@ class Config: word_emb_dim = 512 hidden_dim = 1024 num_lstm_layers = 1 - num_transformer_layers = 6 + num_gpt1_layers = 6 n_head = 8 batch = 32 epoch = 5 - lr_lstm = 2e-3 - lr_transformer = 2e-4 + lr_lstm = 5e-4 + lr_gpt1 = 2e-4 train_size = 0.8 @@ -53,27 +53,30 @@ class Config: '_e' + str(epoch) + '_lstm.pt' ) - encoder_transformer_file = ( + encoder_gpt1_file = ( 'src/encoder' + '_b' + str(batch) + '_h' + str(hidden_dim) + - '_l' + str(num_lstm_layers) + + '_l' + str(num_gpt1_layers) + + '_nh' + str(n_head) + '_e' + str(epoch) + - '_transformer.pt' + '_gpt1.pt' ) - decoder_transformer_file = ( + decoder_gpt1_file = ( 'src/decoder' + '_b' + str(batch) + '_h' + str(hidden_dim) + - '_l' + str(num_lstm_layers) + + '_l' + str(num_gpt1_layers) + + '_nh' + str(n_head) + '_e' + str(epoch) + - '_transformer.pt' + '_gpt1.pt' ) - embedding_transformer_file = ( + embedding_gpt1_file = ( 'src/embedding' + '_b' + str(batch) + '_h' + str(hidden_dim) + - '_l' + str(num_lstm_layers) + + '_l' + str(num_gpt1_layers) + + '_nh' + str(n_head) + '_e' + str(epoch) + - '_transformer.pt' + '_gpt1.pt' ) diff --git a/load_dataset.py b/load_dataset.py index 1709f1f..50ac26c 100644 --- a/load_dataset.py +++ b/load_dataset.py @@ -1,7 +1,7 @@ import os import torch -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset from torchvision import transforms from torch.nn.utils.rnn import pad_sequence diff --git a/model.py b/model.py index 03c3393..1a1fcdb 100644 --- a/model.py +++ b/model.py @@ -100,7 +100,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(x) -class DecoderTransformer(nn.Module): +class DecoderGPT1(nn.Module): def __init__(self, word_emb_dim: int, diff --git a/show_instance.py b/show_instance.py index 02ffe38..c849e79 100644 --- a/show_instance.py +++ b/show_instance.py @@ -8,11 +8,11 @@ from config import Config from vocab import Vocab -from model import Encoder, DecoderLSTM, DecoderTransformer +from model import Encoder, DecoderLSTM, DecoderGPT1 parser = argparse.ArgumentParser() -parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'transformer']) +parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'gpt1']) parser.add_argument('--image_file', type=str, default='6261030.jpg') args = parser.parse_args() @@ -46,20 +46,20 @@ num_layers=config.num_lstm_layers, vocab_size=config.vocab_size).to(config.device) else: - decoder = DecoderTransformer(word_emb_dim=config.word_emb_dim, - nhead=config.n_head, - hidden_dim=config.hidden_dim, - num_layers=config.num_transformer_layers, - vocab_size=config.vocab_size).to(config.device) + decoder = DecoderGPT1(word_emb_dim=config.word_emb_dim, + nhead=config.n_head, + hidden_dim=config.hidden_dim, + num_layers=config.num_gpt1_layers, + vocab_size=config.vocab_size).to(config.device) if args.model == 'lstm': encoder.load_state_dict(torch.load(config.encoder_lstm_file, map_location=config.device)) emb_layer.load_state_dict(torch.load(config.embedding_lstm_file, map_location=config.device)) decoder.load_state_dict(torch.load(config.decoder_lstm_file, map_location=config.device)) else: - encoder.load_state_dict(torch.load(config.encoder_transformer_file, map_location=config.device)) - emb_layer.load_state_dict(torch.load(config.embedding_transformer_file, map_location=config.device)) - decoder.load_state_dict(torch.load(config.decoder_transformer_file, map_location=config.device)) + encoder.load_state_dict(torch.load(config.encoder_gpt1_file, map_location=config.device)) + emb_layer.load_state_dict(torch.load(config.embedding_gpt1_file, map_location=config.device)) + decoder.load_state_dict(torch.load(config.decoder_gpt1_file, map_location=config.device)) encoder.eval() emb_layer.eval() @@ -104,9 +104,10 @@ sentence.append(next_word) sentence = ' '.join(sentence).strip().capitalize() + '.' +plt.figure().set_figwidth(50) plt.imshow(image_ori.permute(1, 2, 0).cpu()) -plt.title(sentence) +plt.title('[{}]'.format(args.model) + ' ' + sentence) plt.axis('off') -image_save = args.image_file.split('.')[0] + '_' + args.model + '.pdf' -plt.savefig(image_save, bbox_inches='tight') +image_save = args.image_file.split('.')[0] + '_' + args.model + '.png' +plt.savefig(image_save, dpi=300, bbox_inches='tight') print(sentence) \ No newline at end of file diff --git a/train.py b/train.py index 95753af..827f4e5 100644 --- a/train.py +++ b/train.py @@ -9,10 +9,10 @@ from config import Config from vocab import Vocab from load_dataset import Flicker30k, preprocess_image, Padding -from model import Encoder, DecoderLSTM, DecoderTransformer +from model import Encoder, DecoderLSTM, DecoderGPT1 parser = argparse.ArgumentParser() -parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'transformer']) +parser.add_argument('--model', type=str, default='lstm', choices=['lstm', 'gpt1']) args = parser.parse_args() config = Config @@ -47,15 +47,15 @@ num_layers=config.num_lstm_layers, vocab_size=config.vocab_size).to(config.device) else: - decoder = DecoderTransformer(word_emb_dim=config.word_emb_dim, - nhead=config.n_head, - hidden_dim=config.hidden_dim, - num_layers=config.num_transformer_layers, - vocab_size=config.vocab_size).to(config.device) + decoder = DecoderGPT1(word_emb_dim=config.word_emb_dim, + nhead=config.n_head, + hidden_dim=config.hidden_dim, + num_layers=config.num_gpt1_layers, + vocab_size=config.vocab_size).to(config.device) criterion = torch.nn.CrossEntropyLoss().to(config.device) parameters = list(encoder.parameters()) + list(emb_layer.parameters()) + list(decoder.parameters()) -optimizer = torch.optim.Adam(params=parameters, lr=config.lr_lstm if args.model == 'lstm' else config.lr_transformer) +optimizer = torch.optim.Adam(params=parameters, lr=config.lr_lstm if args.model == 'lstm' else config.lr_gpt1) # training print('---Training---') @@ -164,6 +164,6 @@ torch.save(emb_layer.state_dict(), config.embedding_lstm_file) torch.save(decoder.state_dict(), config.decoder_lstm_file) else: - torch.save(encoder.state_dict(), config.encoder_transformer_file) - torch.save(emb_layer.state_dict(), config.embedding_transformer_file) - torch.save(decoder.state_dict(), config.decoder_transformer_file) + torch.save(encoder.state_dict(), config.encoder_gpt1_file) + torch.save(emb_layer.state_dict(), config.embedding_gpt1_file) + torch.save(decoder.state_dict(), config.decoder_gpt1_file) diff --git a/vocab.py b/vocab.py index d6c4256..da93d60 100644 --- a/vocab.py +++ b/vocab.py @@ -1,5 +1,3 @@ -import os -import string from nltk.tokenize import RegexpTokenizer from collections import Counter