diff --git a/scripts/downstream/train_token_classification_lm_finetuning.py b/scripts/downstream/train_token_classification_lm_finetuning.py new file mode 100644 index 0000000..683a26b --- /dev/null +++ b/scripts/downstream/train_token_classification_lm_finetuning.py @@ -0,0 +1,622 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jan 6 13:12:12 2021 + +@author: z +""" + +import pprint +import numpy as np +import torch +import logging +import transformers +import datasets +from dataclasses import dataclass, field +from typing import Optional +from custom_data_collator import DataCollatorForTokenClassification +from datasets import load_dataset, load_metric, Dataset +from functools import lru_cache +from seqeval.metrics import classification_report +from sklearn.metrics import classification_report as sk_classification_report +from sklearn.metrics import precision_recall_fscore_support, accuracy_score + +# thai2transformers +from thai2transformers import metrics as t2f_metrics +from thai2transformers.tokenizers import ( + ThaiRobertaTokenizer, ThaiWordsNewmmTokenizer, + ThaiWordsSyllableTokenizer, FakeSefrCutTokenizer, + SPACE_TOKEN as DEFAULT_SPACE_TOKEN, SEFR_SPLIT_TOKEN) + +from transformers import (Trainer, TrainingArguments, + AutoModelForTokenClassification, AutoTokenizer, + HfArgumentParser, CamembertTokenizer) + +logger = logging.getLogger(__name__) + + +def is_main_process(rank): + return rank in [-1, 0] + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, + or train from scratch. + """ + + model_name_or_path: str = field( + metadata={ + "help": "The model checkpoint for weights initialization." + }, + ) + tokenizer_name_or_path: str = field( + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + } + ) + tokenizer_type: Optional[str] = field( + default='AutoTokenizer', + metadata={'help': 'type of tokenizer'} + ) + + +@dataclass +class DataTrainingArguments: + dataset_name: str = field( + metadata={'help': 'name of dataset'} + ) + label_name: str = field( + metadata={'help': 'name of label column (ex. ner_tags, pos_tags)'} + ) + max_length: Optional[int] = field( + default=None, + metadata={'help': 'max length of a sequence'} + ) + + +@dataclass +class CustomArguments: + no_train_report: bool = field( + default=False, + metadata={'help': 'do not report training set metrics'} + ) + no_eval_report: bool = field( + default=False, + metadata={'help': 'do not report evaluation set metrics'} + ) + no_test_report: bool = field( + default=False, + metadata={'help': 'do not report test set metrics'} + ) + lst20_data_dir: Optional[str] = field( + default=None, + metadata={'help': 'path to lst20 dataset'} + ) + skip_do_eval: Optional[bool] = field( + default=False, + metadata={'help': 'skip eval loop'} + ) + space_token: str = field( + default=DEFAULT_SPACE_TOKEN, + metadata={'help': 'specify custom space token'} + ) + lowercase: bool = field( + default=False, + metadata={'help': 'Apply lowercase to input texts'} + ) + filter_thainer_with_mbert_tokenizer_threshold: Optional[int] = field( + default=None, + metadata={'help': 'fiter thainer test set with mbert.'} + ) + + +parser = HfArgumentParser((ModelArguments, DataTrainingArguments, + TrainingArguments, CustomArguments)) + +model_args, data_args, training_args, custom_args = parser.parse_args_into_dataclasses() + +# Set seed +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +torch.manual_seed(training_args.seed) +np.random.seed(training_args.seed) + +# Setup logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, +) +# Log on each process the small summary: +logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}," + f"n_gpu: {training_args.n_gpu}, distributed training: {bool(training_args.local_rank != -1)}," + f"16-bits training: {training_args.fp16}" +) +# Set the verbosity to info of the Transformers logger (on main process only): +if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() +logger.info("Training/evaluation parameters %s", training_args) + +logger.info("Data parameters %s", data_args) +logger.info("Model parameters %s", model_args) +logger.info("Custom args %s", custom_args) + +if model_args.tokenizer_type == 'AutoTokenizer': + # bert-base-multilingual-cased + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) + tokenizer.add_tokens(custom_args.space_token) +elif model_args.tokenizer_type == 'ThaiRobertaTokenizer': + tokenizer = ThaiRobertaTokenizer.from_pretrained( + model_args.tokenizer_name_or_path) +elif model_args.tokenizer_type == 'ThaiWordsNewmmTokenizer': + tokenizer = ThaiWordsNewmmTokenizer.from_pretrained( + model_args.tokenizer_name_or_path) +elif model_args.tokenizer_type == 'ThaiWordsSyllableTokenizer': + tokenizer = ThaiWordsSyllableTokenizer.from_pretrained( + model_args.tokenizer_name_or_path) +elif model_args.tokenizer_type == 'FakeSefrCutTokenizer': + tokenizer = FakeSefrCutTokenizer.from_pretrained( + model_args.tokenizer_name_or_path) +elif model_args.tokenizer_type == 'CamembertTokenizer': + tokenizer = CamembertTokenizer.from_pretrained( + model_args.tokenizer_name_or_path) + tokenizer.additional_special_tokens = ['NOTUSED', 'NOTUSED', custom_args.space_token] + logger.info("[INFO] space_token = `%s`", custom_args.space_token) +elif model_args.tokenizer_type == 'skip': + logging.info('Skip tokenizer') +else: + raise NotImplementedError(f'tokenizer_type {model_args.tokenizer_type} is not implemeted.') + +text_col = 'tokens' +label_col = data_args.label_name + +if data_args.dataset_name == 'thainer': + dataset = load_dataset("thainer") + # Remove tag: ไม่ยืนยัน + if label_col == 'ner_tags': + dataset['train'] = dataset['train'].map( + lambda examples: {'ner_tags': [i if i not in [13, 26] else 27 + for i in examples[label_col]]} + ) + label_maps = {i: name for i, name in + enumerate(dataset['train'].features[label_col].feature.names)} + label_names = dataset['train'].features[label_col].feature.names + num_labels = dataset['train'].features[label_col].feature.num_classes +elif data_args.dataset_name == 'lst20': + dataset = load_dataset('lst20', data_dir=custom_args.lst20_data_dir) + label_maps = {i: name for i, name in + enumerate(dataset['train'].features[label_col].feature.names)} + label_names = dataset['train'].features[label_col].feature.names + num_labels = dataset['train'].features[label_col].feature.num_classes +elif data_args.dataset_name == 'dummytest': + def generat_dummy_dataset(size, max_length, max_token_length, label_names, label_sizes): + d = {'tokens': []} + c = {} + chars = [chr(i) for i in range(97, 123, 1)] + for label_name, label_size in zip(label_names, label_sizes): + d[label_name] = [] + c[label_name] = list(range(label_size)) + for i in range(size): + length = np.random.randint(1, max_length) + d['tokens'].append([''.join(np.random.choice(chars, size=max_token_length)) + for _ in range(length)]) + for label_name in label_names: + dummy_labels = np.random.choice(c[label_name], + size=length) + d[label_name].append(dummy_labels) + return Dataset.from_dict(d) + dataset = datasets.DatasetDict( + {'train': generat_dummy_dataset(50, 50, 8, ['ner_tags', 'pos_tags'], [10, 20]), + 'validation': generat_dummy_dataset(10, 50, 8, ['ner_tags', 'pos_tags'], [10, 20]), + 'test': generat_dummy_dataset(10, 50, 8, ['ner_tags', 'pos_tags'], [10, 20]) + }) + label_maps = {i: str(name) for i, name in + enumerate(range(20))} + if 'ner' in label_col.lower(): + label_names = ['O'] + ['B-' + str(name) for i, name in enumerate(range(1, 20))] + else: + label_names = [str(name) for i, name in enumerate(range(20))] + num_labels = 20 +else: + raise NotImplementedError + + +def pre_tokenize(token, space_token=custom_args.space_token): + token = token.replace(' ', space_token) + return token + + +if model_args.tokenizer_type == 'FakeSefrCutTokenizer': + from sefr_cache import SefrCacheTokenizer + sefr_cache_tokenizer = SefrCacheTokenizer() + sefr_cache_tokenizer.load('sefr_cache_tokenizer_dict.pkl') + + @lru_cache(maxsize=None) + def sefr_cached_tokenize(token, space_token=custom_args.space_token, + lowercase=custom_args.lowercase): + if lowercase: + token = token.lower() + token = pre_tokenize(token, space_token) + tokens = sefr_cache_tokenizer.tokenize(token) + text = SEFR_SPLIT_TOKEN.join(tokens) + ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + return ids + + cached_tokenize = sefr_cached_tokenize +else: + @lru_cache(maxsize=None) + def cached_tokenize(token, space_token=custom_args.space_token, + lowercase=custom_args.lowercase): + if lowercase: + token = token.lower() + token = pre_tokenize(token, space_token) + ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(token)) + return ids + + +def preprocess(examples, space_token=custom_args.space_token, lowercase=custom_args.lowercase): + tokens = [] + labels = [] + old_positions = [] + for example_tokens, example_labels in zip(examples[text_col], examples[label_col]): + new_example_tokens = [] + new_example_labels = [] + old_position = [] + for i, (token, label) in enumerate(zip(example_tokens, example_labels)): + # tokenize each already pretokenized tokens with our own tokenizer. + toks = cached_tokenize(token, space_token, lowercase=custom_args.lowercase) + n_toks = len(toks) + new_example_tokens.extend(toks) + # expand label to cover all tokens that get split in a pretokenized token + new_example_labels.extend([label] * n_toks) + # kept track of old position + old_position.extend([i] * n_toks) + tokens.append(new_example_tokens) + labels.append(new_example_labels) + old_positions.append(old_position) + tokenized_inputs = tokenizer._batch_prepare_for_model( + [(e, None) for e in tokens], + truncation_strategy=transformers.tokenization_utils_base.TruncationStrategy.LONGEST_FIRST, + add_special_tokens=True, max_length=data_args.max_length) + # in case of needed truncation we need to chop off some of the labels manually + max_length = max(len(e) for e in tokenized_inputs['input_ids']) + # add -100 to first and last token which is special tokens for and + # -100 is a convention for padding in higgingface transformer lib + # and calculating loss should skip this + tokenized_inputs['old_positions'] = [[-100] + e[:max_length - 2] + [-100] + for e in old_positions] + tokenized_inputs['labels'] = [[-100] + e[:max_length - 2] + [-100] + for e in labels] + return tokenized_inputs + + +data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) + +if data_args.dataset_name == 'thainer': + train_dataset = dataset['train'] + + split = train_dataset.train_test_split( + train_size=0.8, test_size=0.2, seed=2020) + train_dataset = split['train'] + non_train_dataset = split['test'] + + split = non_train_dataset.train_test_split( + train_size=0.5, test_size=0.5, seed=2020) + val_dataset = split['train'] + test_dataset = split['test'] + + if custom_args.filter_thainer_with_mbert_tokenizer_threshold is not None: + mbert_tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased') + # mbert_tokenizer.add_special_tokens(custom_args.space_token) + + def is_not_too_long(example, + max_length=custom_args.filter_thainer_with_mbert_tokenizer_threshold): + tokens = sum([mbert_tokenizer.tokenize( + pre_tokenize(token, '<_>')) + for token in example[text_col]], []) + return len(tokens) < max_length + + test_dataset = test_dataset.filter(is_not_too_long) + + # preprocess + train_dataset = Dataset.from_dict(preprocess(train_dataset)) + val_dataset = Dataset.from_dict(preprocess(val_dataset)) + test_dataset = Dataset.from_dict(preprocess(test_dataset)) + # val set need padding to fix problem with trainer + val_dataset = Dataset.from_dict(data_collator(val_dataset)) + test_dataset = Dataset.from_dict(data_collator(test_dataset)) +elif data_args.dataset_name == 'lst20': + train_dataset = dataset['train'] + val_dataset = dataset['validation'] + test_dataset = dataset['test'] + + train_dataset = Dataset.from_dict(preprocess(train_dataset)) + val_dataset = Dataset.from_dict(preprocess(val_dataset)) + test_dataset = Dataset.from_dict(preprocess(test_dataset)) + # val set need padding to fix problem with trainer + val_dataset = Dataset.from_dict(data_collator(val_dataset)) + test_dataset = Dataset.from_dict(data_collator(test_dataset)) +elif data_args.dataset_name == 'dummytest': + train_dataset = dataset['train'] + val_dataset = dataset['validation'] + test_dataset = dataset['test'] + + train_dataset = Dataset.from_dict(preprocess(train_dataset)) + val_dataset = Dataset.from_dict(preprocess(val_dataset)) + test_dataset = Dataset.from_dict(preprocess(test_dataset)) + val_dataset = Dataset.from_dict(data_collator(val_dataset)) + test_dataset = Dataset.from_dict(data_collator(test_dataset)) +else: + raise NotImplementedError + +model = AutoModelForTokenClassification.from_pretrained( + model_args.model_name_or_path, num_labels=num_labels) + +if model.config.vocab_size == len(tokenizer) - len(tokenizer.get_added_vocab()): + # resize to accomodate added token + model.resize_token_embeddings(len(tokenizer)) +elif model.config.vocab_size == len(tokenizer) and len(tokenizer.get_added_vocab()) > 0: + logger.warning('model might already accomodate added token') +else: + logger.warning(f'model vocab size ({model.config.vocab_size}) is not equal to' + f'tokenizer ({len(tokenizer)}), ' + 'this might cause from tokenizer missmatch with model or added vocabulary') + raise ValueError + +metric = load_metric("seqeval") + + +def get_batch(obj, batch_size): + i = 0 + r = obj[i * batch_size: i * batch_size + batch_size] + yield r + i += 1 + while i * batch_size < len(obj): + r = obj[i * batch_size: i * batch_size + batch_size] + yield r + i += 1 + + +def agg_preds_labels(model, dataset, device=torch.device('cuda')): + agg_chunk_preds = [] + agg_chunk_labels = [] + model.to(device) + for step, batch in enumerate(get_batch(dataset, training_args.per_device_eval_batch_size)): + labels = batch['labels'] + old_positions = batch['old_positions'] + dont_include = ['labels', 'old_positions'] + batch = {k: torch.tensor(v, dtype=torch.int64).to(device) for k, v in batch.items() + if k not in dont_include} + + preds, = model(**batch) + preds = preds.argmax(2) + preds = preds.tolist() + + use_idxs = [[i for i, e in enumerate(label) if e != -100] + for label in labels] + true_preds = [[preds[j][i] for i in use_idx] + for j, use_idx in enumerate(use_idxs)] + true_labels = [[labels[j][i] for i in use_idx] + for j, use_idx in enumerate(use_idxs)] + true_old_positions = [[old_positions[j][i] for i in use_idx] + for j, use_idx in enumerate(use_idxs)] + chunk_preds = [] + chunk_labels = [] + for i, old_position in enumerate(true_old_positions): + cur_pos = -100 + chunk_preds.append([]) + chunk_labels.append([]) + for j, pos in enumerate(old_position): + if pos != cur_pos: + cur_pos = pos + chunk_preds[-1].append(true_preds[i][j]) + chunk_labels[-1].append(true_labels[i][j]) + elif pos < cur_pos: + raise ValueError('later position has higher value than previous one') + agg_chunk_preds.extend(chunk_preds) + agg_chunk_labels.extend(chunk_labels) + print(f'\rProcessed: {len(agg_chunk_preds)} / {len(dataset)}', + flush=True, end=' ') + return agg_chunk_labels, agg_chunk_preds + + +def sk_classification_metrics(labels, preds): + precision_macro, recall_macro, f1_macro, _ = \ + precision_recall_fscore_support(labels, preds, average='macro') + precision_micro, recall_micro, f1_micro, _ = \ + precision_recall_fscore_support(labels, preds, average='micro') + acc = accuracy_score(labels, preds) + return { + 'accuracy': acc, + 'f1_micro': f1_micro, + 'precision_micro': precision_micro, + 'recall_micro': recall_micro, + 'f1_macro': f1_macro, + 'precision_macro': precision_macro, + 'recall_macro': recall_macro, + 'nb_samples': len(labels) + } + + +def compute_token_metrics(agg_chunk_labels, agg_chunk_preds): + report = sk_classification_report(sum(agg_chunk_labels, []), + sum(agg_chunk_preds, []), target_names=label_names) + results = sk_classification_metrics(sum(agg_chunk_labels, []), + sum(agg_chunk_preds, [])) + return results, report + + +def compute_chunk_metrics(agg_chunk_labels, agg_chunk_preds): + results = metric.compute(predictions=[[label_maps[e] for e in a] for a in agg_chunk_preds], + references=[[label_maps[e] for e in a] for a in agg_chunk_labels]) + report = classification_report([[label_maps[e] for e in a] for a in agg_chunk_labels], + [[label_maps[e] for e in a] for a in agg_chunk_preds]) + return results, report + + +def t2t_chunk_metrics(agg_chunk_labels, agg_chunk_preds): + class LabelsPreds: + label_ids = agg_chunk_labels + predictions = agg_chunk_preds + return t2f_metrics.seqeval_classification_metrics(LabelsPreds) + + +def t2t_sk_classification_metrics(agg_chunk_labels, agg_chunk_preds): + class LabelsPreds: + label_ids = agg_chunk_labels + predictions = agg_chunk_preds + return t2f_metrics.sk_classification_metrics(LabelsPreds, pred_labs=True) + + +def compute_metrics(p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [label_maps[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_maps[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + if 'ner' in data_args.label_name: + results = metric.compute(predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + else: + result = t2t_sk_classification_metrics(sum(true_labels, []), + sum(true_predictions, [])) + result = {k: v for k, v in result.items() if k != 'classification_report'} + return result + + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, +) + +if training_args.do_train: + trainer.train() + trainer.save_model() + +if training_args.do_eval and not custom_args.skip_do_eval: + trainer.evaluate() + + +# Preprocess dataset again sometimes something funky occur causing error in report. +# This is possibly because huggingface transformers lib decide to drop some variable +# that are irrelevant to training such as 'old_positions', we also padding training +# dataset in this step. This is done so that predicition loop can ran without having +# padding step inside it. +if data_args.dataset_name == 'thainer': + train_dataset = dataset['train'] + + split = train_dataset.train_test_split( + train_size=0.8, test_size=0.2, seed=2020) + train_dataset = split['train'] + non_train_dataset = split['test'] + + split = non_train_dataset.train_test_split( + train_size=0.5, test_size=0.5, seed=2020) + val_dataset = split['train'] + test_dataset = split['test'] + + if custom_args.filter_thainer_with_mbert_tokenizer_threshold is not None: + mbert_tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased') + # mbert_tokenizer.add_special_tokens(custom_args.space_token) + + def is_not_too_long(example, + max_length=custom_args.filter_thainer_with_mbert_tokenizer_threshold): + tokens = sum([mbert_tokenizer.tokenize( + pre_tokenize(token, '<_>')) + for token in example[text_col]], []) + return len(tokens) < max_length + + test_dataset = test_dataset.filter(is_not_too_long) + + # preprocess + train_dataset = Dataset.from_dict(preprocess(train_dataset)) + val_dataset = Dataset.from_dict(preprocess(val_dataset)) + test_dataset = Dataset.from_dict(preprocess(test_dataset)) + # val set need padding to fix problem with trainer + train_dataset = Dataset.from_dict(data_collator(train_dataset)) + val_dataset = Dataset.from_dict(data_collator(val_dataset)) + test_dataset = Dataset.from_dict(data_collator(test_dataset)) +elif data_args.dataset_name == 'lst20': + train_dataset = dataset['train'] + val_dataset = dataset['validation'] + test_dataset = dataset['test'] + + train_dataset = Dataset.from_dict(preprocess(train_dataset)) + val_dataset = Dataset.from_dict(preprocess(val_dataset)) + test_dataset = Dataset.from_dict(preprocess(test_dataset)) + # val set need padding to fix problem with trainer + train_dataset = Dataset.from_dict(data_collator(train_dataset)) + val_dataset = Dataset.from_dict(data_collator(val_dataset)) + test_dataset = Dataset.from_dict(data_collator(test_dataset)) +elif data_args.dataset_name == 'dummytest': + train_dataset = dataset['train'] + val_dataset = dataset['validation'] + test_dataset = dataset['test'] + train_dataset = Dataset.from_dict(preprocess(train_dataset)) + val_dataset = Dataset.from_dict(preprocess(val_dataset)) + test_dataset = Dataset.from_dict(preprocess(test_dataset)) + val_dataset = Dataset.from_dict(data_collator(val_dataset)) + test_dataset = Dataset.from_dict(data_collator(test_dataset)) +else: + raise NotImplementedError + + +if not custom_args.no_train_report: + agg_chunk_labels, agg_chunk_preds = agg_preds_labels(model, train_dataset) + agg_chunk_labels = [[label_maps[e] for e in a] for a in agg_chunk_labels] + agg_chunk_preds = [[label_maps[e] for e in a] for a in agg_chunk_preds] + if 'ner' in data_args.label_name: + result = t2t_chunk_metrics(agg_chunk_labels, agg_chunk_preds) + else: + result = t2t_sk_classification_metrics(sum(agg_chunk_labels, []), + sum(agg_chunk_preds, [])) + print('[ Train Result ]') + pprint.pprint({k: v for k, v in result.items() if k != 'classification_report'}) + print(result['classification_report']) + +if not custom_args.no_eval_report: + agg_chunk_labels, agg_chunk_preds = agg_preds_labels(model, val_dataset) + agg_chunk_labels = [[label_maps[e] for e in a] for a in agg_chunk_labels] + agg_chunk_preds = [[label_maps[e] for e in a] for a in agg_chunk_preds] + if 'ner' in data_args.label_name: + result = t2t_chunk_metrics(agg_chunk_labels, agg_chunk_preds) + else: + result = t2t_sk_classification_metrics(sum(agg_chunk_labels, []), + sum(agg_chunk_preds, [])) + print('[ Val Result ]') + pprint.pprint({k: v for k, v in result.items() if k != 'classification_report'}) + print(result['classification_report']) + + +if not custom_args.no_test_report: + agg_chunk_labels, agg_chunk_preds = agg_preds_labels(model, test_dataset) + agg_chunk_labels = [[label_maps[e] for e in a] for a in agg_chunk_labels] + agg_chunk_preds = [[label_maps[e] for e in a] for a in agg_chunk_preds] + if 'ner' in data_args.label_name: + result = t2t_chunk_metrics(agg_chunk_labels, agg_chunk_preds) + else: + result = t2t_sk_classification_metrics(sum(agg_chunk_labels, []), + sum(agg_chunk_preds, [])) + print('[ Test Result ]') + pprint.pprint({k: v for k, v in result.items() if k != 'classification_report'}) + print(result['classification_report'])