diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..4bdd189 --- /dev/null +++ b/dataloader.py @@ -0,0 +1,82 @@ +import torch +import numpy as np +from torch.utils.data import Dataset, DataLoader + +class SummaryDataset(Dataset): + def __init__(self, context, summary, tok, enc_max_len, dec_max_len, ignore_index=-100): + super().__init__() + self.tok = tok + self.enc_max_len = enc_max_len + self.dec_max_len = dec_max_len + self.context = context + self.summary = summary + self.pad_index = tok.pad_token_id + self.ignore_index = ignore_index + + def add_padding_data(self, inputs, max_len): + if len(inputs) < max_len: + pad = np.array([self.pad_index] *(max_len - len(inputs))) + inputs = np.concatenate([inputs, pad]) + else: + inputs = inputs[:max_len] + + return inputs + + def add_ignored_data(self, inputs, max_len): + if len(inputs) < max_len: + pad = np.array([self.ignore_index] *(max_len - len(inputs))) + inputs = np.concatenate([inputs, pad]) + else: + inputs = inputs[:max_len] + + return inputs + + def __getitem__(self, idx): + context = self.context[idx] + summary = self.summary[idx] + input_ids = self.tok.encode(context) + input_ids = self.add_padding_data(input_ids, self.enc_max_len) + + label_ids = self.tok.encode(summary, add_special_tokens=False) + label_ids.append(self.tok.eos_token_id) + dec_input_ids = [self.tok.eos_token_id] + dec_input_ids += label_ids[:-1] + dec_input_ids = self.add_padding_data(dec_input_ids, self.dec_max_len) + label_ids = self.add_ignored_data(label_ids, self.dec_max_len) + +# return (torch.tensor(input_ids), +# torch.tensor(dec_input_ids), +# torch.tensor(label_ids)) + return {'input_ids': np.array(input_ids, dtype=np.int_), + 'decoder_input_ids': np.array(dec_input_ids, dtype=np.int_), + 'labels': np.array(label_ids, dtype=np.int_)} + + def __len__(self): + return len(self.context) + +class SummaryBatchGenerator: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__(self, batch): +# print(batch) + input_ids = torch.tensor([item['input_ids'] for item in batch]) + decoder_input_ids = torch.tensor([item['decoder_input_ids'] for item in batch]) + labels = torch.tensor([item['labels'] for item in batch]) + + attention_mask = (input_ids != self.tokenizer.pad_token_id).int() + decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id).int() + + return {'input_ids': input_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'decoder_input_ids': decoder_input_ids, + 'decoder_attention_mask': decoder_attention_mask} + +def get_dataloader(dataset, batch_generator, batch_size=16, shuffle=True): + data_loader = DataLoader(dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=batch_generator, + num_workers=4) + return data_loader diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..1f6ebf8 --- /dev/null +++ b/inference.py @@ -0,0 +1,50 @@ +import argparse +import json +import pandas as pd +import torch +from transformers import set_seed, AutoTokenizer, AutoModelForSeq2SeqLM +from kobart import get_pytorch_kobart_model, get_kobart_tokenizer + +# parser = argparse.ArgumentParser() +# parser.add_argument('--text', type=str, required=True) +# parser.add_argument('--device', type=str, default='cpu') + +SEED = 42 +set_seed(SEED) + +def inference(model, tokenizer, test_df, device): + model.to(device) + model.eval() + results = [] + + with torch.no_grad(): + for text, gd in zip(test_df['text'], test_df['tag']): + inputs = tokenizer([text[:1024]], return_tensors='pt') + del inputs['token_type_ids'] + res = model.generate(**inputs, do_sample=True, num_return_sequences=10) + generated_summary = list(set([tokenizer.decode(r, skip_special_tokens=True) for r in res])) + generated = {"text":text, "golden tag": gd, "generated tag": generated_summary} + results.append(generated) + return results + +if __name__ == '__main__': +# args = parser.parse_args() + + +# text = args.text +# device = args.device + + with open('data/Brunch_accm_20210328_test.json', 'r') as f: + test_data = json.load(f) + test_df = pd.DataFrame(test_data) + +# test_df = df['context'] + device = 'cpu' + tokenizer = get_kobart_tokenizer() + model = AutoModelForSeq2SeqLM.from_pretrained("model_checkpoint/checkpoint_20210328_large/saved_checkpoint_5") +# model.to('cuda') + + res = inference(model, tokenizer, test_df, device) + print(res) + with open('test_result.json', 'w') as f: + json.dump(res, f) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..f0f4581 --- /dev/null +++ b/main.py @@ -0,0 +1,247 @@ + +import os +import re +import sys +import math +import yaml +import logging +import argparse +import datetime +import json + +import torch +import transformers +from transformers import BartModel, AutoModelForSeq2SeqLM +from kobart import get_pytorch_kobart_model, get_kobart_tokenizer + +# from configloader import train_config +from dataloader import get_dataloader, SummaryDataset, SummaryBatchGenerator +from train import train + + +def gen_checkpoint_id(args): + engines = "".join([engine.capitalize() for engine in args.ENGINE_ID.split("-")]) + tasks = "".join([task.capitalize() for task in args.TASK_ID.split("-")]) + timez = datetime.datetime.now().strftime("%Y%m%d%H%M") + checkpoint_id = "_".join([engines, tasks, timez]) + return checkpoint_id + +def get_logger(args): + log_path = f"{args.checkpoint}/info/" + + if not os.path.isdir(log_path): + os.mkdir(log_path) + train_instance_log_files = os.listdir(log_path) + train_instance_count = len(train_instance_log_files) + + logging.basicConfig( + filename=f'{args.checkpoint}/info/train_instance_{train_instance_count}_info.log', + filemode='w', + format="%(asctime)s | %(filename)15s | %(levelname)7s | %(funcName)10s | %(message)s", + datefmt='%Y-%m-%d %H:%M:%S' + ) + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + streamHandler = logging.StreamHandler() + streamHandler.setLevel(logging.INFO) + logger.addHandler(streamHandler) + + logger.info("-"*40) + for arg in vars(args): + logger.info(f"{arg}: {getattr(args, arg)}") + logger.info("-"*40)\ + + return logger + +def checkpoint_count(checkpoint): + _, folders, files = next(iter(os.walk(checkpoint))) + files = list(filter(lambda x: "saved_checkpoint_" in x, files)) + # regex used to extract only integer elements from the list of files in the corresponding folder + # this is to extract the most recent checkpoint in case of continuation of training + checkpoints = map(lambda x: int(re.search(r"[0-9]{1,}", x).group()[0]), files) + + try: + last_checkpoint = sorted(checkpoints)[-1] + except IndexError: + last_checkpoint = 0 + return last_checkpoint + +def get_args(): + global train_config + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_class", + type=str, + default='AutoModelForSeq2SeqLM' + ) + parser.add_argument( + "--tokenizer_class", + type=str, + default='AutoTokenizer' + ) + parser.add_argument( + "--optimizer_class", + type=str, + default='AdamW' + ) + parser.add_argument( + "--device", + type=str, + default='cuda' + ) + parser.add_argument( + "--checkpoint", + type=str, + default='checkpoint_20210328_large' + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5 + ) + parser.add_argument( + "--epochs", + type=int, + default=20 + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=2 + ) + parser.add_argument( + "--eval_batch_size", + type=int, + default=4 + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=16 + ) + parser.add_argument( + "--log_every", + type=int, + default=10 + ) + parser.add_argument( + "--data_dir", + type=str, + default='data' + ) + parser.add_argument( + "--enc_max_len", + type=int, + default=1024 + ) + parser.add_argument( + "--dec_max_len", + type=int, + default=128 + ) + args = parser.parse_args() + args.device = args.device if args.device else 'cpu' + + return args + +def main(): + # Get ArgParse + args = get_args() + if args.checkpoint: + args.checkpoint = ( + "./model_checkpoint/" + args.checkpoint[-1] + if args.checkpoint[-1] == "/" + else "./model_checkpoint/" + args.checkpoint + ) + else: + args.checkpoint = "./model_checkpoint/" + gen_checkpoint_id(args) + + + # If checkpoint path exists, load the last model + if os.path.isdir(args.checkpoint): + # EXAMPLE: "{engine_name}_{task_name}_{timestamp}/saved_checkpoint_1" + args.checkpoint_count = checkpoint_count(args.checkpoint) + logger = get_logger(args) + logger.info(f"Checkpoint path directory exists") + logger.info(f"Loading model from saved_checkpoint_{args.checkpoint_count}") + model = torch.load(f"{args.checkpoint}/saved_checkpoint_{args.checkpoint_count}") + + args.checkpoint_count += 1 # + # If there is none, create a checkpoint folder and train from scratch + else: + try: + os.makedirs(args.checkpoint) + except: + print("Ignoring Existing File Path ...") + +# model = BartModel.from_pretrained(get_pytorch_kobart_model()) + model = AutoModelForSeq2SeqLM.from_pretrained(get_pytorch_kobart_model()) + + args.checkpoint_count = 0 + logger = get_logger(args) + + logger.info(f"Creating a new directory for {args.checkpoint}") + + args.logger = logger + + model.to(args.device) + + # Define Tokenizer + tokenizer = get_kobart_tokenizer() + + # Add Additional Special Tokens + #special_tokens_dict = {"sep_token": ""} + #tokenizer.add_special_tokens(special_tokens_dict) + #model.resize_token_embeddings(new_num_tokens=len(tokenizer)) + + # Define Optimizer + optimizer_class = getattr(transformers, args.optimizer_class) + optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) + + logger.info(f"Loading data from {args.data_dir} ...") + with open("data/Brunch_accm_20210328_train.json", 'r') as f: + train_data = json.load(f) + train_context = [data['text'] for data in train_data] + train_tag = [data['tag'] for data in train_data] + with open("data/Brunch_accm_20210328_test.json", 'r') as f: + test_data = json.load(f) + test_context = [data['text'] for data in test_data] + test_tag = [data['tag'] for data in test_data] + + train_dataset = SummaryDataset(train_context, train_tag, tokenizer, args.enc_max_len, args.dec_max_len, ignore_index=-100) + test_dataset = SummaryDataset(test_context, test_tag, tokenizer, args.enc_max_len, args.dec_max_len, ignore_index=-100) +# train_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "train.json")) +# valid_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "valid.json")) +# test_dataset = Seq2SeqDataset(data_path=os.path.join(args.data_dir, "test.json")) + + + batch_generator = SummaryBatchGenerator(tokenizer) + + train_loader = get_dataloader( + train_dataset, + batch_generator=batch_generator, + batch_size=args.train_batch_size, + shuffle=True, + ) + + test_loader = get_dataloader( + test_dataset, + batch_generator=batch_generator, + batch_size=args.eval_batch_size, + shuffle=False, + ) + +# test_loader = get_dataloader( +# test_dataset, +# batch_generator=batch_generator, +# batch_size=args.eval_batch_size, +# shuffle=False, +# ) + + + train(model, optimizer, tokenizer, train_loader, test_loader, test_tag, args)# test_loader, args) + +if __name__ == "__main__": + main() diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..66d0a98 --- /dev/null +++ b/preprocess.py @@ -0,0 +1,125 @@ +""" +브런치 데이터 전처리 + +https://github.com/MrBananaHuman/KorNlpTutorial/blob/main/0_%ED%95%9C%EA%B5%AD%EC%96%B4_%EC%A0%84%EC%B2%98%EB%A6%AC.ipynb 를 참고함 +""" + +import json +import re +import os + +from pykospacing import spacing +from hanspell import spell_checker +from soynlp.normalizer import * + +# 문단 단위로 분리 +def paragraph_tokenize(text): + paragraphs = text.split('\n') + return paragraphs + +# 기호 전처리 +punct = "/-'?!.,#$%\'()*+-/:;<=>@[\\]^_`{|}~" + '""“”’' + '∞θ÷α•à−β∅³π‘₹´°£€\×™√²—–&' +punct_mapping = {"‘": "'", "₹": "e", "´": "'", "°": "", "€": "e", "™": "tm", "√": " sqrt ", "×": "x", "²": "2", "—": "-", "–": "-", "’": "'", "_": "-", "`": "'", '“': '"', '”': '"', '“': '"', "£": "e", '∞': 'infinity', 'θ': 'theta', '÷': '/', 'α': 'alpha', '•': '.', 'à': 'a', '−': '-', 'β': 'beta', '∅': '', '³': '3', 'π': 'pi', } + +def clean_punc(text, punct, mapping): + for p in mapping: + text = text.replace(p, mapping[p]) + + for p in punct: + text = text.replace(p, f'{p}') + + specials = {'\u200b': ' ', '…': ' ... ', '\ufeff': '', 'करना': '', 'है': ''} + for s in specials: + text = text.replace(s, specials[s]) + + return text.strip() + +# 링크 주소 등 제거 +def clean_text(texts): + corpus = [] + for i in range(0, len(texts)): + review = re.sub(r'[@%\\*=/~#&\+á?\xc3\xa1\-\|\:\;\-\,\_\~\$\'\"]', '',str(texts[i])) #remove punctuation +# review = re.sub(r'\d+','', str(texts[i]))# remove number + review = review.lower() #lower case + review = re.sub(r'\s+', ' ', review) #remove extra space + review = re.sub(r'<[^>]+>','',review) #remove Html tags + review = re.sub(r'\s+', ' ', review) #remove spaces + review = re.sub(r"^\s+", '', review) #remove space from start + review = re.sub(r'\s+$', '', review) #remove space from the end + corpus.append(review) + return corpus + +# 외래어 사전 +# !curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=1RNYpLE-xbMCGtiEHIoNsCmfcyJP3kLYn" > /dev/null +# !curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=1RNYpLE-xbMCGtiEHIoNsCmfcyJP3kLYn" -o confused_loanwords.txt +lownword_map = {} +lownword_data = open('confused_loanwords.txt', 'r', encoding='utf-8') +lines = lownword_data.readlines() + +for line in lines: + line = line.strip() + miss_spell = line.split('\t')[0] + ori_word = line.split('\t')[1] + lownword_map[miss_spell] = ori_word + +# 전처리 함수 +def text_preprocessor(text): + corpus = [] + + p_text = paragraph_tokenize(text) + pp_text = [clean_punc(text, punct, punct_mapping) for text in p_text] + ppc_text = [sents for sents in clean_text(pp_text) if sents != ''] + + for sent in ppc_text: + spaced_text = spacing(sent) + spelled_sent = spell_checker.check(sent) + checked_sent = spelled_sent.checked + normalized_sent = repeat_normalize(checked_sent) + for lownword in lownword_map: + normalized_sent = normalized_sent.replace(lownword, lownword_map[lownword]) + corpus.append(normalized_sent) + return corpus + +if __name__ == '__main__': + + # 데이터 로드 + print("-"*40) + print("Data Loading...") + org_data_fname = 'Brunch_accm_20210328.json' + with open(os.path.join('data', org_data_fname), 'r') as f: + data = json.load(f) + # data = data[:10] + print(f"The number of original data is {len(data)}.") + + # 인쇄 + print(f"The first data is as below:") + print(data[0]) + + # 전처리 + print("-"*40) + print("Data is being preprocessed...") + data_preprocessed = [] + for i, dat in enumerate(data): + try: + new_d = {} + new_d['text'] = ' '.join(text_preprocessor(dat['text'])) + new_d['tag'] = dat['tag'].split(',') + data_preprocessed.append(new_d) + except: + print(f"Error occured at {i}-th passage :") + pass + + if i % 100 == 0 and i > 0: + print(f"{i}-th data is processed.") + + + # 전처리 완료 + print("-"*40) + print("Data preprocessing is finished.") + print(f"The number of processed data is {len(data_preprocessed)}") + new_data_fname = 'Brunch_accm_20210328_preprocessed.json' + with open(os.path.join('data', new_data_fname), 'w') as f: + json.dump(data_preprocessed, f) + print(f"The data is saved as {new_data_fname}.") + + diff --git a/train.py b/train.py new file mode 100644 index 0000000..477505e --- /dev/null +++ b/train.py @@ -0,0 +1,228 @@ +import time +import logging +import random +import json + + +import tqdm +import numpy as np +import torch +from transformers import set_seed +from rouge_score import rouge_scorer + + +SEED = 42 +set_seed(SEED) +scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True) + +def serialize_args(args): + def is_jsonable(x): + try: + json.dumps(x) + return True + except: + return False + + dct = {k: v for k, v in args.__dict__.items() if is_jsonable(v) } + return dct + + +def single_epoch_train(model, optimizer, train_loader, args): + """ + Fine-tuning for a single epoch. This was done + in order to validate after each epoch. + """ + model.train() + logger = args.logger + loader = tqdm.tqdm(train_loader) + device = args.device + + loss_acumm = 0 + + for idx, batch in enumerate(loader): + input_ids, attention_mask, labels, decoder_input_ids, decoder_attention_mask = ( + batch['input_ids'].to(device), + batch['attention_mask'].to(device), + batch['labels'].to(device), + batch['decoder_input_ids'].to(device), + batch['decoder_attention_mask'].to(device), + ) + + outputs = model(input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask) + loss = outputs.loss + + # Update loss on tqdm loader description + loader.set_description(f"Train Batch Loss: {loss.item():.3f}") + loader.refresh() + try: + import wandb + wandb.log({'loss': loss.item()}) + except: + pass + + # Backward + loss = loss / args.gradient_accumulation_steps + loss.backward() + + # If accumulation step, then descend + if idx % args.gradient_accumulation_steps: + optimizer.step() + optimizer.zero_grad() + +# # Log every log_every batches +# if not idx % args.log_every: +# logger.info(f"Loss: {loss.item()}") + + loss_acumm += loss.item() + + return loss_acumm / len(loader) + + +def single_epoch_validate(model, tokenizer, valid_loader, args): + """ + Testing for a single epoch. + + Generation + """ + model.eval() + logger = args.logger + loader = tqdm.tqdm(valid_loader) + device = args.device + + gend_outputs = [] + loss_acumm = 0 + + with torch.no_grad(): + for idx, batch in enumerate(loader): + input_ids, attention_mask, labels, decoder_input_ids, decoder_attention_mask = ( + batch['input_ids'].to(device), + batch['attention_mask'].to(device), + batch['labels'].to(device), + batch['decoder_input_ids'].to(device), + batch['decoder_attention_mask'].to(device), + ) + +# repetition_penalty = 2.5 +# length_penalty=1.0 +# no_repeat_ngram_size=3 + pred_ids = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.dec_max_len + ) + gend_outputs += [tokenizer.decode(ids, skip_special_tokens=True) for ids in pred_ids] +# max_length=64, +# num_beams=3, +# repetition_penalty=repetition_penalty, +# length_penalty=length_penalty, +# no_repeat_ngram_size=no_repeat_ngram_size, +# early_stopping=True, + # top_k=50, + # top_p=1.0, + # do_sample=False, + # temperature=1.0, + # num_return_sequences=10, + # length_penalty=2, + # min_length=3, + # decoder_start_token_id=model.config.eos_token_id, +# ) + +# decoded_inputs = [tokenizer.decode(c, +# skip_special_tokens=False, +# clean_up_tokenization_spaces=False) +# for c in decoder_input_ids] + +# decoded_preds = [tokenizer.decode(c, +# skip_special_tokens=False, +# clean_up_tokenization_spaces=False) +# for c in pred_ids] + +# decoded_labels = [tokenizer.decode(c, +# skip_special_tokens=False, +# clean_up_tokenization_spaces=False) +# for c in labels] + +# for inputs, preds, labels, kinds in zip(decoded_inputs, decoded_preds, decoded_labels, kind_batch): +# o = {'inputs': inputs, 'preds': preds, 'labels': labels, 'kinds': kinds} +# outputs.append(o) + + + #with torch.cuda.amp.autocast(): + + outputs = model(input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask) + loss = outputs.loss + +# losses.append(loss.item()) + # Update loss on tqdm loader description + loader.set_description(f"Valid Batch Loss: {loss.item():.3f}") + loader.refresh() + try: + import wandb + wandb.log({'val_loss': loss.item()}) + except: + pass + + if not idx % args.log_every: + logger.info(f"Loss: {loss.item()}") + + loss_acumm += loss.item() + + return loss_acumm / len(loader), gend_outputs + # compute metrics based on outputs +# metrics = compute_metric(tokenizer, outputs) +# metrics['val_avg_loss'] = avg_loss + +# return avg_loss + + + + +def train(model, optimizer, tokenizer, train_loader, valid_loader, valid_texts, args):#, , test_loader, args): + logger = args.logger + + with open(f"{args.checkpoint}/args.json", "w") as f: + json.dump(serialize_args(args), f) + + + for epoch in range(args.epochs): + + #with experiment.train(): + start_time = time.time() + logger.info(f"Epoch {epoch + 1} (Globally {args.checkpoint_count})") + + # Training + logger.info(f"Begin Training ... ") + train_loss = single_epoch_train(model, optimizer, train_loader, args) + mins = round((time.time() - start_time) / 60 , 2) + valid_loss, valid_output = single_epoch_validate(model, tokenizer, valid_loader, args) + + logger.info(f"Training Finished!") + logger.info(f"Time taken for training epoch {epoch+1} (globally {args.checkpoint_count}): {mins} min(s)") + logger.info(f"Epoch : {epoch+1}, Training Average Loss : {train_loss}, Validation Average Loss : {valid_loss}") + + scores = [] + for gend, gold in zip(valid_output, valid_texts): + scores.append(scorer.score(gold, gend)) + scores_r1 = [score['rouge1'].recall for score in scores] + scores_rL = [score['rougeL'].recall for score in scores] + score_r1 = sum(scores_r1)/len(scores_r1) + score_rL = sum(scores_rL)/len(scores_rL) + + logger.info(f"Validation Rouge1 : {score_r1}, Validation RougeL : {score_rL}") + + # Saving model + model.save_pretrained(f"{args.checkpoint}/saved_checkpoint_{args.checkpoint_count}") + logger.info(f"Checkpoint saved at {args.checkpoint}/saved_checkpoint_{args.checkpoint_count}") + + + args.checkpoint_count += 1 + + return \ No newline at end of file