Skip to content

Commit

Permalink
KoBART Tag Generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Inhwan Lee committed Apr 4, 2021
1 parent a2bc78c commit 1d8ef90
Show file tree
Hide file tree
Showing 5 changed files with 732 additions and 0 deletions.
82 changes: 82 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class SummaryDataset(Dataset):
def __init__(self, context, summary, tok, enc_max_len, dec_max_len, ignore_index=-100):
super().__init__()
self.tok = tok
self.enc_max_len = enc_max_len
self.dec_max_len = dec_max_len
self.context = context
self.summary = summary
self.pad_index = tok.pad_token_id
self.ignore_index = ignore_index

def add_padding_data(self, inputs, max_len):
if len(inputs) < max_len:
pad = np.array([self.pad_index] *(max_len - len(inputs)))
inputs = np.concatenate([inputs, pad])
else:
inputs = inputs[:max_len]

return inputs

def add_ignored_data(self, inputs, max_len):
if len(inputs) < max_len:
pad = np.array([self.ignore_index] *(max_len - len(inputs)))
inputs = np.concatenate([inputs, pad])
else:
inputs = inputs[:max_len]

return inputs

def __getitem__(self, idx):
context = self.context[idx]
summary = self.summary[idx]
input_ids = self.tok.encode(context)
input_ids = self.add_padding_data(input_ids, self.enc_max_len)

label_ids = self.tok.encode(summary, add_special_tokens=False)
label_ids.append(self.tok.eos_token_id)
dec_input_ids = [self.tok.eos_token_id]
dec_input_ids += label_ids[:-1]
dec_input_ids = self.add_padding_data(dec_input_ids, self.dec_max_len)
label_ids = self.add_ignored_data(label_ids, self.dec_max_len)

# return (torch.tensor(input_ids),
# torch.tensor(dec_input_ids),
# torch.tensor(label_ids))
return {'input_ids': np.array(input_ids, dtype=np.int_),
'decoder_input_ids': np.array(dec_input_ids, dtype=np.int_),
'labels': np.array(label_ids, dtype=np.int_)}

def __len__(self):
return len(self.context)

class SummaryBatchGenerator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def __call__(self, batch):
# print(batch)
input_ids = torch.tensor([item['input_ids'] for item in batch])
decoder_input_ids = torch.tensor([item['decoder_input_ids'] for item in batch])
labels = torch.tensor([item['labels'] for item in batch])

attention_mask = (input_ids != self.tokenizer.pad_token_id).int()
decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id).int()

return {'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'decoder_input_ids': decoder_input_ids,
'decoder_attention_mask': decoder_attention_mask}

def get_dataloader(dataset, batch_generator, batch_size=16, shuffle=True):
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=batch_generator,
num_workers=4)
return data_loader
50 changes: 50 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
import json
import pandas as pd
import torch
from transformers import set_seed, AutoTokenizer, AutoModelForSeq2SeqLM
from kobart import get_pytorch_kobart_model, get_kobart_tokenizer

# parser = argparse.ArgumentParser()
# parser.add_argument('--text', type=str, required=True)
# parser.add_argument('--device', type=str, default='cpu')

SEED = 42
set_seed(SEED)

def inference(model, tokenizer, test_df, device):
model.to(device)
model.eval()
results = []

with torch.no_grad():
for text, gd in zip(test_df['text'], test_df['tag']):
inputs = tokenizer([text[:1024]], return_tensors='pt')
del inputs['token_type_ids']
res = model.generate(**inputs, do_sample=True, num_return_sequences=10)
generated_summary = list(set([tokenizer.decode(r, skip_special_tokens=True) for r in res]))
generated = {"text":text, "golden tag": gd, "generated tag": generated_summary}
results.append(generated)
return results

if __name__ == '__main__':
# args = parser.parse_args()


# text = args.text
# device = args.device

with open('data/Brunch_accm_20210328_test.json', 'r') as f:
test_data = json.load(f)
test_df = pd.DataFrame(test_data)

# test_df = df['context']
device = 'cpu'
tokenizer = get_kobart_tokenizer()
model = AutoModelForSeq2SeqLM.from_pretrained("model_checkpoint/checkpoint_20210328_large/saved_checkpoint_5")
# model.to('cuda')

res = inference(model, tokenizer, test_df, device)
print(res)
with open('test_result.json', 'w') as f:
json.dump(res, f)
247 changes: 247 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@

import os
import re
import sys
import math
import yaml
import logging
import argparse
import datetime
import json

import torch
import transformers
from transformers import BartModel, AutoModelForSeq2SeqLM
from kobart import get_pytorch_kobart_model, get_kobart_tokenizer

# from configloader import train_config
from dataloader import get_dataloader, SummaryDataset, SummaryBatchGenerator
from train import train


def gen_checkpoint_id(args):
engines = "".join([engine.capitalize() for engine in args.ENGINE_ID.split("-")])
tasks = "".join([task.capitalize() for task in args.TASK_ID.split("-")])
timez = datetime.datetime.now().strftime("%Y%m%d%H%M")
checkpoint_id = "_".join([engines, tasks, timez])
return checkpoint_id

def get_logger(args):
log_path = f"{args.checkpoint}/info/"

if not os.path.isdir(log_path):
os.mkdir(log_path)
train_instance_log_files = os.listdir(log_path)
train_instance_count = len(train_instance_log_files)

logging.basicConfig(
filename=f'{args.checkpoint}/info/train_instance_{train_instance_count}_info.log',
filemode='w',
format="%(asctime)s | %(filename)15s | %(levelname)7s | %(funcName)10s | %(message)s",
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

streamHandler = logging.StreamHandler()
streamHandler.setLevel(logging.INFO)
logger.addHandler(streamHandler)

logger.info("-"*40)
for arg in vars(args):
logger.info(f"{arg}: {getattr(args, arg)}")
logger.info("-"*40)\

return logger

def checkpoint_count(checkpoint):
_, folders, files = next(iter(os.walk(checkpoint)))
files = list(filter(lambda x: "saved_checkpoint_" in x, files))
# regex used to extract only integer elements from the list of files in the corresponding folder
# this is to extract the most recent checkpoint in case of continuation of training
checkpoints = map(lambda x: int(re.search(r"[0-9]{1,}", x).group()[0]), files)

try:
last_checkpoint = sorted(checkpoints)[-1]
except IndexError:
last_checkpoint = 0
return last_checkpoint

def get_args():
global train_config

parser = argparse.ArgumentParser()
parser.add_argument(
"--model_class",
type=str,
default='AutoModelForSeq2SeqLM'
)
parser.add_argument(
"--tokenizer_class",
type=str,
default='AutoTokenizer'
)
parser.add_argument(
"--optimizer_class",
type=str,
default='AdamW'
)
parser.add_argument(
"--device",
type=str,
default='cuda'
)
parser.add_argument(
"--checkpoint",
type=str,
default='checkpoint_20210328_large'
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5
)
parser.add_argument(
"--epochs",
type=int,
default=20
)
parser.add_argument(
"--train_batch_size",
type=int,
default=2
)
parser.add_argument(
"--eval_batch_size",
type=int,
default=4
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=16
)
parser.add_argument(
"--log_every",
type=int,
default=10
)
parser.add_argument(
"--data_dir",
type=str,
default='data'
)
parser.add_argument(
"--enc_max_len",
type=int,
default=1024
)
parser.add_argument(
"--dec_max_len",
type=int,
default=128
)
args = parser.parse_args()
args.device = args.device if args.device else 'cpu'

return args

def main():
# Get ArgParse
args = get_args()
if args.checkpoint:
args.checkpoint = (
"./model_checkpoint/" + args.checkpoint[-1]
if args.checkpoint[-1] == "/"
else "./model_checkpoint/" + args.checkpoint
)
else:
args.checkpoint = "./model_checkpoint/" + gen_checkpoint_id(args)


# If checkpoint path exists, load the last model
if os.path.isdir(args.checkpoint):
# EXAMPLE: "{engine_name}_{task_name}_{timestamp}/saved_checkpoint_1"
args.checkpoint_count = checkpoint_count(args.checkpoint)
logger = get_logger(args)
logger.info(f"Checkpoint path directory exists")
logger.info(f"Loading model from saved_checkpoint_{args.checkpoint_count}")
model = torch.load(f"{args.checkpoint}/saved_checkpoint_{args.checkpoint_count}")

args.checkpoint_count += 1 #
# If there is none, create a checkpoint folder and train from scratch
else:
try:
os.makedirs(args.checkpoint)
except:
print("Ignoring Existing File Path ...")

# model = BartModel.from_pretrained(get_pytorch_kobart_model())
model = AutoModelForSeq2SeqLM.from_pretrained(get_pytorch_kobart_model())

args.checkpoint_count = 0
logger = get_logger(args)

logger.info(f"Creating a new directory for {args.checkpoint}")

args.logger = logger

model.to(args.device)

# Define Tokenizer
tokenizer = get_kobart_tokenizer()

# Add Additional Special Tokens
#special_tokens_dict = {"sep_token": "<sep>"}
#tokenizer.add_special_tokens(special_tokens_dict)
#model.resize_token_embeddings(new_num_tokens=len(tokenizer))

# Define Optimizer
optimizer_class = getattr(transformers, args.optimizer_class)
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)

logger.info(f"Loading data from {args.data_dir} ...")
with open("data/Brunch_accm_20210328_train.json", 'r') as f:
train_data = json.load(f)
train_context = [data['text'] for data in train_data]
train_tag = [data['tag'] for data in train_data]
with open("data/Brunch_accm_20210328_test.json", 'r') as f:
test_data = json.load(f)
test_context = [data['text'] for data in test_data]
test_tag = [data['tag'] for data in test_data]

train_dataset = SummaryDataset(train_context, train_tag, tokenizer, args.enc_max_len, args.dec_max_len, ignore_index=-100)
test_dataset = SummaryDataset(test_context, test_tag, tokenizer, args.enc_max_len, args.dec_max_len, ignore_index=-100)
# train_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "train.json"))
# valid_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "valid.json"))
# test_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "test.json"))


batch_generator = SummaryBatchGenerator(tokenizer)

train_loader = get_dataloader(
train_dataset,
batch_generator=batch_generator,
batch_size=args.train_batch_size,
shuffle=True,
)

test_loader = get_dataloader(
test_dataset,
batch_generator=batch_generator,
batch_size=args.eval_batch_size,
shuffle=False,
)

# test_loader = get_dataloader(
# test_dataset,
# batch_generator=batch_generator,
# batch_size=args.eval_batch_size,
# shuffle=False,
# )


train(model, optimizer, tokenizer, train_loader, test_loader, test_tag, args)# test_loader, args)

if __name__ == "__main__":
main()
Loading

0 comments on commit 1d8ef90

Please sign in to comment.